torchrl 0.11.0__cp314-cp314-manylinux_2_28_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cpython-314-aarch64-linux-gnu.so +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/METADATA +1308 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,360 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
from tensordict import TensorDictBase
|
|
10
|
+
from tensordict.utils import expand_right
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _custom_conv1d(tensor: torch.Tensor, filter: torch.Tensor):
|
|
14
|
+
"""Computes a conv1d filter over a value.
|
|
15
|
+
|
|
16
|
+
This is usually used to compute a discounted return:
|
|
17
|
+
|
|
18
|
+
Tensor: Filter Result (discounted return)
|
|
19
|
+
[ r_0, [ 1.0, [ r_0 + g r_1 + g^2 r_2 + r^3 r_3,
|
|
20
|
+
r_1, g, r_1 + g r_2 + g^2 r_3,
|
|
21
|
+
r_2, g^2, r_2 + g r_3,
|
|
22
|
+
r_3, g^3 ] r_3 ]
|
|
23
|
+
0, | |
|
|
24
|
+
0, | zero padding | direction of filter
|
|
25
|
+
0 ] | v
|
|
26
|
+
|
|
27
|
+
This function takes care of applying the one-sided zero padding. In this example,
|
|
28
|
+
`Filter_dim` = :obj:`Time` = 4, but in practice Filter_dim can be <= to :obj:`Time`.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
tensor (torch.Tensor): a [ Batch x 1 x Time ] floating-point tensor
|
|
32
|
+
filter (torch.Tensor): a [ Filter_dim x 1 ] floating-point filter
|
|
33
|
+
|
|
34
|
+
Returns: a filtered tensor of the same shape as the input tensor.
|
|
35
|
+
|
|
36
|
+
"""
|
|
37
|
+
if filter.ndimension() > 2:
|
|
38
|
+
# filter will have shape batch_dims x timesteps x filter_dim x 1
|
|
39
|
+
# reshape to batch_dims x timesteps x 1 x filter_dim ready for convolving
|
|
40
|
+
filter = filter.view(*filter.shape[:-2], 1, filter.shape[-2])
|
|
41
|
+
|
|
42
|
+
# because time is represented on two different dimensions, we don't
|
|
43
|
+
# need all convolutions, just those lying along a diagonal
|
|
44
|
+
# rather than compute them all and discard, we stack just the slices
|
|
45
|
+
# of val_pad that we care about, and apply the filter manually
|
|
46
|
+
|
|
47
|
+
# STACK VERSION: val_pad is computed as in the block below
|
|
48
|
+
# batched_val_pad = torch.stack(
|
|
49
|
+
# [val_pad[..., i : i + filter.shape[-1]] for i in range(tensor.shape[-1])],
|
|
50
|
+
# dim=1,
|
|
51
|
+
# )
|
|
52
|
+
|
|
53
|
+
# roll version
|
|
54
|
+
T = tensor.shape[-1]
|
|
55
|
+
device = tensor.device
|
|
56
|
+
batched_val_pad = (
|
|
57
|
+
roll_by_gather(
|
|
58
|
+
tensor.expand(tensor.shape[0], filter.shape[-1], T).transpose(-2, -1),
|
|
59
|
+
0,
|
|
60
|
+
-torch.arange(filter.shape[-1], device=device),
|
|
61
|
+
)
|
|
62
|
+
.flip(-1)
|
|
63
|
+
.triu(filter.shape[-1] - T)
|
|
64
|
+
.flip(-1)
|
|
65
|
+
.unsqueeze(-2)
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
# this is just a batched matrix multiplication, but einsum makes it
|
|
69
|
+
# easy to keep the many dimensions under control. Here b = batch,
|
|
70
|
+
# t = timestep, s = singleton, j is the filter dimension that should
|
|
71
|
+
# get summed out. we swap the order of s and t here rather than
|
|
72
|
+
# reshape / create a view later.
|
|
73
|
+
# this is essentially identical to (batched_val_pad @ filter.transpose(-2, -1)).squeeze().unsqueeze(-2)
|
|
74
|
+
# out = (batched_val_pad @ filter.transpose(-2, -1)).squeeze().unsqueeze(-2)
|
|
75
|
+
out = torch.einsum("btsj,btsj->bst", batched_val_pad, filter)
|
|
76
|
+
else:
|
|
77
|
+
val_pad = torch.nn.functional.pad(tensor, [0, filter.shape[-2] - 1])
|
|
78
|
+
|
|
79
|
+
# shape = val.shape
|
|
80
|
+
filter = filter.squeeze(-1).unsqueeze(0).unsqueeze(0) # 1 x 1 x T
|
|
81
|
+
out = torch.conv1d(val_pad, filter)
|
|
82
|
+
# out = out.view(shape)
|
|
83
|
+
if out.shape != tensor.shape:
|
|
84
|
+
raise RuntimeError(
|
|
85
|
+
f"wrong output shape: input shape: {tensor.shape}, output shape: {out.shape}"
|
|
86
|
+
)
|
|
87
|
+
return out
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def roll_by_gather(mat: torch.Tensor, dim: int, shifts: torch.LongTensor):
|
|
91
|
+
"""Rolls a batched matrix along the last or last but one dimension.
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
mat (torch.Tensor): A batched matrix to roll
|
|
95
|
+
dim (int): 0 or -2 indicates the last but one dimension,
|
|
96
|
+
1 or -1 the last dimension.
|
|
97
|
+
shifts (torch.LongTensor): A tensor containing the shifts. Must have the same number of
|
|
98
|
+
elements as the unchosen dimension.
|
|
99
|
+
|
|
100
|
+
Examples:
|
|
101
|
+
>>> x = torch.arange(12).view(3, 4)
|
|
102
|
+
>>> roll_by_gather(x, 0, -torch.arange(4)) # shifts the values in each column
|
|
103
|
+
tensor([[ 0, 5, 10, 3],
|
|
104
|
+
[ 4, 9, 2, 7],
|
|
105
|
+
[ 8, 1, 6, 11]])
|
|
106
|
+
>>> roll_by_gather(x, 1, -torch.arange(3)) # shifts the values in each row
|
|
107
|
+
tensor([[ 0, 1, 2, 3],
|
|
108
|
+
[ 5, 6, 7, 4],
|
|
109
|
+
[10, 11, 8, 9]])
|
|
110
|
+
|
|
111
|
+
"""
|
|
112
|
+
# assumes 2D array
|
|
113
|
+
*batch, n_rows, n_cols = mat.shape
|
|
114
|
+
device = mat.device
|
|
115
|
+
|
|
116
|
+
if dim in (0, -2):
|
|
117
|
+
arange1 = (
|
|
118
|
+
torch.arange(n_rows, device=device).unsqueeze(-1).expand((n_rows, n_cols))
|
|
119
|
+
)
|
|
120
|
+
arange2 = (arange1 - shifts) % n_rows
|
|
121
|
+
return torch.gather(mat, -2, arange2.expand(*batch, *arange2.shape))
|
|
122
|
+
elif dim in (1, -1):
|
|
123
|
+
arange1 = torch.arange(n_cols, device=device).expand((n_rows, n_cols))
|
|
124
|
+
arange2 = (arange1 - shifts.unsqueeze(-1)) % n_cols
|
|
125
|
+
return torch.gather(mat, -1, arange2.expand(*batch, n_rows, n_cols))
|
|
126
|
+
else:
|
|
127
|
+
raise NotImplementedError(f"dim {dim} is not supported.")
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def _make_gammas_tensor(gamma: torch.Tensor, T: int, rolling_gamma: bool):
|
|
131
|
+
"""Prepares a decay tensor for a matrix multiplication.
|
|
132
|
+
|
|
133
|
+
Given a tensor gamma of size [*batch, T, D],
|
|
134
|
+
it will return a new tensor with size [*batch, T, T+1, D].
|
|
135
|
+
In the rolling_gamma case, a rolling of the gamma values will be performed
|
|
136
|
+
along the T axis, e.g.:
|
|
137
|
+
[[ 1, g1, g2, g3],
|
|
138
|
+
[ 1, g2, g3, 0],
|
|
139
|
+
[ 1, g3, 0, 0]]
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
gamma (torch.tensor): the gamma tensor to be prepared.
|
|
143
|
+
T (int): the time length
|
|
144
|
+
rolling_gamma (bool): if ``True``, the gamma value is set for each step
|
|
145
|
+
independently. If False, the gamma value at (i, t) will be used for the
|
|
146
|
+
trajectory following (i, t).
|
|
147
|
+
|
|
148
|
+
Returns: the prepared gamma decay tensor
|
|
149
|
+
|
|
150
|
+
"""
|
|
151
|
+
# some reshaping code vendored from vec_td_lambda_return_estimate
|
|
152
|
+
gamma = gamma.transpose(-2, -1).contiguous()
|
|
153
|
+
gamma = gamma.view(-1, T)
|
|
154
|
+
dtype = gamma.dtype
|
|
155
|
+
device = gamma.device
|
|
156
|
+
if rolling_gamma:
|
|
157
|
+
# # loop
|
|
158
|
+
# gammas = gamma.unsqueeze(-2).expand(gamma.shape[0], T, T).contiguous()
|
|
159
|
+
# for i in range(1, T):
|
|
160
|
+
# s = gammas[:, i].clone()
|
|
161
|
+
# gammas[:, i] = 0
|
|
162
|
+
# gammas[:, i, :-i] = s[:, i:]
|
|
163
|
+
# gammas = torch.cumprod(gammas.unsqueeze(-1), -2)
|
|
164
|
+
# gammas_cont = torch.ones(gammas.shape[0], T, T, 1)
|
|
165
|
+
# gammas_cont[..., 1:, :] = gammas[..., :-1, :]
|
|
166
|
+
# gammas = gammas_cont
|
|
167
|
+
|
|
168
|
+
# vectorized version
|
|
169
|
+
gammas = torch.ones(gamma.shape[0], T, T + 1, 1, dtype=dtype, device=device)
|
|
170
|
+
s0 = gamma.unsqueeze(-1).expand(gamma.shape[0], T, T)
|
|
171
|
+
s1 = roll_by_gather(s0, 0, shifts=-torch.arange(T, device=device))
|
|
172
|
+
|
|
173
|
+
# we should triu here, but it's useless since there is a triu on the values
|
|
174
|
+
# happening in _custom_conv1d
|
|
175
|
+
# s2 = s1.flip(-1).triu().flip(-1).transpose(-2, -1)
|
|
176
|
+
s2 = s1.transpose(-2, -1)
|
|
177
|
+
gammas[..., 1:, :] = s2.unsqueeze(-1)
|
|
178
|
+
else:
|
|
179
|
+
gammas = torch.ones(*gamma.shape, T + 1, 1, device=device, dtype=dtype)
|
|
180
|
+
gammas[..., 1:, :] = gamma[..., None, None]
|
|
181
|
+
return gammas
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def _flatten_batch(tensor, time_dim=-1):
|
|
185
|
+
"""Because we mark the end of each batch with a truncated signal, we can concatenate them.
|
|
186
|
+
|
|
187
|
+
Args:
|
|
188
|
+
tensor (torch.Tensor): a tensor of shape [*B, T, *F]
|
|
189
|
+
time_dim (int, optional): the time dimension T. Defaults to -1.
|
|
190
|
+
|
|
191
|
+
"""
|
|
192
|
+
return tensor.flatten(0, time_dim)
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def _get_num_per_traj(done):
|
|
196
|
+
"""Because we mark the end of each batch with a truncated signal, we can concatenate them.
|
|
197
|
+
|
|
198
|
+
Args:
|
|
199
|
+
done (torch.Tensor): A done or truncated mark of shape [*B, T]
|
|
200
|
+
|
|
201
|
+
Returns:
|
|
202
|
+
A list of integers representing the number of steps in each trajectory
|
|
203
|
+
|
|
204
|
+
"""
|
|
205
|
+
done = done.clone()
|
|
206
|
+
done[..., -1] = True
|
|
207
|
+
# TODO: find a way of copying once only, eg not using reshape
|
|
208
|
+
num_per_traj = torch.where(done.reshape(-1))[0] + 1
|
|
209
|
+
num_per_traj[1:] = num_per_traj[1:] - num_per_traj[:-1]
|
|
210
|
+
return num_per_traj
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def _split_and_pad_sequence(
|
|
214
|
+
tensor: torch.Tensor | TensorDictBase,
|
|
215
|
+
splits: torch.Tensor,
|
|
216
|
+
return_mask=False,
|
|
217
|
+
time_dim=-1,
|
|
218
|
+
):
|
|
219
|
+
"""Given a tensor of size [*B, T, F] and the corresponding traj lengths (flattened), returns the padded trajectories [NPad, Tmax, *other].
|
|
220
|
+
|
|
221
|
+
Compatible with tensordict inputs.
|
|
222
|
+
|
|
223
|
+
Examples:
|
|
224
|
+
>>> from tensordict import TensorDict
|
|
225
|
+
>>> is_init = torch.zeros(4, 5, dtype=torch.bool)
|
|
226
|
+
>>> is_init[:, 0] = True
|
|
227
|
+
>>> is_init[0, 3] = True
|
|
228
|
+
>>> is_init[1, 2] = True
|
|
229
|
+
>>> tensordict = TensorDict({
|
|
230
|
+
... "is_init": is_init,
|
|
231
|
+
... "obs": torch.arange(20).view(4, 5).unsqueeze(-1).expand(4, 5, 3),
|
|
232
|
+
... }, [4, 5])
|
|
233
|
+
>>> splits = _get_num_per_traj_init(is_init)
|
|
234
|
+
>>> print(splits)
|
|
235
|
+
tensor([3, 2, 2, 3, 5, 5])
|
|
236
|
+
>>> td = _split_and_pad_sequence(tensordict, splits)
|
|
237
|
+
>>> print(td)
|
|
238
|
+
TensorDict(
|
|
239
|
+
fields={
|
|
240
|
+
is_init: Tensor(shape=torch.Size([6, 5]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
241
|
+
obs: Tensor(shape=torch.Size([6, 5, 3]), device=cpu, dtype=torch.int64, is_shared=False)},
|
|
242
|
+
batch_size=torch.Size([6, 5]),
|
|
243
|
+
device=None,
|
|
244
|
+
is_shared=False)
|
|
245
|
+
>>> print(td["obs"])
|
|
246
|
+
tensor([[[ 0, 0, 0],
|
|
247
|
+
[ 1, 1, 1],
|
|
248
|
+
[ 2, 2, 2],
|
|
249
|
+
[ 0, 0, 0],
|
|
250
|
+
[ 0, 0, 0]],
|
|
251
|
+
<BLANKLINE>
|
|
252
|
+
[[ 3, 3, 3],
|
|
253
|
+
[ 4, 4, 4],
|
|
254
|
+
[ 0, 0, 0],
|
|
255
|
+
[ 0, 0, 0],
|
|
256
|
+
[ 0, 0, 0]],
|
|
257
|
+
<BLANKLINE>
|
|
258
|
+
[[ 5, 5, 5],
|
|
259
|
+
[ 6, 6, 6],
|
|
260
|
+
[ 0, 0, 0],
|
|
261
|
+
[ 0, 0, 0],
|
|
262
|
+
[ 0, 0, 0]],
|
|
263
|
+
<BLANKLINE>
|
|
264
|
+
[[ 7, 7, 7],
|
|
265
|
+
[ 8, 8, 8],
|
|
266
|
+
[ 9, 9, 9],
|
|
267
|
+
[ 0, 0, 0],
|
|
268
|
+
[ 0, 0, 0]],
|
|
269
|
+
<BLANKLINE>
|
|
270
|
+
[[10, 10, 10],
|
|
271
|
+
[11, 11, 11],
|
|
272
|
+
[12, 12, 12],
|
|
273
|
+
[13, 13, 13],
|
|
274
|
+
[14, 14, 14]],
|
|
275
|
+
<BLANKLINE>
|
|
276
|
+
[[15, 15, 15],
|
|
277
|
+
[16, 16, 16],
|
|
278
|
+
[17, 17, 17],
|
|
279
|
+
[18, 18, 18],
|
|
280
|
+
[19, 19, 19]]])
|
|
281
|
+
|
|
282
|
+
"""
|
|
283
|
+
max_seq_len = torch.max(splits)
|
|
284
|
+
shape = (len(splits), max_seq_len)
|
|
285
|
+
|
|
286
|
+
# int16 supports length up to 32767
|
|
287
|
+
dtype = (
|
|
288
|
+
torch.int16
|
|
289
|
+
if tensor.size(time_dim) < torch.iinfo(torch.int16).max
|
|
290
|
+
else torch.int32
|
|
291
|
+
)
|
|
292
|
+
arange = torch.arange(max_seq_len, device=tensor.device, dtype=dtype).unsqueeze(0)
|
|
293
|
+
mask = arange < splits.unsqueeze(1)
|
|
294
|
+
|
|
295
|
+
tensor = _flatten_batch(tensor, time_dim=time_dim)
|
|
296
|
+
|
|
297
|
+
def _fill_tensor(tensor):
|
|
298
|
+
empty_tensor = torch.zeros(
|
|
299
|
+
*shape,
|
|
300
|
+
*tensor.shape[1:],
|
|
301
|
+
dtype=tensor.dtype,
|
|
302
|
+
device=tensor.device,
|
|
303
|
+
)
|
|
304
|
+
mask_expand = expand_right(mask, (*mask.shape, *tensor.shape[1:]))
|
|
305
|
+
# We need to use masked-scatter to accommodate vmap
|
|
306
|
+
return torch.masked_scatter(empty_tensor, mask_expand, tensor.reshape(-1))
|
|
307
|
+
# empty_tensor[mask_expand] = tensor.reshape(-1)
|
|
308
|
+
# return empty_tensor
|
|
309
|
+
|
|
310
|
+
if isinstance(tensor, TensorDictBase):
|
|
311
|
+
tensor = tensor.apply(_fill_tensor, batch_size=list(shape))
|
|
312
|
+
else:
|
|
313
|
+
tensor = _fill_tensor(tensor)
|
|
314
|
+
if return_mask:
|
|
315
|
+
return tensor, mask
|
|
316
|
+
return tensor
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
def _inv_pad_sequence(
|
|
320
|
+
tensor: torch.Tensor | TensorDictBase,
|
|
321
|
+
splits: torch.Tensor,
|
|
322
|
+
mask: torch.Tensor = None,
|
|
323
|
+
):
|
|
324
|
+
"""Inverse a pad_sequence operation.
|
|
325
|
+
|
|
326
|
+
If tensor is of shape [B, T], than splits must be of of shape [B] with all elements
|
|
327
|
+
and integer between [1, T].
|
|
328
|
+
The result will be flattened along the batch dimension(s) and must be reshaped into
|
|
329
|
+
the original shape (if necessary).
|
|
330
|
+
|
|
331
|
+
Examples:
|
|
332
|
+
>>> rewards = torch.randn(100, 20)
|
|
333
|
+
>>> num_per_traj = _get_num_per_traj(torch.zeros(100, 20).bernoulli_(0.1))
|
|
334
|
+
>>> padded = _split_and_pad_sequence(rewards, num_per_traj)
|
|
335
|
+
>>> reconstructed = _inv_pad_sequence(padded, num_per_traj)
|
|
336
|
+
>>> assert (reconstructed==rewards).all()
|
|
337
|
+
"""
|
|
338
|
+
if splits.numel() == 1:
|
|
339
|
+
return tensor
|
|
340
|
+
|
|
341
|
+
if mask is None:
|
|
342
|
+
# int16 supports length up to 32767
|
|
343
|
+
dtype = (
|
|
344
|
+
torch.int16
|
|
345
|
+
if tensor.shape[-1] < torch.iinfo(torch.int16).max
|
|
346
|
+
else torch.int32
|
|
347
|
+
)
|
|
348
|
+
arange = torch.arange(
|
|
349
|
+
tensor.shape[-1], device=tensor.device, dtype=dtype
|
|
350
|
+
).unsqueeze(0)
|
|
351
|
+
mask = arange < splits.unsqueeze(1)
|
|
352
|
+
|
|
353
|
+
return tensor[mask]
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
def _get_num_per_traj_init(is_init):
|
|
357
|
+
"""Like _get_num_per_traj, but with is_init signal."""
|
|
358
|
+
done = torch.zeros_like(is_init)
|
|
359
|
+
done[..., :-1][is_init[..., 1:]] = 1
|
|
360
|
+
return _get_num_per_traj(done)
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
from .loggers import CSVLogger, MLFlowLogger, TensorboardLogger, WandbLogger
|
|
7
|
+
from .recorder import PixelRenderTransform, TensorDictRecorder, VideoRecorder
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"CSVLogger",
|
|
11
|
+
"MLFlowLogger",
|
|
12
|
+
"TensorboardLogger",
|
|
13
|
+
"WandbLogger",
|
|
14
|
+
"PixelRenderTransform",
|
|
15
|
+
"TensorDictRecorder",
|
|
16
|
+
"VideoRecorder",
|
|
17
|
+
]
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
from .common import Logger
|
|
7
|
+
|
|
8
|
+
from .csv import CSVLogger
|
|
9
|
+
from .mlflow import MLFlowLogger
|
|
10
|
+
from .tensorboard import TensorboardLogger
|
|
11
|
+
from .utils import generate_exp_name, get_logger
|
|
12
|
+
|
|
13
|
+
from .wandb import WandbLogger
|
|
14
|
+
|
|
15
|
+
__all__ = [
|
|
16
|
+
"Logger",
|
|
17
|
+
"CSVLogger",
|
|
18
|
+
"MLFlowLogger",
|
|
19
|
+
"TensorboardLogger",
|
|
20
|
+
"generate_exp_name",
|
|
21
|
+
"get_logger",
|
|
22
|
+
"WandbLogger",
|
|
23
|
+
]
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import abc
|
|
8
|
+
from collections.abc import Sequence
|
|
9
|
+
|
|
10
|
+
from torch import Tensor
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
__all__ = ["Logger"]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class Logger:
|
|
17
|
+
"""A template for loggers."""
|
|
18
|
+
|
|
19
|
+
def __init__(self, exp_name: str, log_dir: str) -> None:
|
|
20
|
+
self.exp_name = exp_name
|
|
21
|
+
self.log_dir = log_dir
|
|
22
|
+
self.experiment = self._create_experiment()
|
|
23
|
+
|
|
24
|
+
@abc.abstractmethod
|
|
25
|
+
def _create_experiment(self) -> Experiment: # noqa: F821
|
|
26
|
+
...
|
|
27
|
+
|
|
28
|
+
@abc.abstractmethod
|
|
29
|
+
def log_scalar(self, name: str, value: float, step: int | None = None) -> None:
|
|
30
|
+
...
|
|
31
|
+
|
|
32
|
+
@abc.abstractmethod
|
|
33
|
+
def log_video(
|
|
34
|
+
self, name: str, video: Tensor, step: int | None = None, **kwargs
|
|
35
|
+
) -> None:
|
|
36
|
+
...
|
|
37
|
+
|
|
38
|
+
@abc.abstractmethod
|
|
39
|
+
def log_hparams(self, cfg: DictConfig | dict) -> None: # noqa: F821
|
|
40
|
+
...
|
|
41
|
+
|
|
42
|
+
@abc.abstractmethod
|
|
43
|
+
def __repr__(self) -> str:
|
|
44
|
+
...
|
|
45
|
+
|
|
46
|
+
@abc.abstractmethod
|
|
47
|
+
def log_histogram(self, name: str, data: Sequence, **kwargs):
|
|
48
|
+
...
|
|
@@ -0,0 +1,226 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import os
|
|
8
|
+
from collections import defaultdict
|
|
9
|
+
from collections.abc import Sequence
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
|
|
12
|
+
import tensordict.utils
|
|
13
|
+
import torch
|
|
14
|
+
from tensordict import MemoryMappedTensor
|
|
15
|
+
from torch import Tensor
|
|
16
|
+
|
|
17
|
+
from .common import Logger
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class CSVExperiment:
|
|
21
|
+
"""A CSV logger experiment class."""
|
|
22
|
+
|
|
23
|
+
def __init__(self, log_dir: str, *, video_format="pt", video_fps: int = 30):
|
|
24
|
+
self.scalars = defaultdict(list)
|
|
25
|
+
self.videos_counter = defaultdict(int)
|
|
26
|
+
self.text_counter = defaultdict(int)
|
|
27
|
+
self.log_dir = log_dir
|
|
28
|
+
self.video_format = video_format
|
|
29
|
+
self.video_fps = video_fps
|
|
30
|
+
os.makedirs(self.log_dir, exist_ok=True)
|
|
31
|
+
os.makedirs(os.path.join(self.log_dir, "scalars"), exist_ok=True)
|
|
32
|
+
os.makedirs(os.path.join(self.log_dir, "videos"), exist_ok=True)
|
|
33
|
+
os.makedirs(os.path.join(self.log_dir, "texts"), exist_ok=True)
|
|
34
|
+
|
|
35
|
+
self.files = {}
|
|
36
|
+
|
|
37
|
+
def add_scalar(self, name: str, value: float, global_step: int | None = None):
|
|
38
|
+
if global_step is None:
|
|
39
|
+
global_step = len(self.scalars[name])
|
|
40
|
+
value = float(value)
|
|
41
|
+
self.scalars[name].append((global_step, value))
|
|
42
|
+
filepath = os.path.join(self.log_dir, "scalars", "".join([name, ".csv"]))
|
|
43
|
+
if not os.path.isfile(filepath):
|
|
44
|
+
os.makedirs(Path(filepath).parent, exist_ok=True)
|
|
45
|
+
if filepath not in self.files:
|
|
46
|
+
os.makedirs(Path(filepath).parent, exist_ok=True)
|
|
47
|
+
self.files[filepath] = open(filepath, "a+")
|
|
48
|
+
fd = self.files[filepath]
|
|
49
|
+
fd.write(",".join([str(global_step), str(value)]) + "\n")
|
|
50
|
+
fd.flush()
|
|
51
|
+
|
|
52
|
+
def add_video(self, tag, vid_tensor, global_step: int | None = None, **kwargs):
|
|
53
|
+
"""Writes a video on a file on disk.
|
|
54
|
+
|
|
55
|
+
The video format can be one of
|
|
56
|
+
|
|
57
|
+
- `"pt"`: uses :func:`~torch.save` to save the video tensor);
|
|
58
|
+
- `"memmap"`: saved the file as memory-mapped array (reading this file will require
|
|
59
|
+
the dtype and shape to be known at read time);
|
|
60
|
+
- `"mp4"`: saves the file as an `.mp4` file using torchvision :func:`~torchvision.io.write_video`
|
|
61
|
+
API. Any ``kwargs`` passed to ``add_video`` will be transmitted to ``write_video``.
|
|
62
|
+
These include ``preset``, ``crf`` and others.
|
|
63
|
+
See ffmpeg's doc (https://trac.ffmpeg.org/wiki/Encode/H.264) for some more information of the video format options.
|
|
64
|
+
|
|
65
|
+
"""
|
|
66
|
+
if global_step is None:
|
|
67
|
+
global_step = self.videos_counter[tag]
|
|
68
|
+
self.videos_counter[tag] += 1
|
|
69
|
+
if self.video_format == "pt":
|
|
70
|
+
extension = ".pt"
|
|
71
|
+
elif self.video_format == "memmap":
|
|
72
|
+
extension = ".memmap"
|
|
73
|
+
elif self.video_format == "mp4":
|
|
74
|
+
extension = ".mp4"
|
|
75
|
+
else:
|
|
76
|
+
raise ValueError(
|
|
77
|
+
f"Unknown video format {self.video_format}. Must be one of 'pt', 'memmap' or 'mp4'."
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
filepath = os.path.join(
|
|
81
|
+
self.log_dir, "videos", "_".join([tag, str(global_step)]) + extension
|
|
82
|
+
)
|
|
83
|
+
path_to_create = Path(str(filepath)).parent
|
|
84
|
+
os.makedirs(path_to_create, exist_ok=True)
|
|
85
|
+
if self.video_format == "pt":
|
|
86
|
+
torch.save(vid_tensor, filepath)
|
|
87
|
+
elif self.video_format == "memmap":
|
|
88
|
+
MemoryMappedTensor.from_tensor(vid_tensor, filename=filepath)
|
|
89
|
+
elif self.video_format == "mp4":
|
|
90
|
+
import torchvision
|
|
91
|
+
|
|
92
|
+
if vid_tensor.shape[-3] not in (3, 1):
|
|
93
|
+
raise RuntimeError(
|
|
94
|
+
"expected the video tensor to be of format [T, C, H, W] but the third channel "
|
|
95
|
+
f"starting from the end isn't in (1, 3) but is {vid_tensor.shape[-3]}."
|
|
96
|
+
)
|
|
97
|
+
if vid_tensor.ndim > 4:
|
|
98
|
+
vid_tensor = vid_tensor.flatten(0, vid_tensor.ndim - 4)
|
|
99
|
+
vid_tensor = vid_tensor.permute((0, 2, 3, 1))
|
|
100
|
+
vid_tensor = vid_tensor.expand(*vid_tensor.shape[:-1], 3)
|
|
101
|
+
kwargs.setdefault("fps", self.video_fps)
|
|
102
|
+
torchvision.io.write_video(filepath, vid_tensor, **kwargs)
|
|
103
|
+
else:
|
|
104
|
+
raise ValueError(
|
|
105
|
+
f"Unknown video format {self.video_format}. Must be one of 'pt', 'memmap' or 'mp4'."
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
def add_text(self, tag, text, global_step: int | None = None):
|
|
109
|
+
if global_step is None:
|
|
110
|
+
global_step = self.videos_counter[tag]
|
|
111
|
+
self.videos_counter[tag] += 1
|
|
112
|
+
filepath = os.path.join(
|
|
113
|
+
self.log_dir, "texts", "".join([tag, str(global_step)]) + ".txt"
|
|
114
|
+
)
|
|
115
|
+
if not os.path.isfile(filepath):
|
|
116
|
+
os.makedirs(Path(filepath).parent, exist_ok=True)
|
|
117
|
+
if filepath not in self.files:
|
|
118
|
+
self.files[filepath] = open(filepath, "w+")
|
|
119
|
+
fd = self.files[filepath]
|
|
120
|
+
fd.writelines(text)
|
|
121
|
+
fd.flush()
|
|
122
|
+
|
|
123
|
+
def __repr__(self) -> str:
|
|
124
|
+
return f"CSVExperiment(log_dir={self.log_dir})"
|
|
125
|
+
|
|
126
|
+
def __del__(self):
|
|
127
|
+
for val in getattr(self, "files", {}).values():
|
|
128
|
+
val.close()
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
class CSVLogger(Logger):
|
|
132
|
+
"""A minimal-dependency CSV logger.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
exp_name (str): The name of the experiment.
|
|
136
|
+
log_dir (str or Path, optional): where the experiment should be saved.
|
|
137
|
+
Defaults to ``<cur_dir>/csv_logs``.
|
|
138
|
+
video_format (str, optional): how videos should be saved when calling :meth:`~torchrl.record.loggers.csv.CSVExperiment.add_video`. Must be one of
|
|
139
|
+
``"pt"`` (video saved as a `video_<tag>_<step>.pt` file with torch.save),
|
|
140
|
+
``"memmap"`` (video saved as a `video_<tag>_<step>.memmap` file with :class:`~tensordict.MemoryMappedTensor`),
|
|
141
|
+
``"mp4"`` (video saved as a `video_<tag>_<step>.mp4` file, requires torchvision to be installed).
|
|
142
|
+
Defaults to ``"pt"``.
|
|
143
|
+
video_fps (int, optional): the video frames-per-seconds if `video_format="mp4"`. Defaults to 30.
|
|
144
|
+
|
|
145
|
+
"""
|
|
146
|
+
|
|
147
|
+
experiment: CSVExperiment
|
|
148
|
+
|
|
149
|
+
def __init__(
|
|
150
|
+
self,
|
|
151
|
+
exp_name: str,
|
|
152
|
+
log_dir: str | None = None,
|
|
153
|
+
video_format: str = "pt",
|
|
154
|
+
video_fps: int = 30,
|
|
155
|
+
) -> None:
|
|
156
|
+
if log_dir is None:
|
|
157
|
+
log_dir = "csv_logs"
|
|
158
|
+
self.video_format = video_format
|
|
159
|
+
self.video_fps = video_fps
|
|
160
|
+
super().__init__(exp_name=exp_name, log_dir=log_dir)
|
|
161
|
+
self._has_imported_moviepy = False
|
|
162
|
+
|
|
163
|
+
def _create_experiment(self) -> CSVExperiment:
|
|
164
|
+
"""Creates a CSV experiment."""
|
|
165
|
+
log_dir = str(os.path.join(self.log_dir, self.exp_name))
|
|
166
|
+
return CSVExperiment(
|
|
167
|
+
log_dir, video_format=self.video_format, video_fps=self.video_fps
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
def log_scalar(self, name: str, value: float, step: int | None = None) -> None:
|
|
171
|
+
"""Logs a scalar value to the tensorboard.
|
|
172
|
+
|
|
173
|
+
Args:
|
|
174
|
+
name (str): The name of the scalar.
|
|
175
|
+
value (float): The value of the scalar.
|
|
176
|
+
step (int, optional): The step at which the scalar is logged. Defaults to None.
|
|
177
|
+
"""
|
|
178
|
+
self.experiment.add_scalar(name, value, global_step=step)
|
|
179
|
+
|
|
180
|
+
def log_video(
|
|
181
|
+
self, name: str, video: Tensor, step: int | None = None, **kwargs
|
|
182
|
+
) -> None:
|
|
183
|
+
"""Log videos inputs to a .pt (or other format) file.
|
|
184
|
+
|
|
185
|
+
Args:
|
|
186
|
+
name (str): The name of the video.
|
|
187
|
+
video (Tensor): The video to be logged.
|
|
188
|
+
step (int, optional): The step at which the video is logged. Defaults to None.
|
|
189
|
+
**kwargs: other kwargs passed to the underlying video logger.
|
|
190
|
+
|
|
191
|
+
.. note:: If the video format is `mp4`, many more arguments can be passed to the :meth:`~torchvision.io.write_video`
|
|
192
|
+
function.
|
|
193
|
+
For more information on video logging with :class:`~torchrl.record.loggers.csv.CSVLogger`,
|
|
194
|
+
see the :meth:`~torchrl.record.loggers.csv.CSVExperiment.add_video` documentation.
|
|
195
|
+
"""
|
|
196
|
+
# check for correct format of the video tensor ((N), T, C, H, W)
|
|
197
|
+
# check that the color channel (C) is either 1 or 3
|
|
198
|
+
if video.dim() != 5 or video.size(dim=2) not in {1, 3}:
|
|
199
|
+
raise Exception(
|
|
200
|
+
"Wrong format of the video tensor. Should be ((N), T, C, H, W)"
|
|
201
|
+
)
|
|
202
|
+
self.experiment.add_video(
|
|
203
|
+
tag=name,
|
|
204
|
+
vid_tensor=video,
|
|
205
|
+
global_step=step,
|
|
206
|
+
**kwargs,
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
def log_hparams(self, cfg: DictConfig | dict) -> None: # noqa: F821
|
|
210
|
+
"""Logs the hyperparameters of the experiment.
|
|
211
|
+
|
|
212
|
+
Args:
|
|
213
|
+
cfg (DictConfig or dict): The configuration of the experiment.
|
|
214
|
+
"""
|
|
215
|
+
txt = "\n".join([f"{k}: {val}" for k, val in sorted(cfg.items())])
|
|
216
|
+
self.experiment.add_text("hparams", txt)
|
|
217
|
+
|
|
218
|
+
def __repr__(self) -> str:
|
|
219
|
+
return f"CSVLogger(exp_name={self.exp_name}, experiment={self.experiment.__repr__()})"
|
|
220
|
+
|
|
221
|
+
def log_histogram(self, name: str, data: Sequence, **kwargs):
|
|
222
|
+
raise NotImplementedError("Logging histograms in cvs is not permitted.")
|
|
223
|
+
|
|
224
|
+
def print_log_dir(self):
|
|
225
|
+
"""Prints the log directory content."""
|
|
226
|
+
tensordict.utils.print_directory_tree(self.log_dir)
|