torchrl 0.11.0__cp314-cp314t-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/.dylibs/libc++.1.0.dylib +0 -0
  97. torchrl/__init__.py +144 -0
  98. torchrl/_extension.py +74 -0
  99. torchrl/_torchrl.cpython-314t-darwin.so +0 -0
  100. torchrl/_utils.py +1431 -0
  101. torchrl/collectors/__init__.py +48 -0
  102. torchrl/collectors/_base.py +1058 -0
  103. torchrl/collectors/_constants.py +88 -0
  104. torchrl/collectors/_multi_async.py +324 -0
  105. torchrl/collectors/_multi_base.py +1805 -0
  106. torchrl/collectors/_multi_sync.py +464 -0
  107. torchrl/collectors/_runner.py +581 -0
  108. torchrl/collectors/_single.py +2009 -0
  109. torchrl/collectors/_single_async.py +259 -0
  110. torchrl/collectors/collectors.py +62 -0
  111. torchrl/collectors/distributed/__init__.py +32 -0
  112. torchrl/collectors/distributed/default_configs.py +133 -0
  113. torchrl/collectors/distributed/generic.py +1306 -0
  114. torchrl/collectors/distributed/ray.py +1092 -0
  115. torchrl/collectors/distributed/rpc.py +1006 -0
  116. torchrl/collectors/distributed/sync.py +731 -0
  117. torchrl/collectors/distributed/utils.py +160 -0
  118. torchrl/collectors/llm/__init__.py +10 -0
  119. torchrl/collectors/llm/base.py +494 -0
  120. torchrl/collectors/llm/ray_collector.py +275 -0
  121. torchrl/collectors/llm/utils.py +36 -0
  122. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  123. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  124. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  125. torchrl/collectors/utils.py +433 -0
  126. torchrl/collectors/weight_update.py +591 -0
  127. torchrl/csrc/numpy_utils.h +38 -0
  128. torchrl/csrc/pybind.cpp +27 -0
  129. torchrl/csrc/segment_tree.h +458 -0
  130. torchrl/csrc/torch_utils.h +34 -0
  131. torchrl/csrc/utils.cpp +48 -0
  132. torchrl/csrc/utils.h +31 -0
  133. torchrl/data/__init__.py +187 -0
  134. torchrl/data/datasets/__init__.py +58 -0
  135. torchrl/data/datasets/atari_dqn.py +878 -0
  136. torchrl/data/datasets/common.py +281 -0
  137. torchrl/data/datasets/d4rl.py +489 -0
  138. torchrl/data/datasets/d4rl_infos.py +187 -0
  139. torchrl/data/datasets/gen_dgrl.py +375 -0
  140. torchrl/data/datasets/minari_data.py +643 -0
  141. torchrl/data/datasets/openml.py +177 -0
  142. torchrl/data/datasets/openx.py +798 -0
  143. torchrl/data/datasets/roboset.py +363 -0
  144. torchrl/data/datasets/utils.py +11 -0
  145. torchrl/data/datasets/vd4rl.py +432 -0
  146. torchrl/data/llm/__init__.py +34 -0
  147. torchrl/data/llm/dataset.py +491 -0
  148. torchrl/data/llm/history.py +1378 -0
  149. torchrl/data/llm/prompt.py +198 -0
  150. torchrl/data/llm/reward.py +225 -0
  151. torchrl/data/llm/topk.py +186 -0
  152. torchrl/data/llm/utils.py +543 -0
  153. torchrl/data/map/__init__.py +21 -0
  154. torchrl/data/map/hash.py +185 -0
  155. torchrl/data/map/query.py +204 -0
  156. torchrl/data/map/tdstorage.py +363 -0
  157. torchrl/data/map/tree.py +1434 -0
  158. torchrl/data/map/utils.py +103 -0
  159. torchrl/data/postprocs/__init__.py +8 -0
  160. torchrl/data/postprocs/postprocs.py +391 -0
  161. torchrl/data/replay_buffers/__init__.py +99 -0
  162. torchrl/data/replay_buffers/checkpointers.py +622 -0
  163. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  164. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  165. torchrl/data/replay_buffers/samplers.py +2578 -0
  166. torchrl/data/replay_buffers/scheduler.py +265 -0
  167. torchrl/data/replay_buffers/storages.py +2412 -0
  168. torchrl/data/replay_buffers/utils.py +1042 -0
  169. torchrl/data/replay_buffers/writers.py +781 -0
  170. torchrl/data/tensor_specs.py +7101 -0
  171. torchrl/data/utils.py +334 -0
  172. torchrl/envs/__init__.py +265 -0
  173. torchrl/envs/async_envs.py +1105 -0
  174. torchrl/envs/batched_envs.py +3093 -0
  175. torchrl/envs/common.py +4241 -0
  176. torchrl/envs/custom/__init__.py +11 -0
  177. torchrl/envs/custom/chess.py +617 -0
  178. torchrl/envs/custom/llm.py +214 -0
  179. torchrl/envs/custom/pendulum.py +401 -0
  180. torchrl/envs/custom/san_moves.txt +29274 -0
  181. torchrl/envs/custom/tictactoeenv.py +288 -0
  182. torchrl/envs/env_creator.py +263 -0
  183. torchrl/envs/gym_like.py +752 -0
  184. torchrl/envs/libs/__init__.py +68 -0
  185. torchrl/envs/libs/_gym_utils.py +326 -0
  186. torchrl/envs/libs/brax.py +846 -0
  187. torchrl/envs/libs/dm_control.py +544 -0
  188. torchrl/envs/libs/envpool.py +447 -0
  189. torchrl/envs/libs/gym.py +2239 -0
  190. torchrl/envs/libs/habitat.py +138 -0
  191. torchrl/envs/libs/isaac_lab.py +87 -0
  192. torchrl/envs/libs/isaacgym.py +203 -0
  193. torchrl/envs/libs/jax_utils.py +166 -0
  194. torchrl/envs/libs/jumanji.py +963 -0
  195. torchrl/envs/libs/meltingpot.py +599 -0
  196. torchrl/envs/libs/openml.py +153 -0
  197. torchrl/envs/libs/openspiel.py +652 -0
  198. torchrl/envs/libs/pettingzoo.py +1042 -0
  199. torchrl/envs/libs/procgen.py +351 -0
  200. torchrl/envs/libs/robohive.py +429 -0
  201. torchrl/envs/libs/smacv2.py +645 -0
  202. torchrl/envs/libs/unity_mlagents.py +891 -0
  203. torchrl/envs/libs/utils.py +147 -0
  204. torchrl/envs/libs/vmas.py +813 -0
  205. torchrl/envs/llm/__init__.py +63 -0
  206. torchrl/envs/llm/chat.py +730 -0
  207. torchrl/envs/llm/datasets/README.md +4 -0
  208. torchrl/envs/llm/datasets/__init__.py +17 -0
  209. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  210. torchrl/envs/llm/datasets/ifeval.py +274 -0
  211. torchrl/envs/llm/envs.py +789 -0
  212. torchrl/envs/llm/libs/README.md +3 -0
  213. torchrl/envs/llm/libs/__init__.py +8 -0
  214. torchrl/envs/llm/libs/mlgym.py +869 -0
  215. torchrl/envs/llm/reward/__init__.py +10 -0
  216. torchrl/envs/llm/reward/gsm8k.py +324 -0
  217. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  218. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  222. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  223. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  224. torchrl/envs/llm/transforms/__init__.py +55 -0
  225. torchrl/envs/llm/transforms/browser.py +292 -0
  226. torchrl/envs/llm/transforms/dataloading.py +859 -0
  227. torchrl/envs/llm/transforms/format.py +73 -0
  228. torchrl/envs/llm/transforms/kl.py +1544 -0
  229. torchrl/envs/llm/transforms/policy_version.py +189 -0
  230. torchrl/envs/llm/transforms/reason.py +323 -0
  231. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  232. torchrl/envs/llm/transforms/tools.py +1955 -0
  233. torchrl/envs/model_based/__init__.py +9 -0
  234. torchrl/envs/model_based/common.py +180 -0
  235. torchrl/envs/model_based/dreamer.py +112 -0
  236. torchrl/envs/transforms/__init__.py +147 -0
  237. torchrl/envs/transforms/functional.py +48 -0
  238. torchrl/envs/transforms/gym_transforms.py +203 -0
  239. torchrl/envs/transforms/module.py +341 -0
  240. torchrl/envs/transforms/r3m.py +372 -0
  241. torchrl/envs/transforms/ray_service.py +663 -0
  242. torchrl/envs/transforms/rb_transforms.py +214 -0
  243. torchrl/envs/transforms/transforms.py +11835 -0
  244. torchrl/envs/transforms/utils.py +94 -0
  245. torchrl/envs/transforms/vc1.py +307 -0
  246. torchrl/envs/transforms/vecnorm.py +845 -0
  247. torchrl/envs/transforms/vip.py +407 -0
  248. torchrl/envs/utils.py +1718 -0
  249. torchrl/envs/vec_envs.py +11 -0
  250. torchrl/modules/__init__.py +206 -0
  251. torchrl/modules/distributions/__init__.py +73 -0
  252. torchrl/modules/distributions/continuous.py +830 -0
  253. torchrl/modules/distributions/discrete.py +908 -0
  254. torchrl/modules/distributions/truncated_normal.py +187 -0
  255. torchrl/modules/distributions/utils.py +233 -0
  256. torchrl/modules/llm/__init__.py +62 -0
  257. torchrl/modules/llm/backends/__init__.py +65 -0
  258. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  259. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  260. torchrl/modules/llm/backends/vllm/base.py +72 -0
  261. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  262. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  263. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  264. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  265. torchrl/modules/llm/policies/__init__.py +28 -0
  266. torchrl/modules/llm/policies/common.py +1809 -0
  267. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  268. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  269. torchrl/modules/llm/utils.py +23 -0
  270. torchrl/modules/mcts/__init__.py +21 -0
  271. torchrl/modules/mcts/scores.py +579 -0
  272. torchrl/modules/models/__init__.py +86 -0
  273. torchrl/modules/models/batchrenorm.py +119 -0
  274. torchrl/modules/models/decision_transformer.py +179 -0
  275. torchrl/modules/models/exploration.py +731 -0
  276. torchrl/modules/models/llm.py +156 -0
  277. torchrl/modules/models/model_based.py +596 -0
  278. torchrl/modules/models/models.py +1712 -0
  279. torchrl/modules/models/multiagent.py +1067 -0
  280. torchrl/modules/models/recipes/impala.py +185 -0
  281. torchrl/modules/models/utils.py +162 -0
  282. torchrl/modules/planners/__init__.py +10 -0
  283. torchrl/modules/planners/cem.py +228 -0
  284. torchrl/modules/planners/common.py +73 -0
  285. torchrl/modules/planners/mppi.py +265 -0
  286. torchrl/modules/tensordict_module/__init__.py +89 -0
  287. torchrl/modules/tensordict_module/actors.py +2457 -0
  288. torchrl/modules/tensordict_module/common.py +529 -0
  289. torchrl/modules/tensordict_module/exploration.py +814 -0
  290. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  291. torchrl/modules/tensordict_module/rnn.py +1639 -0
  292. torchrl/modules/tensordict_module/sequence.py +132 -0
  293. torchrl/modules/tensordict_module/world_models.py +34 -0
  294. torchrl/modules/utils/__init__.py +38 -0
  295. torchrl/modules/utils/mappings.py +9 -0
  296. torchrl/modules/utils/utils.py +89 -0
  297. torchrl/objectives/__init__.py +78 -0
  298. torchrl/objectives/a2c.py +659 -0
  299. torchrl/objectives/common.py +753 -0
  300. torchrl/objectives/cql.py +1346 -0
  301. torchrl/objectives/crossq.py +710 -0
  302. torchrl/objectives/ddpg.py +453 -0
  303. torchrl/objectives/decision_transformer.py +371 -0
  304. torchrl/objectives/deprecated.py +516 -0
  305. torchrl/objectives/dqn.py +683 -0
  306. torchrl/objectives/dreamer.py +488 -0
  307. torchrl/objectives/functional.py +48 -0
  308. torchrl/objectives/gail.py +258 -0
  309. torchrl/objectives/iql.py +996 -0
  310. torchrl/objectives/llm/__init__.py +30 -0
  311. torchrl/objectives/llm/grpo.py +846 -0
  312. torchrl/objectives/llm/sft.py +482 -0
  313. torchrl/objectives/multiagent/__init__.py +8 -0
  314. torchrl/objectives/multiagent/qmixer.py +396 -0
  315. torchrl/objectives/ppo.py +1669 -0
  316. torchrl/objectives/redq.py +683 -0
  317. torchrl/objectives/reinforce.py +530 -0
  318. torchrl/objectives/sac.py +1580 -0
  319. torchrl/objectives/td3.py +570 -0
  320. torchrl/objectives/td3_bc.py +625 -0
  321. torchrl/objectives/utils.py +782 -0
  322. torchrl/objectives/value/__init__.py +28 -0
  323. torchrl/objectives/value/advantages.py +1956 -0
  324. torchrl/objectives/value/functional.py +1459 -0
  325. torchrl/objectives/value/utils.py +360 -0
  326. torchrl/record/__init__.py +17 -0
  327. torchrl/record/loggers/__init__.py +23 -0
  328. torchrl/record/loggers/common.py +48 -0
  329. torchrl/record/loggers/csv.py +226 -0
  330. torchrl/record/loggers/mlflow.py +142 -0
  331. torchrl/record/loggers/tensorboard.py +139 -0
  332. torchrl/record/loggers/trackio.py +163 -0
  333. torchrl/record/loggers/utils.py +78 -0
  334. torchrl/record/loggers/wandb.py +214 -0
  335. torchrl/record/recorder.py +554 -0
  336. torchrl/services/__init__.py +79 -0
  337. torchrl/services/base.py +109 -0
  338. torchrl/services/ray_service.py +453 -0
  339. torchrl/testing/__init__.py +107 -0
  340. torchrl/testing/assertions.py +179 -0
  341. torchrl/testing/dist_utils.py +122 -0
  342. torchrl/testing/env_creators.py +227 -0
  343. torchrl/testing/env_helper.py +35 -0
  344. torchrl/testing/gym_helpers.py +156 -0
  345. torchrl/testing/llm_mocks.py +119 -0
  346. torchrl/testing/mocking_classes.py +2720 -0
  347. torchrl/testing/modules.py +295 -0
  348. torchrl/testing/mp_helpers.py +15 -0
  349. torchrl/testing/ray_helpers.py +293 -0
  350. torchrl/testing/utils.py +190 -0
  351. torchrl/trainers/__init__.py +42 -0
  352. torchrl/trainers/algorithms/__init__.py +11 -0
  353. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  354. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  355. torchrl/trainers/algorithms/configs/common.py +41 -0
  356. torchrl/trainers/algorithms/configs/data.py +308 -0
  357. torchrl/trainers/algorithms/configs/envs.py +104 -0
  358. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  359. torchrl/trainers/algorithms/configs/logging.py +80 -0
  360. torchrl/trainers/algorithms/configs/modules.py +570 -0
  361. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  362. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  363. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  364. torchrl/trainers/algorithms/configs/utils.py +252 -0
  365. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  366. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  367. torchrl/trainers/algorithms/ppo.py +373 -0
  368. torchrl/trainers/algorithms/sac.py +308 -0
  369. torchrl/trainers/helpers/__init__.py +40 -0
  370. torchrl/trainers/helpers/collectors.py +416 -0
  371. torchrl/trainers/helpers/envs.py +573 -0
  372. torchrl/trainers/helpers/logger.py +33 -0
  373. torchrl/trainers/helpers/losses.py +132 -0
  374. torchrl/trainers/helpers/models.py +658 -0
  375. torchrl/trainers/helpers/replay_buffer.py +59 -0
  376. torchrl/trainers/helpers/trainers.py +301 -0
  377. torchrl/trainers/trainers.py +2052 -0
  378. torchrl/weight_update/__init__.py +33 -0
  379. torchrl/weight_update/_distributed.py +749 -0
  380. torchrl/weight_update/_mp.py +624 -0
  381. torchrl/weight_update/_noupdate.py +102 -0
  382. torchrl/weight_update/_ray.py +1032 -0
  383. torchrl/weight_update/_rpc.py +284 -0
  384. torchrl/weight_update/_shared.py +891 -0
  385. torchrl/weight_update/llm/__init__.py +32 -0
  386. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  387. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  388. torchrl/weight_update/utils.py +73 -0
  389. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  390. torchrl-0.11.0.dist-info/METADATA +1308 -0
  391. torchrl-0.11.0.dist-info/RECORD +395 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  395. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,360 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import torch
