torchrl 0.11.0__cp314-cp314t-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/.dylibs/libc++.1.0.dylib +0 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cpython-314t-darwin.so +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/METADATA +1308 -0
- torchrl-0.11.0.dist-info/RECORD +395 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import importlib.util
|
|
8
|
+
|
|
9
|
+
import os
|
|
10
|
+
from collections.abc import Sequence
|
|
11
|
+
from tempfile import TemporaryDirectory
|
|
12
|
+
from typing import Any
|
|
13
|
+
|
|
14
|
+
from torch import Tensor
|
|
15
|
+
|
|
16
|
+
from torchrl.record.loggers.common import Logger
|
|
17
|
+
|
|
18
|
+
_has_tv = importlib.util.find_spec("torchvision") is not None
|
|
19
|
+
|
|
20
|
+
_has_mlflow = importlib.util.find_spec("mlflow") is not None
|
|
21
|
+
_has_omegaconf = importlib.util.find_spec("omegaconf") is not None
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class MLFlowLogger(Logger):
|
|
25
|
+
"""Wrapper for the mlflow logger.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
exp_name (str): The name of the experiment.
|
|
29
|
+
tracking_uri (str): A tracking URI to a datastore that supports MLFlow or a local directory.
|
|
30
|
+
|
|
31
|
+
Keyword Args:
|
|
32
|
+
fps (int, optional): Number of frames per second when recording videos. Defaults to ``30``.
|
|
33
|
+
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __init__(
|
|
37
|
+
self,
|
|
38
|
+
exp_name: str,
|
|
39
|
+
tracking_uri: str,
|
|
40
|
+
tags: dict[str, Any] | None = None,
|
|
41
|
+
*,
|
|
42
|
+
video_fps: int = 30,
|
|
43
|
+
**kwargs,
|
|
44
|
+
) -> None:
|
|
45
|
+
import mlflow
|
|
46
|
+
|
|
47
|
+
self._mlflow_kwargs = {
|
|
48
|
+
"name": exp_name,
|
|
49
|
+
"artifact_location": tracking_uri,
|
|
50
|
+
"tags": tags,
|
|
51
|
+
}
|
|
52
|
+
mlflow.set_tracking_uri(tracking_uri)
|
|
53
|
+
super().__init__(exp_name=exp_name, log_dir=tracking_uri)
|
|
54
|
+
self.video_log_counter = 0
|
|
55
|
+
self.video_fps = video_fps
|
|
56
|
+
|
|
57
|
+
def _create_experiment(self) -> mlflow.ActiveRun: # noqa
|
|
58
|
+
import mlflow
|
|
59
|
+
|
|
60
|
+
"""Creates an mlflow experiment.
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
mlflow.ActiveRun: The mlflow experiment object.
|
|
64
|
+
"""
|
|
65
|
+
if not _has_mlflow:
|
|
66
|
+
raise ImportError("MLFlow is not installed")
|
|
67
|
+
|
|
68
|
+
# Only create experiment if it doesnt exist
|
|
69
|
+
experiment = mlflow.get_experiment_by_name(self._mlflow_kwargs["name"])
|
|
70
|
+
if experiment is None:
|
|
71
|
+
self.id = mlflow.create_experiment(**self._mlflow_kwargs)
|
|
72
|
+
else:
|
|
73
|
+
self.id = experiment.experiment_id
|
|
74
|
+
return mlflow.start_run(experiment_id=self.id)
|
|
75
|
+
|
|
76
|
+
def log_scalar(self, name: str, value: float, step: int | None = None) -> None:
|
|
77
|
+
"""Logs a scalar value to mlflow.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
name (str): The name of the scalar.
|
|
81
|
+
value (float): The value of the scalar.
|
|
82
|
+
step (int, optional): The step at which the scalar is logged.
|
|
83
|
+
Defaults to None.
|
|
84
|
+
"""
|
|
85
|
+
import mlflow
|
|
86
|
+
|
|
87
|
+
mlflow.set_experiment(experiment_id=self.id)
|
|
88
|
+
mlflow.log_metric(key=name, value=value, step=step)
|
|
89
|
+
|
|
90
|
+
def log_video(self, name: str, video: Tensor, **kwargs) -> None:
|
|
91
|
+
"""Log video inputs to mlflow.
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
name (str): The name of the video.
|
|
95
|
+
video (Tensor): The video to be logged, expected to be in (T, C, H, W) format
|
|
96
|
+
for consistency with other loggers.
|
|
97
|
+
**kwargs: Other keyword arguments. By construction, log_video
|
|
98
|
+
supports 'step' (integer indicating the step index) and 'fps' (defaults to ``self.video_fps``).
|
|
99
|
+
"""
|
|
100
|
+
import mlflow
|
|
101
|
+
import torchvision
|
|
102
|
+
|
|
103
|
+
if not _has_tv:
|
|
104
|
+
raise ImportError(
|
|
105
|
+
"Logging a video with MLFlow requires torchvision to be installed."
|
|
106
|
+
)
|
|
107
|
+
mlflow.set_experiment(experiment_id=self.id)
|
|
108
|
+
if video.ndim == 5:
|
|
109
|
+
video = video[-1] # N T C H W -> T C H W
|
|
110
|
+
video = video.permute(0, 2, 3, 1) # T C H W -> T H W C
|
|
111
|
+
if video.size(dim=-1) != 3:
|
|
112
|
+
raise ValueError(
|
|
113
|
+
"The MLFlow logger only supports videos with 3 color channels."
|
|
114
|
+
)
|
|
115
|
+
self.video_log_counter += 1
|
|
116
|
+
fps = kwargs.pop("fps", self.video_fps)
|
|
117
|
+
step = kwargs.pop("step", None)
|
|
118
|
+
with TemporaryDirectory() as temp_dir:
|
|
119
|
+
video_name = f"{name}_step_{step:04}.mp4" if step else f"{name}.mp4"
|
|
120
|
+
with open(os.path.join(temp_dir, video_name), "wb") as f:
|
|
121
|
+
torchvision.io.write_video(filename=f.name, video_array=video, fps=fps)
|
|
122
|
+
mlflow.log_artifact(f.name, "videos")
|
|
123
|
+
|
|
124
|
+
def log_hparams(self, cfg: DictConfig | dict) -> None: # noqa: F821
|
|
125
|
+
"""Logs the hyperparameters of the experiment.
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
cfg (DictConfig or dict): The configuration of the experiment.
|
|
129
|
+
"""
|
|
130
|
+
import mlflow
|
|
131
|
+
from omegaconf import OmegaConf
|
|
132
|
+
|
|
133
|
+
mlflow.set_experiment(experiment_id=self.id)
|
|
134
|
+
if type(cfg) is not dict and _has_omegaconf:
|
|
135
|
+
cfg = OmegaConf.to_container(cfg, resolve=True)
|
|
136
|
+
mlflow.log_params(cfg)
|
|
137
|
+
|
|
138
|
+
def __repr__(self) -> str:
|
|
139
|
+
return f"MLFlowLogger(experiment={self.experiment.__repr__()})"
|
|
140
|
+
|
|
141
|
+
def log_histogram(self, name: str, data: Sequence, **kwargs):
|
|
142
|
+
raise NotImplementedError("Logging histograms in cvs is not permitted.")
|
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import importlib.util
|
|
8
|
+
|
|
9
|
+
import os
|
|
10
|
+
from collections.abc import Sequence
|
|
11
|
+
|
|
12
|
+
from torch import Tensor
|
|
13
|
+
|
|
14
|
+
from .common import Logger
|
|
15
|
+
|
|
16
|
+
_has_tb = importlib.util.find_spec("tensorboard") is not None
|
|
17
|
+
_has_omgaconf = importlib.util.find_spec("omegaconf") is not None
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class TensorboardLogger(Logger):
|
|
21
|
+
"""Wrapper for the Tensoarboard logger.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
exp_name (str): The name of the experiment.
|
|
25
|
+
log_dir (str): the tensorboard log_dir. Defaults to ``td_logs``.
|
|
26
|
+
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
def __init__(self, exp_name: str, log_dir: str = "tb_logs") -> None:
|
|
30
|
+
super().__init__(exp_name=exp_name, log_dir=log_dir)
|
|
31
|
+
# re-write log_dir
|
|
32
|
+
self.log_dir = self.experiment.log_dir
|
|
33
|
+
|
|
34
|
+
self._has_imported_moviepy = False
|
|
35
|
+
|
|
36
|
+
def _create_experiment(self) -> SummaryWriter: # noqa
|
|
37
|
+
"""Creates a tensorboard experiment.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
exp_name (str): The name of the experiment.
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
SummaryWriter: The tensorboard experiment.
|
|
44
|
+
|
|
45
|
+
"""
|
|
46
|
+
if not _has_tb:
|
|
47
|
+
raise ImportError("torch.utils.tensorboard could not be imported")
|
|
48
|
+
|
|
49
|
+
from torch.utils.tensorboard import SummaryWriter
|
|
50
|
+
|
|
51
|
+
log_dir = str(os.path.join(self.log_dir, self.exp_name))
|
|
52
|
+
return SummaryWriter(log_dir=log_dir)
|
|
53
|
+
|
|
54
|
+
def log_scalar(self, name: str, value: float, step: int | None = None) -> None:
|
|
55
|
+
"""Logs a scalar value to the tensorboard.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
name (str): The name of the scalar.
|
|
59
|
+
value (float): The value of the scalar.
|
|
60
|
+
step (int, optional): The step at which the scalar is logged. Defaults to None.
|
|
61
|
+
|
|
62
|
+
"""
|
|
63
|
+
self.experiment.add_scalar(name, value, global_step=step)
|
|
64
|
+
|
|
65
|
+
def log_video(
|
|
66
|
+
self, name: str, video: Tensor, step: int | None = None, **kwargs
|
|
67
|
+
) -> None:
|
|
68
|
+
"""Log videos inputs to the tensorboard.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
name (str): The name of the video.
|
|
72
|
+
video (Tensor): The video to be logged.
|
|
73
|
+
step (int, optional): The step at which the video is logged. Defaults to None.
|
|
74
|
+
|
|
75
|
+
"""
|
|
76
|
+
# check for correct format of the video tensor ((N), T, C, H, W)
|
|
77
|
+
# check that the color channel (C) is either 1 or 3
|
|
78
|
+
if video.dim() != 5 or video.size(dim=2) not in {1, 3}:
|
|
79
|
+
raise Exception(
|
|
80
|
+
"Wrong format of the video tensor. Should be ((N), T, C, H, W)"
|
|
81
|
+
)
|
|
82
|
+
if not self._has_imported_moviepy:
|
|
83
|
+
try:
|
|
84
|
+
import moviepy # noqa
|
|
85
|
+
|
|
86
|
+
self._has_imported_moviepy = True
|
|
87
|
+
except ImportError:
|
|
88
|
+
raise Exception(
|
|
89
|
+
"moviepy not found, videos cannot be logged with TensorboardLogger"
|
|
90
|
+
)
|
|
91
|
+
self.experiment.add_video(
|
|
92
|
+
tag=name,
|
|
93
|
+
vid_tensor=video,
|
|
94
|
+
global_step=step,
|
|
95
|
+
**kwargs,
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
def log_hparams(self, cfg: DictConfig | dict) -> None: # noqa: F821
|
|
99
|
+
"""Logs the hyperparameters of the experiment.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
cfg (DictConfig or dict): The configuration of the experiment.
|
|
103
|
+
|
|
104
|
+
"""
|
|
105
|
+
if type(cfg) is not dict and _has_omgaconf:
|
|
106
|
+
if not _has_omgaconf:
|
|
107
|
+
raise ImportError(
|
|
108
|
+
"OmegaConf could not be imported. "
|
|
109
|
+
"Cannot log hydra configs without OmegaConf."
|
|
110
|
+
)
|
|
111
|
+
from omegaconf import OmegaConf
|
|
112
|
+
|
|
113
|
+
cfg = OmegaConf.to_container(cfg, resolve=True)
|
|
114
|
+
self.experiment.add_hparams(cfg, metric_dict={})
|
|
115
|
+
|
|
116
|
+
def __repr__(self) -> str:
|
|
117
|
+
return f"TensorboardLogger(experiment={self.experiment.__repr__()})"
|
|
118
|
+
|
|
119
|
+
def log_histogram(self, name: str, data: Sequence, **kwargs):
|
|
120
|
+
"""Add histogram to summary.
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
name (str): Data identifier
|
|
124
|
+
data (torch.Tensor, numpy.ndarray, or string/blobname): Values to build histogram
|
|
125
|
+
|
|
126
|
+
Keyword Args:
|
|
127
|
+
step (int): Global step value to record
|
|
128
|
+
bins (str): One of {‘tensorflow’,’auto’, ‘fd’, …}. This determines how the bins are made. You can find other options in: https://docs.scipy.org/doc/numpy/reference/generated/numpy.histogram.html
|
|
129
|
+
walltime (:obj:`float`): Optional override default walltime (time.time()) seconds after epoch of event
|
|
130
|
+
|
|
131
|
+
"""
|
|
132
|
+
global_step = kwargs.pop("step", None)
|
|
133
|
+
bins = kwargs.pop("bins")
|
|
134
|
+
walltime = kwargs.pop("walltime", None)
|
|
135
|
+
if len(kwargs):
|
|
136
|
+
raise TypeError(f"Unrecognised arguments {kwargs}.")
|
|
137
|
+
self.experiment.add_histogram(
|
|
138
|
+
tag=name, values=data, global_step=global_step, bins=bins, walltime=walltime
|
|
139
|
+
)
|
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import importlib.util
|
|
8
|
+
|
|
9
|
+
from collections.abc import Sequence
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
|
|
13
|
+
from torch import Tensor
|
|
14
|
+
|
|
15
|
+
from .common import Logger
|
|
16
|
+
|
|
17
|
+
_has_trackio = importlib.util.find_spec("trackio") is not None
|
|
18
|
+
_has_omegaconf = importlib.util.find_spec("omegaconf") is not None
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class TrackioLogger(Logger):
|
|
22
|
+
"""Wrapper for the trackio logger.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
exp_name (str): The name of the experiment.
|
|
26
|
+
project (str): The name of the project.
|
|
27
|
+
|
|
28
|
+
Keyword Args:
|
|
29
|
+
fps (int, optional): Number of frames per second when recording videos. Defaults to ``30``.
|
|
30
|
+
**kwargs: Extra keyword arguments for ``trackio.init``.
|
|
31
|
+
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
@classmethod
|
|
35
|
+
def __new__(cls, *args, **kwargs):
|
|
36
|
+
return super().__new__(cls)
|
|
37
|
+
|
|
38
|
+
def __init__(
|
|
39
|
+
self,
|
|
40
|
+
exp_name: str,
|
|
41
|
+
project: str,
|
|
42
|
+
*,
|
|
43
|
+
video_fps: int = 32,
|
|
44
|
+
**kwargs,
|
|
45
|
+
) -> None:
|
|
46
|
+
if not _has_trackio:
|
|
47
|
+
raise ImportError("trackio could not be imported")
|
|
48
|
+
|
|
49
|
+
self.video_fps = video_fps
|
|
50
|
+
self._trackio_kwargs = {
|
|
51
|
+
"name": exp_name,
|
|
52
|
+
"project": project,
|
|
53
|
+
"resume": "allow",
|
|
54
|
+
**kwargs,
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
super().__init__(exp_name=exp_name, log_dir=project)
|
|
58
|
+
|
|
59
|
+
def _create_experiment(self):
|
|
60
|
+
"""Creates a trackio experiment.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
exp_name (str): The name of the experiment.
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
A trackio.Experiment object.
|
|
67
|
+
"""
|
|
68
|
+
if not _has_trackio:
|
|
69
|
+
raise ImportError("Trackio is not installed")
|
|
70
|
+
import trackio
|
|
71
|
+
|
|
72
|
+
return trackio.init(**self._trackio_kwargs)
|
|
73
|
+
|
|
74
|
+
def log_scalar(self, name: str, value: float, step: int | None = None) -> None:
|
|
75
|
+
"""Logs a scalar value to trackio.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
name (str): The name of the scalar.
|
|
79
|
+
value (float): The value of the scalar.
|
|
80
|
+
step (int, optional): The step at which the scalar is logged.
|
|
81
|
+
Defaults to None.
|
|
82
|
+
"""
|
|
83
|
+
self.experiment.log({name: value}, step=step)
|
|
84
|
+
|
|
85
|
+
def log_video(self, name: str, video: Tensor, **kwargs) -> None:
|
|
86
|
+
"""Log videos inputs to trackio.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
name (str): The name of the video.
|
|
90
|
+
video (Tensor): The video to be logged.
|
|
91
|
+
**kwargs: Other keyword arguments. By construction, log_video
|
|
92
|
+
supports 'step' (integer indicating the step index), 'format'
|
|
93
|
+
(default is 'mp4') and 'fps' (defaults to ``self.video_fps``). Other kwargs are
|
|
94
|
+
passed as-is to the :obj:`experiment.log` method.
|
|
95
|
+
"""
|
|
96
|
+
import trackio
|
|
97
|
+
|
|
98
|
+
fps = kwargs.pop("fps", self.video_fps)
|
|
99
|
+
format = kwargs.pop("format", "mp4")
|
|
100
|
+
self.experiment.log(
|
|
101
|
+
{
|
|
102
|
+
name: trackio.Video(
|
|
103
|
+
video.numpy().astype(np.uint8), fps=fps, format=format
|
|
104
|
+
)
|
|
105
|
+
},
|
|
106
|
+
**kwargs,
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
def log_hparams(self, cfg: DictConfig | dict) -> None: # noqa: F821
|
|
110
|
+
"""Logs the hyperparameters of the experiment.
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
cfg (DictConfig or dict): The configuration of the experiment.
|
|
114
|
+
|
|
115
|
+
"""
|
|
116
|
+
if type(cfg) is not dict and _has_omegaconf:
|
|
117
|
+
if not _has_omegaconf:
|
|
118
|
+
raise ImportError(
|
|
119
|
+
"OmegaConf could not be imported. "
|
|
120
|
+
"Cannot log hydra configs without OmegaConf."
|
|
121
|
+
)
|
|
122
|
+
from omegaconf import OmegaConf
|
|
123
|
+
|
|
124
|
+
cfg = OmegaConf.to_container(cfg, resolve=True)
|
|
125
|
+
self.experiment.config.update(cfg)
|
|
126
|
+
|
|
127
|
+
def __repr__(self) -> str:
|
|
128
|
+
return f"TrackioLogger(experiment={self.experiment.__repr__()})"
|
|
129
|
+
|
|
130
|
+
def log_histogram(self, name: str, data: Sequence, **kwargs):
|
|
131
|
+
"""Add histogram to log.
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
name (str): Data identifier
|
|
135
|
+
data (torch.Tensor, numpy.ndarray): Values to build histogram
|
|
136
|
+
|
|
137
|
+
Keyword Args:
|
|
138
|
+
step (int): Global step value to record
|
|
139
|
+
bins (int): Number of bins to use for the histogram
|
|
140
|
+
|
|
141
|
+
"""
|
|
142
|
+
import trackio
|
|
143
|
+
|
|
144
|
+
num_bins = kwargs.pop("bins", None)
|
|
145
|
+
step = kwargs.pop("step", None)
|
|
146
|
+
self.experiment.log(
|
|
147
|
+
{name: trackio.Histogram(data, num_bins=num_bins)}, step=step
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
def log_str(self, name: str, value: str, step: int | None = None) -> None:
|
|
151
|
+
"""Logs a string value to trackio using a table format for better visualization.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
name (str): The name of the string data.
|
|
155
|
+
value (str): The string value to log.
|
|
156
|
+
step (int, optional): The step at which the string is logged.
|
|
157
|
+
Defaults to None.
|
|
158
|
+
"""
|
|
159
|
+
import trackio
|
|
160
|
+
|
|
161
|
+
# Create a table with a single row
|
|
162
|
+
table = trackio.Table(columns=["text"], data=[[value]])
|
|
163
|
+
self.experiment.log({name: table}, step=step)
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import os
|
|
8
|
+
import pathlib
|
|
9
|
+
import uuid
|
|
10
|
+
from datetime import datetime
|
|
11
|
+
|
|
12
|
+
from torchrl.record.loggers.common import Logger
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def generate_exp_name(model_name: str, experiment_name: str) -> str:
|
|
16
|
+
"""Generates an ID (str) for the described experiment using UUID and current date."""
|
|
17
|
+
exp_name = "_".join(
|
|
18
|
+
(
|
|
19
|
+
model_name,
|
|
20
|
+
experiment_name,
|
|
21
|
+
str(uuid.uuid4())[:8],
|
|
22
|
+
datetime.now().strftime("%y_%m_%d-%H_%M_%S"),
|
|
23
|
+
)
|
|
24
|
+
)
|
|
25
|
+
return exp_name
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def get_logger(
|
|
29
|
+
logger_type: str, logger_name: str, experiment_name: str, **kwargs
|
|
30
|
+
) -> Logger:
|
|
31
|
+
"""Get a logger instance of the provided `logger_type`.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
logger_type (str): One of tensorboard / csv / wandb / mlflow.
|
|
35
|
+
If empty, ``None`` is returned.
|
|
36
|
+
logger_name (str): Name to be used as a log_dir
|
|
37
|
+
experiment_name (str): Name of the experiment
|
|
38
|
+
kwargs (dict[str]): might contain either `wandb_kwargs`, `mlflow_kwargs` or `trackio_kwargs`
|
|
39
|
+
"""
|
|
40
|
+
if logger_type == "tensorboard":
|
|
41
|
+
from torchrl.record.loggers.tensorboard import TensorboardLogger
|
|
42
|
+
|
|
43
|
+
logger = TensorboardLogger(log_dir=logger_name, exp_name=experiment_name)
|
|
44
|
+
elif logger_type == "csv":
|
|
45
|
+
from torchrl.record.loggers.csv import CSVLogger
|
|
46
|
+
|
|
47
|
+
logger = CSVLogger(
|
|
48
|
+
log_dir=logger_name, exp_name=experiment_name, video_format="mp4"
|
|
49
|
+
)
|
|
50
|
+
elif logger_type == "wandb":
|
|
51
|
+
from torchrl.record.loggers.wandb import WandbLogger
|
|
52
|
+
|
|
53
|
+
wandb_kwargs = kwargs.get("wandb_kwargs", {})
|
|
54
|
+
logger = WandbLogger(
|
|
55
|
+
log_dir=logger_name, exp_name=experiment_name, **wandb_kwargs
|
|
56
|
+
)
|
|
57
|
+
elif logger_type == "mlflow":
|
|
58
|
+
from torchrl.record.loggers.mlflow import MLFlowLogger
|
|
59
|
+
|
|
60
|
+
mlflow_kwargs = kwargs.get("mlflow_kwargs", {})
|
|
61
|
+
logger = MLFlowLogger(
|
|
62
|
+
tracking_uri=pathlib.Path(os.path.abspath(logger_name)).as_uri(),
|
|
63
|
+
exp_name=experiment_name,
|
|
64
|
+
**mlflow_kwargs,
|
|
65
|
+
)
|
|
66
|
+
elif logger_type == "trackio":
|
|
67
|
+
from torchrl.record.loggers.trackio import TrackioLogger
|
|
68
|
+
|
|
69
|
+
trackio_kwargs = kwargs.get("trackio_kwargs", {})
|
|
70
|
+
project = trackio_kwargs.pop("project", "torchrl")
|
|
71
|
+
logger = TrackioLogger(
|
|
72
|
+
project=project, exp_name=experiment_name, **trackio_kwargs
|
|
73
|
+
)
|
|
74
|
+
elif logger_type in ("", None):
|
|
75
|
+
return None
|
|
76
|
+
else:
|
|
77
|
+
raise NotImplementedError(f"Unsupported logger_type: '{logger_type}'")
|
|
78
|
+
return logger
|