torchrl 0.11.0__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/LICENSE +21 -0
- torchrl-0.11.0.dist-info/METADATA +1307 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,214 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import importlib.util
|
|
8
|
+
|
|
9
|
+
import os
|
|
10
|
+
from collections.abc import Sequence
|
|
11
|
+
|
|
12
|
+
from torch import Tensor
|
|
13
|
+
|
|
14
|
+
from .common import Logger
|
|
15
|
+
|
|
16
|
+
_has_wandb = importlib.util.find_spec("wandb") is not None
|
|
17
|
+
_has_omegaconf = importlib.util.find_spec("omegaconf") is not None
|
|
18
|
+
_has_moviepy = importlib.util.find_spec("moviepy") is not None
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class WandbLogger(Logger):
|
|
22
|
+
"""Wrapper for the wandb logger.
|
|
23
|
+
|
|
24
|
+
The keyword arguments are mainly based on the :func:`wandb.init` kwargs.
|
|
25
|
+
See the doc `here <https://docs.wandb.ai/ref/python/init>`__.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
exp_name (str): The name of the experiment.
|
|
29
|
+
offline (bool, optional): if ``True``, the logs will be stored locally
|
|
30
|
+
only. Defaults to ``False``.
|
|
31
|
+
save_dir (path, optional): the directory where to save data. Exclusive with
|
|
32
|
+
``log_dir``.
|
|
33
|
+
log_dir (path, optional): the directory where to save data. Exclusive with
|
|
34
|
+
``save_dir``.
|
|
35
|
+
id (str, optional): A unique ID for this run, used for resuming.
|
|
36
|
+
It must be unique in the project, and if you delete a run you can't reuse the ID.
|
|
37
|
+
project (str, optional): The name of the project where you're sending
|
|
38
|
+
the new run. If the project is not specified, the run is put in
|
|
39
|
+
an ``"Uncategorized"`` project.
|
|
40
|
+
|
|
41
|
+
Keyword Args:
|
|
42
|
+
fps (int, optional): Number of frames per second when recording videos. Defaults to ``30``.
|
|
43
|
+
**kwargs: Extra keyword arguments for ``wandb.init``. See relevant page for
|
|
44
|
+
more info.
|
|
45
|
+
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
@classmethod
|
|
49
|
+
def __new__(cls, *args, **kwargs):
|
|
50
|
+
return super().__new__(cls)
|
|
51
|
+
|
|
52
|
+
def __init__(
|
|
53
|
+
self,
|
|
54
|
+
exp_name: str,
|
|
55
|
+
offline: bool = False,
|
|
56
|
+
save_dir: str | None = None,
|
|
57
|
+
id: str | None = None,
|
|
58
|
+
project: str | None = None,
|
|
59
|
+
*,
|
|
60
|
+
video_fps: int = 32,
|
|
61
|
+
**kwargs,
|
|
62
|
+
) -> None:
|
|
63
|
+
if not _has_wandb:
|
|
64
|
+
raise ImportError("wandb could not be imported")
|
|
65
|
+
|
|
66
|
+
log_dir = kwargs.pop("log_dir", None)
|
|
67
|
+
self.offline = offline
|
|
68
|
+
if save_dir and log_dir:
|
|
69
|
+
raise ValueError(
|
|
70
|
+
"log_dir and save_dir point to the same value in "
|
|
71
|
+
"WandbLogger. Both cannot be specified."
|
|
72
|
+
)
|
|
73
|
+
save_dir = save_dir if save_dir and not log_dir else log_dir
|
|
74
|
+
self.save_dir = save_dir
|
|
75
|
+
self.id = id
|
|
76
|
+
self.project = project
|
|
77
|
+
self.video_fps = video_fps
|
|
78
|
+
self._wandb_kwargs = {
|
|
79
|
+
"name": exp_name,
|
|
80
|
+
"dir": save_dir,
|
|
81
|
+
"id": id,
|
|
82
|
+
"project": project,
|
|
83
|
+
"resume": "allow",
|
|
84
|
+
**kwargs,
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
super().__init__(exp_name=exp_name, log_dir=save_dir)
|
|
88
|
+
if self.offline:
|
|
89
|
+
os.environ["WANDB_MODE"] = "dryrun"
|
|
90
|
+
|
|
91
|
+
def _create_experiment(self):
|
|
92
|
+
"""Creates a wandb experiment.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
exp_name (str): The name of the experiment.
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
A wandb.Experiment object.
|
|
99
|
+
"""
|
|
100
|
+
if not _has_wandb:
|
|
101
|
+
raise ImportError("Wandb is not installed")
|
|
102
|
+
import wandb
|
|
103
|
+
|
|
104
|
+
if self.offline:
|
|
105
|
+
os.environ["WANDB_MODE"] = "dryrun"
|
|
106
|
+
|
|
107
|
+
return wandb.init(**self._wandb_kwargs)
|
|
108
|
+
|
|
109
|
+
def log_scalar(
|
|
110
|
+
self, name: str, value: float, step: int | None = None, commit: bool = False
|
|
111
|
+
) -> None:
|
|
112
|
+
"""Logs a scalar value to wandb.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
name (str): The name of the scalar.
|
|
116
|
+
value (float): The value of the scalar.
|
|
117
|
+
step (int, optional): The step at which the scalar is logged.
|
|
118
|
+
Defaults to None.
|
|
119
|
+
commit: If true, data for current step is assumed to be final (and
|
|
120
|
+
no further data for this step should be logged).
|
|
121
|
+
"""
|
|
122
|
+
self.experiment.log({name: value}, step=step, commit=commit)
|
|
123
|
+
|
|
124
|
+
def log_video(self, name: str, video: Tensor, **kwargs) -> None:
|
|
125
|
+
"""Log videos inputs to wandb.
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
name (str): The name of the video.
|
|
129
|
+
video (Tensor): The video to be logged.
|
|
130
|
+
**kwargs: Other keyword arguments. By construction, log_video
|
|
131
|
+
supports 'step' (integer indicating the step index), 'format'
|
|
132
|
+
(default is 'mp4') and 'fps' (defaults to ``self.video_fps``). Other kwargs are
|
|
133
|
+
passed as-is to the :obj:`experiment.log` method.
|
|
134
|
+
|
|
135
|
+
Raises:
|
|
136
|
+
ImportError: If moviepy is not installed (required by wandb for video encoding).
|
|
137
|
+
"""
|
|
138
|
+
if not _has_moviepy:
|
|
139
|
+
raise ImportError(
|
|
140
|
+
"Video logging with wandb requires moviepy. "
|
|
141
|
+
"Install with: pip install moviepy\n"
|
|
142
|
+
"Or install wandb with media support: pip install 'wandb[media]'"
|
|
143
|
+
)
|
|
144
|
+
import wandb
|
|
145
|
+
|
|
146
|
+
fps = kwargs.pop("fps", self.video_fps)
|
|
147
|
+
format = kwargs.pop("format", "mp4")
|
|
148
|
+
self.experiment.log(
|
|
149
|
+
{name: wandb.Video(video, fps=fps, format=format)},
|
|
150
|
+
**kwargs,
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
def log_hparams(self, cfg: DictConfig | dict) -> None: # noqa: F821
|
|
154
|
+
"""Logs the hyperparameters of the experiment.
|
|
155
|
+
|
|
156
|
+
Args:
|
|
157
|
+
cfg (DictConfig or dict): The configuration of the experiment.
|
|
158
|
+
|
|
159
|
+
"""
|
|
160
|
+
if type(cfg) is not dict and _has_omegaconf:
|
|
161
|
+
if not _has_omegaconf:
|
|
162
|
+
raise ImportError(
|
|
163
|
+
"OmegaConf could not be imported. "
|
|
164
|
+
"Cannot log hydra configs without OmegaConf."
|
|
165
|
+
)
|
|
166
|
+
from omegaconf import OmegaConf
|
|
167
|
+
|
|
168
|
+
cfg = OmegaConf.to_container(cfg, resolve=True)
|
|
169
|
+
self.experiment.config.update(cfg, allow_val_change=True)
|
|
170
|
+
|
|
171
|
+
def __repr__(self) -> str:
|
|
172
|
+
return f"WandbLogger(experiment={self.experiment.__repr__()})"
|
|
173
|
+
|
|
174
|
+
def log_histogram(self, name: str, data: Sequence, **kwargs):
|
|
175
|
+
"""Add histogram to log.
|
|
176
|
+
|
|
177
|
+
Args:
|
|
178
|
+
name (str): Data identifier
|
|
179
|
+
data (torch.Tensor, numpy.ndarray, or string/blobname): Values to build histogram
|
|
180
|
+
|
|
181
|
+
Keyword Args:
|
|
182
|
+
step (int): Global step value to record
|
|
183
|
+
bins (str): One of {'tensorflow','auto', 'fd', …}. This determines how the bins are made. You can find other options in: https://docs.scipy.org/doc/numpy/reference/generated/numpy.histogram.html
|
|
184
|
+
|
|
185
|
+
"""
|
|
186
|
+
import wandb
|
|
187
|
+
|
|
188
|
+
num_bins = kwargs.pop("bins", None)
|
|
189
|
+
step = kwargs.pop("step", None)
|
|
190
|
+
extra_kwargs = {}
|
|
191
|
+
if step is not None:
|
|
192
|
+
extra_kwargs["trainer/step"] = step
|
|
193
|
+
self.experiment.log(
|
|
194
|
+
{name: wandb.Histogram(data, num_bins=num_bins), **extra_kwargs}
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
def log_str(self, name: str, value: str, step: int | None = None) -> None:
|
|
198
|
+
"""Logs a string value to wandb using a table format for better visualization.
|
|
199
|
+
|
|
200
|
+
Args:
|
|
201
|
+
name (str): The name of the string data.
|
|
202
|
+
value (str): The string value to log.
|
|
203
|
+
step (int, optional): The step at which the string is logged.
|
|
204
|
+
Defaults to None.
|
|
205
|
+
"""
|
|
206
|
+
import wandb
|
|
207
|
+
|
|
208
|
+
# Create a table with a single row
|
|
209
|
+
table = wandb.Table(columns=["text"], data=[[value]])
|
|
210
|
+
|
|
211
|
+
if step is not None:
|
|
212
|
+
self.experiment.log({name: value}, step=step)
|
|
213
|
+
else:
|
|
214
|
+
self.experiment.log({name: table})
|