8
+
9
+ from tensordict import TensorDictBase
10
+ from tensordict.utils import expand_right
11
+
12
+
13
+ def _custom_conv1d(tensor: torch.Tensor, filter: torch.Tensor):
14
+ """Computes a conv1d filter over a value.
15
+
16
+ This is usually used to compute a discounted return:
17
+
18
+ Tensor: Filter Result (discounted return)
19
+ [ r_0, [ 1.0, [ r_0 + g r_1 + g^2 r_2 + r^3 r_3,
20
+ r_1, g, r_1 + g r_2 + g^2 r_3,
21
+ r_2, g^2, r_2 + g r_3,
22
+ r_3, g^3 ] r_3 ]
23
+ 0, | |
24
+ 0, | zero padding | direction of filter
25
+ 0 ] | v
26
+
27
+ This function takes care of applying the one-sided zero padding. In this example,
28
+ `Filter_dim` = :obj:`Time` = 4, but in practice Filter_dim can be <= to :obj:`Time`.
29
+
30
+ Args:
31
+ tensor (torch.Tensor): a [ Batch x 1 x Time ] floating-point tensor
32
+ filter (torch.Tensor): a [ Filter_dim x 1 ] floating-point filter
33
+
34
+ Returns: a filtered tensor of the same shape as the input tensor.
35
+
36
+ """
37
+ if filter.ndimension() > 2:
38
+ # filter will have shape batch_dims x timesteps x filter_dim x 1
39
+ # reshape to batch_dims x timesteps x 1 x filter_dim ready for convolving
40
+ filter = filter.view(*filter.shape[:-2], 1, filter.shape[-2])
41
+
42
+ # because time is represented on two different dimensions, we don't
43
+ # need all convolutions, just those lying along a diagonal
44
+ # rather than compute them all and discard, we stack just the slices
45
+ # of val_pad that we care about, and apply the filter manually
46
+
47
+ # STACK VERSION: val_pad is computed as in the block below
48
+ # batched_val_pad = torch.stack(
49
+ # [val_pad[..., i : i + filter.shape[-1]] for i in range(tensor.shape[-1])],
50
+ # dim=1,
51
+ # )
52
+
53
+ # roll version
54
+ T = tensor.shape[-1]
55
+ device = tensor.device
56
+ batched_val_pad = (
57
+ roll_by_gather(
58
+ tensor.expand(tensor.shape[0], filter.shape[-1], T).transpose(-2, -1),
59
+ 0,
60
+ -torch.arange(filter.shape[-1], device=device),
61
+ )
62
+ .flip(-1)
63
+ .triu(filter.shape[-1] - T)
64
+ .flip(-1)
65
+ .unsqueeze(-2)
66
+ )
67
+
68
+ # this is just a batched matrix multiplication, but einsum makes it
69
+ # easy to keep the many dimensions under control. Here b = batch,
70
+ # t = timestep, s = singleton, j is the filter dimension that should
71
+ # get summed out. we swap the order of s and t here rather than
72
+ # reshape / create a view later.
73
+ # this is essentially identical to (batched_val_pad @ filter.transpose(-2, -1)).squeeze().unsqueeze(-2)
74
+ # out = (batched_val_pad @ filter.transpose(-2, -1)).squeeze().unsqueeze(-2)
75
+ out = torch.einsum("btsj,btsj->bst", batched_val_pad, filter)
76
+ else:
77
+ val_pad = torch.nn.functional.pad(tensor, [0, filter.shape[-2] - 1])
78
+
79
+ # shape = val.shape
80
+ filter = filter.squeeze(-1).unsqueeze(0).unsqueeze(0) # 1 x 1 x T
81
+ out = torch.conv1d(val_pad, filter)
82
+ # out = out.view(shape)
83
+ if out.shape != tensor.shape:
84
+ raise RuntimeError(
85
+ f"wrong output shape: input shape: {tensor.shape}, output shape: {out.shape}"
86
+ )
87
+ return out
88
+
89
+
90
+ def roll_by_gather(mat: torch.Tensor, dim: int, shifts: torch.LongTensor):
91
+ """Rolls a batched matrix along the last or last but one dimension.
92
+
93
+ Args:
94
+ mat (torch.Tensor): A batched matrix to roll
95
+ dim (int): 0 or -2 indicates the last but one dimension,
96
+ 1 or -1 the last dimension.
97
+ shifts (torch.LongTensor): A tensor containing the shifts. Must have the same number of
98
+ elements as the unchosen dimension.
99
+
100
+ Examples:
101
+ >>> x = torch.arange(12).view(3, 4)
102
+ >>> roll_by_gather(x, 0, -torch.arange(4)) # shifts the values in each column
103
+ tensor([[ 0, 5, 10, 3],
104
+ [ 4, 9, 2, 7],
105
+ [ 8, 1, 6, 11]])
106
+ >>> roll_by_gather(x, 1, -torch.arange(3)) # shifts the values in each row
107
+ tensor([[ 0, 1, 2, 3],
108
+ [ 5, 6, 7, 4],
109
+ [10, 11, 8, 9]])
110
+
111
+ """
112
+ # assumes 2D array
113
+ *batch, n_rows, n_cols = mat.shape
114
+ device = mat.device
115
+
116
+ if dim in (0, -2):
117
+ arange1 = (
118
+ torch.arange(n_rows, device=device).unsqueeze(-1).expand((n_rows, n_cols))
119
+ )
120
+ arange2 = (arange1 - shifts) % n_rows
121
+ return torch.gather(mat, -2, arange2.expand(*batch, *arange2.shape))
122
+ elif dim in (1, -1):
123
+ arange1 = torch.arange(n_cols, device=device).expand((n_rows, n_cols))
124
+ arange2 = (arange1 - shifts.unsqueeze(-1)) % n_cols
125
+ return torch.gather(mat, -1, arange2.expand(*batch, n_rows, n_cols))
126
+ else:
127
+ raise NotImplementedError(f"dim {dim} is not supported.")
128
+
129
+
130
+ def _make_gammas_tensor(gamma: torch.Tensor, T: int, rolling_gamma: bool):
131
+ """Prepares a decay tensor for a matrix multiplication.
132
+
133
+ Given a tensor gamma of size [*batch, T, D],
134
+ it will return a new tensor with size [*batch, T, T+1, D].
135
+ In the rolling_gamma case, a rolling of the gamma values will be performed
136
+ along the T axis, e.g.:
137
+ [[ 1, g1, g2, g3],
138
+ [ 1, g2, g3, 0],
139
+ [ 1, g3, 0, 0]]
140
+
141
+ Args:
142
+ gamma (torch.tensor): the gamma tensor to be prepared.
143
+ T (int): the time length
144
+ rolling_gamma (bool): if ``True``, the gamma value is set for each step
145
+ independently. If False, the gamma value at (i, t) will be used for the
146
+ trajectory following (i, t).
147
+
148
+ Returns: the prepared gamma decay tensor
149
+
150
+ """
151
+ # some reshaping code vendored from vec_td_lambda_return_estimate
152
+ gamma = gamma.transpose(-2, -1).contiguous()
153
+ gamma = gamma.view(-1, T)
154
+ dtype = gamma.dtype
155
+ device = gamma.device
156
+ if rolling_gamma:
157
+ # # loop
158
+ # gammas = gamma.unsqueeze(-2).expand(gamma.shape[0], T, T).contiguous()
159
+ # for i in range(1, T):
160
+ # s = gammas[:, i].clone()
161
+ # gammas[:, i] = 0
162
+ # gammas[:, i, :-i] = s[:, i:]
163
+ # gammas = torch.cumprod(gammas.unsqueeze(-1), -2)
164
+ # gammas_cont = torch.ones(gammas.shape[0], T, T, 1)
165
+ # gammas_cont[..., 1:, :] = gammas[..., :-1, :]
166
+ # gammas = gammas_cont
167
+
168
+ # vectorized version
169
+ gammas = torch.ones(gamma.shape[0], T, T + 1, 1, dtype=dtype, device=device)
170
+ s0 = gamma.unsqueeze(-1).expand(gamma.shape[0], T, T)
171
+ s1 = roll_by_gather(s0, 0, shifts=-torch.arange(T, device=device))
172
+
173
+ # we should triu here, but it's useless since there is a triu on the values
174
+ # happening in _custom_conv1d
175
+ # s2 = s1.flip(-1).triu().flip(-1).transpose(-2, -1)
176
+ s2 = s1.transpose(-2, -1)
177
+ gammas[..., 1:, :] = s2.unsqueeze(-1)
178
+ else:
179
+ gammas = torch.ones(*gamma.shape, T + 1, 1, device=device, dtype=dtype)
180
+ gammas[..., 1:, :] = gamma[..., None, None]
181
+ return gammas
182
+
183
+
184
+ def _flatten_batch(tensor, time_dim=-1):
185
+ """Because we mark the end of each batch with a truncated signal, we can concatenate them.
186
+
187
+ Args:
188
+ tensor (torch.Tensor): a tensor of shape [*B, T, *F]
189
+ time_dim (int, optional): the time dimension T. Defaults to -1.
190
+
191
+ """
192
+ return tensor.flatten(0, time_dim)
193
+
194
+
195
+ def _get_num_per_traj(done):
196
+ """Because we mark the end of each batch with a truncated signal, we can concatenate them.
197
+
198
+ Args:
199
+ done (torch.Tensor): A done or truncated mark of shape [*B, T]
200
+
201
+ Returns:
202
+ A list of integers representing the number of steps in each trajectory
203
+
204
+ """
205
+ done = done.clone()
206
+ done[..., -1] = True
207
+ # TODO: find a way of copying once only, eg not using reshape
208
+ num_per_traj = torch.where(done.reshape(-1))[0] + 1
209
+ num_per_traj[1:] = num_per_traj[1:] - num_per_traj[:-1]
210
+ return num_per_traj
211
+
212
+
213
+ def _split_and_pad_sequence(
214
+ tensor: torch.Tensor | TensorDictBase,
215
+ splits: torch.Tensor,
216
+ return_mask=False,
217
+ time_dim=-1,
218
+ ):
219
+ """Given a tensor of size [*B, T, F] and the corresponding traj lengths (flattened), returns the padded trajectories [NPad, Tmax, *other].
220
+
221
+ Compatible with tensordict inputs.
222
+
223
+ Examples:
224
+ >>> from tensordict import TensorDict
225
+ >>> is_init = torch.zeros(4, 5, dtype=torch.bool)
226
+ >>> is_init[:, 0] = True
227
+ >>> is_init[0, 3] = True
228
+ >>> is_init[1, 2] = True
229
+ >>> tensordict = TensorDict({
230
+ ... "is_init": is_init,
231
+ ... "obs": torch.arange(20).view(4, 5).unsqueeze(-1).expand(4, 5, 3),
232
+ ... }, [4, 5])
233
+ >>> splits = _get_num_per_traj_init(is_init)
234
+ >>> print(splits)
235
+ tensor([3, 2, 2, 3, 5, 5])
236
+ >>> td = _split_and_pad_sequence(tensordict, splits)
237
+ >>> print(td)
238
+ TensorDict(
239
+ fields={
240
+ is_init: Tensor(shape=torch.Size([6, 5]), device=cpu, dtype=torch.bool, is_shared=False),
241
+ obs: Tensor(shape=torch.Size([6, 5, 3]), device=cpu, dtype=torch.int64, is_shared=False)},
242
+ batch_size=torch.Size([6, 5]),
243
+ device=None,
244
+ is_shared=False)
245
+ >>> print(td["obs"])
246
+ tensor([[[ 0, 0, 0],
247
+ [ 1, 1, 1],
248
+ [ 2, 2, 2],
249
+ [ 0, 0, 0],
250
+ [ 0, 0, 0]],
251
+ <BLANKLINE>
252
+ [[ 3, 3, 3],
253
+ [ 4, 4, 4],
254
+ [ 0, 0, 0],
255
+ [ 0, 0, 0],
256
+ [ 0, 0, 0]],
257
+ <BLANKLINE>
258
+ [[ 5, 5, 5],
259
+ [ 6, 6, 6],
260
+ [ 0, 0, 0],
261
+ [ 0, 0, 0],
262
+ [ 0, 0, 0]],
263
+ <BLANKLINE>
264
+ [[ 7, 7, 7],
265
+ [ 8, 8, 8],
266
+ [ 9, 9, 9],
267
+ [ 0, 0, 0],
268
+ [ 0, 0, 0]],
269
+ <BLANKLINE>
270
+ [[10, 10, 10],
271
+ [11, 11, 11],
272
+ [12, 12, 12],
273
+ [13, 13, 13],
274
+ [14, 14, 14]],
275
+ <BLANKLINE>
276
+ [[15, 15, 15],
277
+ [16, 16, 16],
278
+ [17, 17, 17],
279
+ [18, 18, 18],
280
+ [19, 19, 19]]])
281
+
282
+ """
283
+ max_seq_len = torch.max(splits)
284
+ shape = (len(splits), max_seq_len)
285
+
286
+ # int16 supports length up to 32767
287
+ dtype = (
288
+ torch.int16
289
+ if tensor.size(time_dim) < torch.iinfo(torch.int16).max
290
+ else torch.int32
291
+ )
292
+ arange = torch.arange(max_seq_len, device=tensor.device, dtype=dtype).unsqueeze(0)
293
+ mask = arange < splits.unsqueeze(1)
294
+
295
+ tensor = _flatten_batch(tensor, time_dim=time_dim)
296
+
297
+ def _fill_tensor(tensor):
298
+ empty_tensor = torch.zeros(
299
+ *shape,
300
+ *tensor.shape[1:],
301
+ dtype=tensor.dtype,
302
+ device=tensor.device,
303
+ )
304
+ mask_expand = expand_right(mask, (*mask.shape, *tensor.shape[1:]))
305
+ # We need to use masked-scatter to accommodate vmap
306
+ return torch.masked_scatter(empty_tensor, mask_expand, tensor.reshape(-1))
307
+ # empty_tensor[mask_expand] = tensor.reshape(-1)
308
+ # return empty_tensor
309
+
310
+ if isinstance(tensor, TensorDictBase):
311
+ tensor = tensor.apply(_fill_tensor, batch_size=list(shape))
312
+ else:
313
+ tensor = _fill_tensor(tensor)
314
+ if return_mask:
315
+ return tensor, mask
316
+ return tensor
317
+
318
+
319
+ def _inv_pad_sequence(
320
+ tensor: torch.Tensor | TensorDictBase,
321
+ splits: torch.Tensor,
322
+ mask: torch.Tensor = None,
323
+ ):
324
+ """Inverse a pad_sequence operation.
325
+
326
+ If tensor is of shape [B, T], than splits must be of of shape [B] with all elements
327
+ and integer between [1, T].
328
+ The result will be flattened along the batch dimension(s) and must be reshaped into
329
+ the original shape (if necessary).
330
+
331
+ Examples:
332
+ >>> rewards = torch.randn(100, 20)
333
+ >>> num_per_traj = _get_num_per_traj(torch.zeros(100, 20).bernoulli_(0.1))
334
+ >>> padded = _split_and_pad_sequence(rewards, num_per_traj)
335
+ >>> reconstructed = _inv_pad_sequence(padded, num_per_traj)
336
+ >>> assert (reconstructed==rewards).all()
337
+ """
338
+ if splits.numel() == 1:
339
+ return tensor
340
+
341
+ if mask is None:
342
+ # int16 supports length up to 32767
343
+ dtype = (
344
+ torch.int16
345
+ if tensor.shape[-1] < torch.iinfo(torch.int16).max
346
+ else torch.int32
347
+ )
348
+ arange = torch.arange(
349
+ tensor.shape[-1], device=tensor.device, dtype=dtype
350
+ ).unsqueeze(0)
351
+ mask = arange < splits.unsqueeze(1)
352
+
353
+ return tensor[mask]
354
+
355
+
356
+ def _get_num_per_traj_init(is_init):
357
+ """Like _get_num_per_traj, but with is_init signal."""
358
+ done = torch.zeros_like(is_init)
359
+ done[..., :-1][is_init[..., 1:]] = 1
360
+ return _get_num_per_traj(done)
@@ -0,0 +1,17 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from .loggers import CSVLogger, MLFlowLogger, TensorboardLogger, WandbLogger
7
+ from .recorder import PixelRenderTransform, TensorDictRecorder, VideoRecorder
8
+
9
+ __all__ = [
10
+ "CSVLogger",
11
+ "MLFlowLogger",
12
+ "TensorboardLogger",
13
+ "WandbLogger",
14
+ "PixelRenderTransform",
15
+ "TensorDictRecorder",
16
+ "VideoRecorder",
17
+ ]
@@ -0,0 +1,23 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from .common import Logger
7
+
8
+ from .csv import CSVLogger
9
+ from .mlflow import MLFlowLogger
10
+ from .tensorboard import TensorboardLogger
11
+ from .utils import generate_exp_name, get_logger
12
+
13
+ from .wandb import WandbLogger
14
+
15
+ __all__ = [
16
+ "Logger",
17
+ "CSVLogger",
18
+ "MLFlowLogger",
19
+ "TensorboardLogger",
20
+ "generate_exp_name",
21
+ "get_logger",
22
+ "WandbLogger",
23
+ ]
@@ -0,0 +1,48 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import abc
8
+ from collections.abc import Sequence
9
+
10
+ from torch import Tensor
11
+
12
+
13
+ __all__ = ["Logger"]
14
+
15
+
16
+ class Logger:
17
+ """A template for loggers."""
18
+
19
+ def __init__(self, exp_name: str, log_dir: str) -> None:
20
+ self.exp_name = exp_name
21
+ self.log_dir = log_dir
22
+ self.experiment = self._create_experiment()
23
+
24
+ @abc.abstractmethod
25
+ def _create_experiment(self) -> Experiment: # noqa: F821
26
+ ...
27
+
28
+ @abc.abstractmethod
29
+ def log_scalar(self, name: str, value: float, step: int | None = None) -> None:
30
+ ...
31
+
32
+ @abc.abstractmethod
33
+ def log_video(
34
+ self, name: str, video: Tensor, step: int | None = None, **kwargs
35
+ ) -> None:
36
+ ...
37
+
38
+ @abc.abstractmethod
39
+ def log_hparams(self, cfg: DictConfig | dict) -> None: # noqa: F821
40
+ ...
41
+
42
+ @abc.abstractmethod
43
+ def __repr__(self) -> str:
44
+ ...
45
+
46
+ @abc.abstractmethod
47
+ def log_histogram(self, name: str, data: Sequence, **kwargs):
48
+ ...
@@ -0,0 +1,226 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import os
8
+ from collections import defaultdict
9
+ from collections.abc import Sequence
10
+ from pathlib import Path
11
+
12
+ import tensordict.utils
13
+ import torch
14
+ from tensordict import MemoryMappedTensor
15
+ from torch import Tensor
16
+
17
+ from .common import Logger
18
+
19
+
20
+ class CSVExperiment:
21
+ """A CSV logger experiment class."""
22
+
23
+ def __init__(self, log_dir: str, *, video_format="pt", video_fps: int = 30):
24
+ self.scalars = defaultdict(list)
25
+ self.videos_counter = defaultdict(int)
26
+ self.text_counter = defaultdict(int)
27
+ self.log_dir = log_dir
28
+ self.video_format = video_format
29
+ self.video_fps = video_fps
30
+ os.makedirs(self.log_dir, exist_ok=True)
31
+ os.makedirs(os.path.join(self.log_dir, "scalars"), exist_ok=True)
32
+ os.makedirs(os.path.join(self.log_dir, "videos"), exist_ok=True)
33
+ os.makedirs(os.path.join(self.log_dir, "texts"), exist_ok=True)
34
+
35
+ self.files = {}
36
+
37
+ def add_scalar(self, name: str, value: float, global_step: int | None = None):
38
+ if global_step is None:
39
+ global_step = len(self.scalars[name])
40
+ value = float(value)
41
+ self.scalars[name].append((global_step, value))
42
+ filepath = os.path.join(self.log_dir, "scalars", "".join([name, ".csv"]))
43
+ if not os.path.isfile(filepath):
44
+ os.makedirs(Path(filepath).parent, exist_ok=True)
45
+ if filepath not in self.files:
46
+ os.makedirs(Path(filepath).parent, exist_ok=True)
47
+ self.files[filepath] = open(filepath, "a+")
48
+ fd = self.files[filepath]
49
+ fd.write(",".join([str(global_step), str(value)]) + "\n")
50
+ fd.flush()
51
+
52
+ def add_video(self, tag, vid_tensor, global_step: int | None = None, **kwargs):
53
+ """Writes a video on a file on disk.
54
+
55
+ The video format can be one of
56
+
57
+ - `"pt"`: uses :func:`~torch.save` to save the video tensor);
58
+ - `"memmap"`: saved the file as memory-mapped array (reading this file will require
59
+ the dtype and shape to be known at read time);
60
+ - `"mp4"`: saves the file as an `.mp4` file using torchvision :func:`~torchvision.io.write_video`
61
+ API. Any ``kwargs`` passed to ``add_video`` will be transmitted to ``write_video``.
62
+ These include ``preset``, ``crf`` and others.
63
+ See ffmpeg's doc (https://trac.ffmpeg.org/wiki/Encode/H.264) for some more information of the video format options.
64
+
65
+ """
66
+ if global_step is None:
67
+ global_step = self.videos_counter[tag]
68
+ self.videos_counter[tag] += 1
69
+ if self.video_format == "pt":
70
+ extension = ".pt"
71
+ elif self.video_format == "memmap":
72
+ extension = ".memmap"
73
+ elif self.video_format == "mp4":
74
+ extension = ".mp4"
75
+ else:
76
+ raise ValueError(
77
+ f"Unknown video format {self.video_format}. Must be one of 'pt', 'memmap' or 'mp4'."
78
+ )
79
+
80
+ filepath = os.path.join(
81
+ self.log_dir, "videos", "_".join([tag, str(global_step)]) + extension
82
+ )
83
+ path_to_create = Path(str(filepath)).parent
84
+ os.makedirs(path_to_create, exist_ok=True)
85
+ if self.video_format == "pt":
86
+ torch.save(vid_tensor, filepath)
87
+ elif self.video_format == "memmap":
88
+ MemoryMappedTensor.from_tensor(vid_tensor, filename=filepath)
89
+ elif self.video_format == "mp4":
90
+ import torchvision
91
+
92
+ if vid_tensor.shape[-3] not in (3, 1):
93
+ raise RuntimeError(
94
+ "expected the video tensor to be of format [T, C, H, W] but the third channel "
95
+ f"starting from the end isn't in (1, 3) but is {vid_tensor.shape[-3]}."
96
+ )
97
+ if vid_tensor.ndim > 4:
98
+ vid_tensor = vid_tensor.flatten(0, vid_tensor.ndim - 4)
99
+ vid_tensor = vid_tensor.permute((0, 2, 3, 1))
100
+ vid_tensor = vid_tensor.expand(*vid_tensor.shape[:-1], 3)
101
+ kwargs.setdefault("fps", self.video_fps)
102
+ torchvision.io.write_video(filepath, vid_tensor, **kwargs)
103
+ else:
104
+ raise ValueError(
105
+ f"Unknown video format {self.video_format}. Must be one of 'pt', 'memmap' or 'mp4'."
106
+ )
107
+
108
+ def add_text(self, tag, text, global_step: int | None = None):
109
+ if global_step is None:
110
+ global_step = self.videos_counter[tag]
111
+ self.videos_counter[tag] += 1
112
+ filepath = os.path.join(
113
+ self.log_dir, "texts", "".join([tag, str(global_step)]) + ".txt"
114
+ )
115
+ if not os.path.isfile(filepath):
116
+ os.makedirs(Path(filepath).parent, exist_ok=True)
117
+ if filepath not in self.files:
118
+ self.files[filepath] = open(filepath, "w+")
119
+ fd = self.files[filepath]
120
+ fd.writelines(text)
121
+ fd.flush()
122
+
123
+ def __repr__(self) -> str:
124
+ return f"CSVExperiment(log_dir={self.log_dir})"
125
+
126
+ def __del__(self):
127
+ for val in getattr(self, "files", {}).values():
128
+ val.close()
129
+
130
+
131
+ class CSVLogger(Logger):
132
+ """A minimal-dependency CSV logger.
133
+
134
+ Args:
135
+ exp_name (str): The name of the experiment.
136
+ log_dir (str or Path, optional): where the experiment should be saved.
137
+ Defaults to ``<cur_dir>/csv_logs``.
138
+ video_format (str, optional): how videos should be saved when calling :meth:`~torchrl.record.loggers.csv.CSVExperiment.add_video`. Must be one of
139
+ ``"pt"`` (video saved as a `video_<tag>_<step>.pt` file with torch.save),
140
+ ``"memmap"`` (video saved as a `video_<tag>_<step>.memmap` file with :class:`~tensordict.MemoryMappedTensor`),
141
+ ``"mp4"`` (video saved as a `video_<tag>_<step>.mp4` file, requires torchvision to be installed).
142
+ Defaults to ``"pt"``.
143
+ video_fps (int, optional): the video frames-per-seconds if `video_format="mp4"`. Defaults to 30.
144
+
145
+ """
146
+
147
+ experiment: CSVExperiment
148
+
149
+ def __init__(
150
+ self,
151
+ exp_name: str,
152
+ log_dir: str | None = None,
153
+ video_format: str = "pt",
154
+ video_fps: int = 30,
155
+ ) -> None:
156
+ if log_dir is None:
157
+ log_dir = "csv_logs"
158
+ self.video_format = video_format
159
+ self.video_fps = video_fps
160
+ super().__init__(exp_name=exp_name, log_dir=log_dir)
161
+ self._has_imported_moviepy = False
162
+
163
+ def _create_experiment(self) -> CSVExperiment:
164
+ """Creates a CSV experiment."""
165
+ log_dir = str(os.path.join(self.log_dir, self.exp_name))
166
+ return CSVExperiment(
167
+ log_dir, video_format=self.video_format, video_fps=self.video_fps
168
+ )
169
+
170
+ def log_scalar(self, name: str, value: float, step: int | None = None) -> None:
171
+ """Logs a scalar value to the tensorboard.
172
+
173
+ Args:
174
+ name (str): The name of the scalar.
175
+ value (float): The value of the scalar.
176
+ step (int, optional): The step at which the scalar is logged. Defaults to None.
177
+ """
178
+ self.experiment.add_scalar(name, value, global_step=step)
179
+
180
+ def log_video(
181
+ self, name: str, video: Tensor, step: int | None = None, **kwargs
182
+ ) -> None:
183
+ """Log videos inputs to a .pt (or other format) file.
184
+
185
+ Args:
186
+ name (str): The name of the video.
187
+ video (Tensor): The video to be logged.
188
+ step (int, optional): The step at which the video is logged. Defaults to None.
189
+ **kwargs: other kwargs passed to the underlying video logger.
190
+
191
+ .. note:: If the video format is `mp4`, many more arguments can be passed to the :meth:`~torchvision.io.write_video`
192
+ function.
193
+ For more information on video logging with :class:`~torchrl.record.loggers.csv.CSVLogger`,
194
+ see the :meth:`~torchrl.record.loggers.csv.CSVExperiment.add_video` documentation.
195
+ """
196
+ # check for correct format of the video tensor ((N), T, C, H, W)
197
+ # check that the color channel (C) is either 1 or 3
198
+ if video.dim() != 5 or video.size(dim=2) not in {1, 3}:
199
+ raise Exception(
200
+ "Wrong format of the video tensor. Should be ((N), T, C, H, W)"
201
+ )
202
+ self.experiment.add_video(
203
+ tag=name,
204
+ vid_tensor=video,
205
+ global_step=step,
206
+ **kwargs,
207
+ )
208
+
209
+ def log_hparams(self, cfg: DictConfig | dict) -> None: # noqa: F821
210
+ """Logs the hyperparameters of the experiment.
211
+
212
+ Args:
213
+ cfg (DictConfig or dict): The configuration of the experiment.
214
+ """
215
+ txt = "\n".join([f"{k}: {val}" for k, val in sorted(cfg.items())])
216
+ self.experiment.add_text("hparams", txt)
217
+
218
+ def __repr__(self) -> str:
219
+ return f"CSVLogger(exp_name={self.exp_name}, experiment={self.experiment.__repr__()})"
220
+
221
+ def log_histogram(self, name: str, data: Sequence, **kwargs):
222
+ raise NotImplementedError("Logging histograms in cvs is not permitted.")
223
+
224
+ def print_log_dir(self):
225
+ """Prints the log directory content."""
226
+ tensordict.utils.print_directory_tree(self.log_dir)