torchrl 0.11.0__cp314-cp314-manylinux_2_28_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cpython-314-aarch64-linux-gnu.so +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/METADATA +1308 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
torchrl/data/map/hash.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
from collections.abc import Callable
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
from torch.nn import Module
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class BinaryToDecimal(Module):
|
|
14
|
+
"""A Module to convert binaries encoded tensors to decimals.
|
|
15
|
+
|
|
16
|
+
This is a utility class that allow to convert a binary encoding tensor (e.g. `1001`) to
|
|
17
|
+
its decimal value (e.g. `9`)
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
num_bits (int): the number of bits to use for the bases table.
|
|
21
|
+
The number of bits must be lower or equal to the input length and the input length
|
|
22
|
+
must be divisible by ``num_bits``. If ``num_bits`` is lower than the number of
|
|
23
|
+
bits in the input, the end result will be aggregated on the last dimension using
|
|
24
|
+
:func:`~torch.sum`.
|
|
25
|
+
device (torch.device): the device where inputs and outputs are to be expected.
|
|
26
|
+
dtype (torch.dtype): the output dtype.
|
|
27
|
+
convert_to_binary (bool, optional): if ``True``, the input to the ``forward``
|
|
28
|
+
method will be cast to a binary input using :func:`~torch.heavyside`.
|
|
29
|
+
Defaults to ``False``.
|
|
30
|
+
|
|
31
|
+
Examples:
|
|
32
|
+
>>> binary_to_decimal = BinaryToDecimal(
|
|
33
|
+
... num_bits=4, device="cpu", dtype=torch.int32, convert_to_binary=True
|
|
34
|
+
... )
|
|
35
|
+
>>> binary = torch.Tensor([[0, 0, 1, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 0, 10, 0]])
|
|
36
|
+
>>> decimal = binary_to_decimal(binary)
|
|
37
|
+
>>> assert decimal.shape == (2,)
|
|
38
|
+
>>> assert (decimal == torch.Tensor([3, 2])).all()
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(
|
|
42
|
+
self,
|
|
43
|
+
num_bits: int,
|
|
44
|
+
device: torch.device,
|
|
45
|
+
dtype: torch.dtype,
|
|
46
|
+
convert_to_binary: bool = False,
|
|
47
|
+
):
|
|
48
|
+
super().__init__()
|
|
49
|
+
self.convert_to_binary = convert_to_binary
|
|
50
|
+
self.bases = 2 ** torch.arange(num_bits - 1, -1, -1, device=device, dtype=dtype)
|
|
51
|
+
self.num_bits = num_bits
|
|
52
|
+
self.zero_tensor = torch.zeros((1,), device=device)
|
|
53
|
+
|
|
54
|
+
def forward(self, features: torch.Tensor) -> torch.Tensor:
|
|
55
|
+
num_features = features.shape[-1]
|
|
56
|
+
if self.num_bits > num_features:
|
|
57
|
+
raise ValueError(f"{num_features=} is less than {self.num_bits=}")
|
|
58
|
+
elif num_features % self.num_bits != 0:
|
|
59
|
+
raise ValueError(f"{num_features=} is not divisible by {self.num_bits=}")
|
|
60
|
+
|
|
61
|
+
binary_features = (
|
|
62
|
+
torch.heaviside(features, self.zero_tensor)
|
|
63
|
+
if self.convert_to_binary
|
|
64
|
+
else features
|
|
65
|
+
)
|
|
66
|
+
feature_parts = binary_features.reshape(shape=(-1, self.num_bits))
|
|
67
|
+
digits = torch.vmap(torch.dot, (None, 0))(
|
|
68
|
+
self.bases, feature_parts.to(self.bases.dtype)
|
|
69
|
+
)
|
|
70
|
+
digits = digits.reshape(shape=(-1, features.shape[-1] // self.num_bits))
|
|
71
|
+
aggregated_digits = torch.sum(digits, dim=-1)
|
|
72
|
+
return aggregated_digits
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class SipHash(Module):
|
|
76
|
+
"""A Module to Compute SipHash values for given tensors.
|
|
77
|
+
|
|
78
|
+
A hash function module based on SipHash implementation in python. Input tensors should have shape ``[batch_size, num_features]``
|
|
79
|
+
and the output shape will be ``[batch_size]``.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
as_tensor (bool, optional): if ``True``, the bytes will be turned into integers
|
|
83
|
+
through the builtin ``hash`` function and mapped to a tensor. Default: ``True``.
|
|
84
|
+
|
|
85
|
+
.. warning:: This module relies on the builtin ``hash`` function.
|
|
86
|
+
To get reproducible results across runs, the ``PYTHONHASHSEED`` environment
|
|
87
|
+
variable must be set before the code is run (changing this value during code
|
|
88
|
+
execution is without effect).
|
|
89
|
+
|
|
90
|
+
Examples:
|
|
91
|
+
>>> # Assuming we set PYTHONHASHSEED=0 prior to running this code
|
|
92
|
+
>>> a = torch.tensor([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]])
|
|
93
|
+
>>> b = a.clone()
|
|
94
|
+
>>> hash_module = SipHash(as_tensor=True)
|
|
95
|
+
>>> hash_a = hash_module(a)
|
|
96
|
+
>>> hash_a
|
|
97
|
+
tensor([-4669941682990263259, -3778166555168484291, -9122128731510687521])
|
|
98
|
+
>>> hash_b = hash_module(b)
|
|
99
|
+
>>> assert (hash_a == hash_b).all()
|
|
100
|
+
"""
|
|
101
|
+
|
|
102
|
+
def __init__(self, as_tensor: bool = True):
|
|
103
|
+
super().__init__()
|
|
104
|
+
self.as_tensor = as_tensor
|
|
105
|
+
|
|
106
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor | list[bytes]:
|
|
107
|
+
hash_values = []
|
|
108
|
+
if x.dtype in (torch.bfloat16,):
|
|
109
|
+
x = x.to(torch.float16)
|
|
110
|
+
for x_i in x.detach().cpu().numpy():
|
|
111
|
+
hash_value = x_i.tobytes()
|
|
112
|
+
hash_values.append(hash_value)
|
|
113
|
+
if not self.as_tensor:
|
|
114
|
+
return hash_values
|
|
115
|
+
result = torch.tensor([hash(x) for x in hash_values], dtype=torch.int64)
|
|
116
|
+
return result
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
class RandomProjectionHash(SipHash):
|
|
120
|
+
"""A module that combines random projections with SipHash to get a low-dimensional tensor, easier to embed through :class:`~.SipHash`.
|
|
121
|
+
|
|
122
|
+
This module requires sklearn to be installed.
|
|
123
|
+
|
|
124
|
+
Keyword Args:
|
|
125
|
+
n_components (int, optional): the low-dimensional number of components of the projections.
|
|
126
|
+
Defaults to 16.
|
|
127
|
+
dtype_cast (torch.dtype, optional): the dtype to cast the projection to.
|
|
128
|
+
Defaults to ``torch.bfloat16``.
|
|
129
|
+
as_tensor (bool, optional): if ``True``, the bytes will be turned into integers
|
|
130
|
+
through the builtin ``hash`` function and mapped to a tensor. Default: ``True``.
|
|
131
|
+
|
|
132
|
+
.. warning:: This module relies on the builtin ``hash`` function.
|
|
133
|
+
To get reproducible results across runs, the ``PYTHONHASHSEED`` environment
|
|
134
|
+
variable must be set before the code is run (changing this value during code
|
|
135
|
+
execution is without effect).
|
|
136
|
+
|
|
137
|
+
init_method: TODO
|
|
138
|
+
"""
|
|
139
|
+
|
|
140
|
+
_N_COMPONENTS_DEFAULT = 16
|
|
141
|
+
|
|
142
|
+
def __init__(
|
|
143
|
+
self,
|
|
144
|
+
*,
|
|
145
|
+
n_components: int | None = None,
|
|
146
|
+
dtype_cast=torch.bfloat16,
|
|
147
|
+
as_tensor: bool = True,
|
|
148
|
+
init_method: Callable[[torch.Tensor], torch.Tensor | None] | None = None,
|
|
149
|
+
**kwargs,
|
|
150
|
+
):
|
|
151
|
+
if n_components is None:
|
|
152
|
+
n_components = self._N_COMPONENTS_DEFAULT
|
|
153
|
+
|
|
154
|
+
super().__init__(as_tensor=as_tensor)
|
|
155
|
+
self.register_buffer("_n_components", torch.as_tensor(n_components))
|
|
156
|
+
|
|
157
|
+
self._init = False
|
|
158
|
+
if init_method is None:
|
|
159
|
+
init_method = torch.nn.init.normal_
|
|
160
|
+
self.init_method = init_method
|
|
161
|
+
|
|
162
|
+
self.dtype_cast = dtype_cast
|
|
163
|
+
self.register_buffer("transform", torch.nn.UninitializedBuffer())
|
|
164
|
+
|
|
165
|
+
@property
|
|
166
|
+
def n_components(self):
|
|
167
|
+
return self._n_components.item()
|
|
168
|
+
|
|
169
|
+
def fit(self, x):
|
|
170
|
+
"""Fits the random projection to the input data."""
|
|
171
|
+
self.transform.materialize(
|
|
172
|
+
(x.shape[-1], self.n_components), dtype=self.dtype_cast, device=x.device
|
|
173
|
+
)
|
|
174
|
+
self.init_method(self.transform)
|
|
175
|
+
self._init = True
|
|
176
|
+
|
|
177
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
178
|
+
if not self._init:
|
|
179
|
+
self.fit(x)
|
|
180
|
+
elif not self._init:
|
|
181
|
+
raise RuntimeError(
|
|
182
|
+
f"The {type(self).__name__} has not been initialized. Call fit before calling this method."
|
|
183
|
+
)
|
|
184
|
+
x = x.to(self.dtype_cast) @ self.transform
|
|
185
|
+
return super().forward(x)
|
|
@@ -0,0 +1,204 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
from collections.abc import Callable, Mapping
|
|
8
|
+
|
|
9
|
+
from copy import deepcopy
|
|
10
|
+
from typing import Any, TypeVar
|
|
11
|
+
|
|
12
|
+
import torch
|
|
13
|
+
import torch.nn as nn
|
|
14
|
+
from tensordict import NestedKey, TensorDictBase
|
|
15
|
+
from tensordict.nn.common import TensorDictModuleBase
|
|
16
|
+
from torchrl._utils import logger as torchrl_logger
|
|
17
|
+
from torchrl.data.map.hash import SipHash
|
|
18
|
+
|
|
19
|
+
K = TypeVar("K")
|
|
20
|
+
V = TypeVar("V")
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class HashToInt(nn.Module):
|
|
24
|
+
"""Converts a hash value to an integer that can be used for indexing a contiguous storage."""
|
|
25
|
+
|
|
26
|
+
def __init__(self):
|
|
27
|
+
super().__init__()
|
|
28
|
+
self._index_to_index = {}
|
|
29
|
+
|
|
30
|
+
def __call__(self, key: torch.Tensor, extend: bool = False) -> torch.Tensor:
|
|
31
|
+
result = []
|
|
32
|
+
if extend:
|
|
33
|
+
for _item in key.tolist():
|
|
34
|
+
result.append(
|
|
35
|
+
self._index_to_index.setdefault(_item, len(self._index_to_index))
|
|
36
|
+
)
|
|
37
|
+
else:
|
|
38
|
+
for _item in key.tolist():
|
|
39
|
+
result.append(
|
|
40
|
+
self._index_to_index.get(_item, len(self._index_to_index))
|
|
41
|
+
)
|
|
42
|
+
return torch.tensor(result, device=key.device, dtype=key.dtype)
|
|
43
|
+
|
|
44
|
+
def state_dict(self) -> dict[str, torch.Tensor]:
|
|
45
|
+
values = torch.tensor(self._index_to_index.values())
|
|
46
|
+
keys = torch.tensor(self._index_to_index.keys())
|
|
47
|
+
return {"keys": keys, "values": values}
|
|
48
|
+
|
|
49
|
+
def load_state_dict(
|
|
50
|
+
self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False
|
|
51
|
+
):
|
|
52
|
+
keys = state_dict["keys"]
|
|
53
|
+
values = state_dict["values"]
|
|
54
|
+
self._index_to_index = {
|
|
55
|
+
key: val for key, val in zip(keys.tolist(), values.tolist())
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class QueryModule(TensorDictModuleBase):
|
|
60
|
+
"""A Module to generate compatible indices for storage.
|
|
61
|
+
|
|
62
|
+
A module that queries a storage and return required index of that storage.
|
|
63
|
+
Currently, it only outputs integer indices (torch.int64).
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
in_keys (list of NestedKeys): keys of the input tensordict that
|
|
67
|
+
will be used to generate the hash value.
|
|
68
|
+
index_key (NestedKey): the output key where the index value will be written.
|
|
69
|
+
Defaults to ``"_index"``.
|
|
70
|
+
|
|
71
|
+
Keyword Args:
|
|
72
|
+
hash_key (NestedKey): the output key where the hash value will be written.
|
|
73
|
+
Defaults to ``"_hash"``.
|
|
74
|
+
hash_module (Callable[[Any], int] or a list of these, optional): a hash
|
|
75
|
+
module similar to :class:`~tensordict.nn.SipHash` (default).
|
|
76
|
+
If a list of callables is provided, its length must equate the number of in_keys.
|
|
77
|
+
hash_to_int (Callable[[int], int], optional): a stateful function that
|
|
78
|
+
maps a hash value to a non-negative integer corresponding to an index in a
|
|
79
|
+
storage. Defaults to :class:`~torchrl.data.map.HashToInt`.
|
|
80
|
+
aggregator (Callable[[int], int], optional): a hash function to group multiple hashes
|
|
81
|
+
together. This argument should only be passed when there is more than one ``in_keys``.
|
|
82
|
+
If a single ``hash_module`` is provided but no aggregator is passed, it will take
|
|
83
|
+
the value of the hash_module. If no ``hash_module`` or a list of ``hash_modules`` is
|
|
84
|
+
provided but no aggregator is passed, it will default to ``SipHash``.
|
|
85
|
+
clone (bool, optional): if ``True``, a shallow clone of the input TensorDict will be
|
|
86
|
+
returned. This can be used to retrieve the integer index within the storage,
|
|
87
|
+
corresponding to a given input tensordict. This can be overridden at runtime by
|
|
88
|
+
providing the ``clone`` argument to the forward method.
|
|
89
|
+
Defaults to ``False``.
|
|
90
|
+
|
|
91
|
+
Examples:
|
|
92
|
+
>>> query_module = QueryModule(
|
|
93
|
+
... in_keys=["key1", "key2"],
|
|
94
|
+
... index_key="index",
|
|
95
|
+
... hash_module=SipHash(),
|
|
96
|
+
... )
|
|
97
|
+
>>> query = TensorDict(
|
|
98
|
+
... {
|
|
99
|
+
... "key1": torch.Tensor([[1], [1], [1], [2]]),
|
|
100
|
+
... "key2": torch.Tensor([[3], [3], [2], [3]]),
|
|
101
|
+
... "other": torch.randn(4),
|
|
102
|
+
... },
|
|
103
|
+
... batch_size=(4,),
|
|
104
|
+
... )
|
|
105
|
+
>>> res = query_module(query)
|
|
106
|
+
>>> # The first two pairs of key1 and key2 match
|
|
107
|
+
>>> assert res["index"][0] == res["index"][1]
|
|
108
|
+
>>> # The last three pairs of key1 and key2 have at least one mismatching value
|
|
109
|
+
>>> assert res["index"][1] != res["index"][2]
|
|
110
|
+
>>> assert res["index"][2] != res["index"][3]
|
|
111
|
+
|
|
112
|
+
"""
|
|
113
|
+
|
|
114
|
+
def __init__(
|
|
115
|
+
self,
|
|
116
|
+
in_keys: list[NestedKey],
|
|
117
|
+
index_key: NestedKey = "_index",
|
|
118
|
+
hash_key: NestedKey = "_hash",
|
|
119
|
+
*,
|
|
120
|
+
hash_module: Callable[[Any], int] | list[Callable[[Any], int]] | None = None,
|
|
121
|
+
hash_to_int: Callable[[int], int] | None = None,
|
|
122
|
+
aggregator: Callable[[Any], int] = None,
|
|
123
|
+
clone: bool = False,
|
|
124
|
+
):
|
|
125
|
+
if len(in_keys) == 0:
|
|
126
|
+
raise ValueError("`in_keys` cannot be empty.")
|
|
127
|
+
in_keys = in_keys if isinstance(in_keys, list) else [in_keys]
|
|
128
|
+
|
|
129
|
+
super().__init__()
|
|
130
|
+
in_keys = self.in_keys = in_keys
|
|
131
|
+
self.out_keys = [index_key, hash_key]
|
|
132
|
+
index_key = self.out_keys[0]
|
|
133
|
+
self.hash_key = self.out_keys[1]
|
|
134
|
+
|
|
135
|
+
if aggregator is not None and len(self.in_keys) == 1:
|
|
136
|
+
torchrl_logger.warn(
|
|
137
|
+
"An aggregator was provided but there is only one in-key to be read. "
|
|
138
|
+
"This module will be ignored."
|
|
139
|
+
)
|
|
140
|
+
elif aggregator is None:
|
|
141
|
+
if hash_module is not None and not isinstance(hash_module, list):
|
|
142
|
+
aggregator = hash_module
|
|
143
|
+
else:
|
|
144
|
+
aggregator = SipHash()
|
|
145
|
+
if hash_module is None:
|
|
146
|
+
hash_module = [SipHash() for _ in range(len(self.in_keys))]
|
|
147
|
+
elif not isinstance(hash_module, list):
|
|
148
|
+
try:
|
|
149
|
+
hash_module = [
|
|
150
|
+
deepcopy(hash_module) if len(self.in_keys) > 1 else hash_module
|
|
151
|
+
for _ in range(len(self.in_keys))
|
|
152
|
+
]
|
|
153
|
+
except Exception as err:
|
|
154
|
+
raise RuntimeError(
|
|
155
|
+
"failed to deepcopy the hash module. Please provide a list of hash modules instead."
|
|
156
|
+
) from err
|
|
157
|
+
elif len(hash_module) != len(self.in_keys):
|
|
158
|
+
raise ValueError(
|
|
159
|
+
"The number of hash_modules must match the number of in_keys. "
|
|
160
|
+
f"Got {len(hash_module)} hash modules but {len(in_keys)} in_keys."
|
|
161
|
+
)
|
|
162
|
+
if hash_to_int is None:
|
|
163
|
+
hash_to_int = HashToInt()
|
|
164
|
+
|
|
165
|
+
self.aggregator = aggregator
|
|
166
|
+
self.hash_module = dict(zip(self.in_keys, hash_module))
|
|
167
|
+
self.hash_to_int = hash_to_int
|
|
168
|
+
|
|
169
|
+
self.index_key = index_key
|
|
170
|
+
self.clone = clone
|
|
171
|
+
|
|
172
|
+
def forward(
|
|
173
|
+
self,
|
|
174
|
+
tensordict: TensorDictBase,
|
|
175
|
+
*,
|
|
176
|
+
extend: bool = True,
|
|
177
|
+
write_hash: bool = True,
|
|
178
|
+
clone: bool | None = None,
|
|
179
|
+
) -> TensorDictBase:
|
|
180
|
+
hash_values = []
|
|
181
|
+
|
|
182
|
+
for k in self.in_keys:
|
|
183
|
+
hash_values.append(self.hash_module[k](tensordict.get(k)))
|
|
184
|
+
if len(self.in_keys) > 1:
|
|
185
|
+
hash_values = torch.stack(
|
|
186
|
+
hash_values,
|
|
187
|
+
dim=-1,
|
|
188
|
+
)
|
|
189
|
+
hash_values = self.aggregator(hash_values)
|
|
190
|
+
else:
|
|
191
|
+
hash_values = hash_values[0]
|
|
192
|
+
|
|
193
|
+
td_hash_value = self.hash_to_int(hash_values, extend=extend)
|
|
194
|
+
|
|
195
|
+
clone = clone if clone is not None else self.clone
|
|
196
|
+
if clone:
|
|
197
|
+
output = tensordict.copy()
|
|
198
|
+
else:
|
|
199
|
+
output = tensordict
|
|
200
|
+
|
|
201
|
+
output.set(self.index_key, td_hash_value)
|
|
202
|
+
if write_hash:
|
|
203
|
+
output.set(self.hash_key, hash_values)
|
|
204
|
+
return output
|