torchrl 0.11.0__cp314-cp314-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cpython-314-aarch64-linux-gnu.so +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/METADATA +1308 -0
  390. torchrl-0.11.0.dist-info/RECORD +394 -0
  391. torchrl-0.11.0.dist-info/WHEEL +5 -0
  392. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  393. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,908 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ from collections.abc import Sequence
8
+
9
+ from enum import Enum
10
+ from functools import wraps
11
+
12
+ import torch
13
+ import torch.distributions as D
14
+ import torch.nn.functional as F
15
+ from tensordict.utils import expand_as_right
16
+
17
+ from torch.distributions.utils import lazy_property, logits_to_probs, probs_to_logits
18
+
19
+ __all__ = [
20
+ "OneHotCategorical",
21
+ "MaskedCategorical",
22
+ "Ordinal",
23
+ "OneHotOrdinal",
24
+ "LLMMaskedCategorical",
25
+ ]
26
+
27
+
28
+ def _treat_categorical_params(
29
+ params: torch.Tensor | None = None,
30
+ ) -> torch.Tensor | None:
31
+ if params is None:
32
+ return None
33
+ if params.shape[-1] == 1:
34
+ params = params[..., 0]
35
+ return params
36
+
37
+
38
+ def rand_one_hot(values: torch.Tensor, do_softmax: bool = True) -> torch.Tensor:
39
+ if do_softmax:
40
+ values = values.softmax(-1)
41
+ out = values.cumsum(-1) > torch.rand_like(values[..., :1])
42
+ out = (out.cumsum(-1) == 1).to(torch.long)
43
+ return out
44
+
45
+
46
+ class _one_hot_wrapper:
47
+ def __init__(self, parent_dist):
48
+ self.parent_dist = parent_dist
49
+
50
+ def __call__(self, func):
51
+ @wraps(func)
52
+ def wrapped(_self, *args, **kwargs):
53
+ out = getattr(self.parent_dist, func.__name__)(_self, *args, **kwargs)
54
+ n = _self.num_samples
55
+ return torch.nn.functional.one_hot(out, n)
56
+
57
+ return wrapped
58
+
59
+
60
+ class ReparamGradientStrategy(Enum):
61
+ PassThrough = 1
62
+ RelaxedOneHot = 2
63
+
64
+
65
+ class OneHotCategorical(D.Categorical):
66
+ """One-hot categorical distribution.
67
+
68
+ This class behaves exactly as torch.distributions.Categorical except that it reads and produces one-hot encodings
69
+ of the discrete tensors.
70
+
71
+ Args:
72
+ logits (torch.Tensor): event log probabilities (unnormalized)
73
+ probs (torch.Tensor): event probabilities
74
+ grad_method (ReparamGradientStrategy, optional): strategy to gather
75
+ reparameterized samples.
76
+ ``ReparamGradientStrategy.PassThrough`` will compute the sample gradients
77
+ by using the softmax valued log-probability as a proxy to the
78
+ sample gradients.
79
+ ``ReparamGradientStrategy.RelaxedOneHot`` will use
80
+ :class:`torch.distributions.RelaxedOneHot` to sample from the distribution.
81
+
82
+ Examples:
83
+ >>> torch.manual_seed(0)
84
+ >>> logits = torch.randn(4)
85
+ >>> dist = OneHotCategorical(logits=logits)
86
+ >>> print(dist.rsample((3,)))
87
+ tensor([[1., 0., 0., 0.],
88
+ [0., 0., 0., 1.],
89
+ [1., 0., 0., 0.]])
90
+
91
+ """
92
+
93
+ num_params: int = 1
94
+
95
+ # This is to make the compiler happy, see https://github.com/pytorch/pytorch/issues/140266
96
+ @lazy_property
97
+ def logits(self):
98
+ return probs_to_logits(self.probs)
99
+
100
+ @lazy_property
101
+ def probs(self):
102
+ return logits_to_probs(self.logits)
103
+
104
+ def __init__(
105
+ self,
106
+ logits: torch.Tensor | None = None,
107
+ probs: torch.Tensor | None = None,
108
+ grad_method: ReparamGradientStrategy = ReparamGradientStrategy.PassThrough,
109
+ **kwargs,
110
+ ) -> None:
111
+ logits = _treat_categorical_params(logits)
112
+ probs = _treat_categorical_params(probs)
113
+ self.grad_method = grad_method
114
+ super().__init__(probs=probs, logits=logits, **kwargs)
115
+ # Get num_samples from logits or probs shape
116
+ if logits is not None:
117
+ self.num_samples = logits.shape[-1]
118
+ else:
119
+ self.num_samples = probs.shape[-1]
120
+
121
+ def log_prob(self, value: torch.Tensor) -> torch.Tensor:
122
+ return super().log_prob(value.argmax(dim=-1))
123
+
124
+ @property
125
+ def mode(self) -> torch.Tensor:
126
+ if hasattr(self, "logits"):
127
+ return (self.logits == self.logits.max(-1, True)[0]).to(torch.long)
128
+ else:
129
+ return (self.probs == self.probs.max(-1, True)[0]).to(torch.long)
130
+
131
+ @property
132
+ def deterministic_sample(self):
133
+ return self.mode
134
+
135
+ def entropy(self):
136
+ min_real = torch.finfo(self.logits.dtype).min
137
+ logits = torch.clamp(self.logits, min=min_real)
138
+ p_log_p = logits * self.probs
139
+ return -p_log_p.sum(-1)
140
+
141
+ @_one_hot_wrapper(D.Categorical)
142
+ def sample(self, sample_shape: torch.Size | Sequence | None = None) -> torch.Tensor:
143
+ ...
144
+
145
+ def rsample(self, sample_shape: torch.Size | Sequence = None) -> torch.Tensor:
146
+ if sample_shape is None:
147
+ sample_shape = torch.Size([])
148
+ if hasattr(self, "logits") and self.logits is not None:
149
+ logits = self.logits
150
+ probs = None
151
+ else:
152
+ logits = None
153
+ probs = self.probs
154
+ if self.grad_method == ReparamGradientStrategy.RelaxedOneHot:
155
+ d = D.relaxed_categorical.RelaxedOneHotCategorical(
156
+ 1.0, probs=probs, logits=logits
157
+ )
158
+ out = d.rsample(sample_shape)
159
+ out.data.copy_((out == out.max(-1)[0].unsqueeze(-1)).to(out.dtype))
160
+ return out
161
+ elif self.grad_method == ReparamGradientStrategy.PassThrough:
162
+ if logits is not None:
163
+ probs = self.probs
164
+ else:
165
+ probs = torch.softmax(self.logits, dim=-1)
166
+ out = self.sample(sample_shape)
167
+ out = out + probs - probs.detach()
168
+ return out
169
+ else:
170
+ raise ValueError(
171
+ f"Unknown reparameterization strategy {self.reparam_strategy}."
172
+ )
173
+
174
+
175
+ class MaskedCategorical(D.Categorical):
176
+ """MaskedCategorical distribution.
177
+
178
+ Reference:
179
+ https://www.tensorflow.org/agents/api_docs/python/tf_agents/distributions/masked/MaskedCategorical
180
+
181
+ Args:
182
+ logits (torch.Tensor): event log probabilities (unnormalized)
183
+ probs (torch.Tensor): event probabilities. If provided, the probabilities
184
+ corresponding to masked items will be zeroed and the probability
185
+ re-normalized along its last dimension.
186
+
187
+ Keyword Args:
188
+ mask (torch.Tensor): A boolean mask of the same shape as ``logits``/``probs``
189
+ where ``False`` entries are the ones to be masked. Alternatively,
190
+ if ``sparse_mask`` is True, it represents the list of valid indices
191
+ in the distribution. Exclusive with ``indices``.
192
+ indices (torch.Tensor): A dense index tensor representing which actions
193
+ must be taken into account. Exclusive with ``mask``.
194
+ neg_inf (:obj:`float`, optional): The log-probability value allocated to
195
+ invalid (out-of-mask) indices. Defaults to -inf.
196
+ padding_value: The padding value in the mask tensor. When
197
+ sparse_mask == True, the padding_value will be ignored.
198
+ use_cross_entropy (bool, optional): For faster computation of the log-probability,
199
+ the cross_entropy loss functional can be used. Defaults to ``True``.
200
+ padding_side (str, optional): The side of the padding. Defaults to ``"left"``.
201
+
202
+ Examples:
203
+ >>> torch.manual_seed(0)
204
+ >>> logits = torch.randn(4) / 100 # almost equal probabilities
205
+ >>> mask = torch.tensor([True, False, True, True])
206
+ >>> dist = MaskedCategorical(logits=logits, mask=mask)
207
+ >>> sample = dist.sample((10,))
208
+ >>> print(sample) # no `1` in the sample
209
+ tensor([2, 3, 0, 2, 2, 0, 2, 0, 2, 2])
210
+ >>> print(dist.log_prob(sample))
211
+ tensor([-1.1203, -1.0928, -1.0831, -1.1203, -1.1203, -1.0831, -1.1203, -1.0831,
212
+ -1.1203, -1.1203])
213
+ >>> print(dist.log_prob(torch.ones_like(sample)))
214
+ tensor([-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf])
215
+ >>> # with probabilities
216
+ >>> prob = torch.ones(10)
217
+ >>> prob = prob / prob.sum()
218
+ >>> mask = torch.tensor([False] + 9 * [True]) # first outcome is masked
219
+ >>> dist = MaskedCategorical(probs=prob, mask=mask)
220
+ >>> print(dist.log_prob(torch.arange(10)))
221
+ tensor([ -inf, -2.1972, -2.1972, -2.1972, -2.1972, -2.1972, -2.1972, -2.1972,
222
+ -2.1972, -2.1972])
223
+ """
224
+
225
+ @lazy_property
226
+ def logits(self):
227
+ return probs_to_logits(self.probs)
228
+
229
+ @lazy_property
230
+ def probs(self):
231
+ return logits_to_probs(self.logits)
232
+
233
+ def __init__(
234
+ self,
235
+ logits: torch.Tensor | None = None,
236
+ probs: torch.Tensor | None = None,
237
+ *,
238
+ mask: torch.Tensor | None = None,
239
+ indices: torch.Tensor | None = None,
240
+ neg_inf: float = float("-inf"),
241
+ padding_value: int | None = None,
242
+ use_cross_entropy: bool = True,
243
+ padding_side: str = "left",
244
+ ) -> None:
245
+ if not ((mask is None) ^ (indices is None)):
246
+ raise ValueError(
247
+ f"A ``mask`` or some ``indices`` must be provided for {type(self)}, but not both."
248
+ )
249
+ if mask is None:
250
+ mask = indices
251
+ sparse_mask = True
252
+ else:
253
+ sparse_mask = False
254
+
255
+ if probs is not None:
256
+ if logits is not None:
257
+ raise ValueError(
258
+ "Either `probs` or `logits` must be specified, but not both."
259
+ )
260
+ # unnormalized logits
261
+ probs = probs.clone()
262
+ if mask.dtype == torch.bool:
263
+ probs[~mask] = 0
264
+ else:
265
+ probs = torch.scatter(
266
+ torch.zeros_like(probs), -1, indices, probs.gather(-1, indices)
267
+ )
268
+ probs = probs / probs.sum(-1, keepdim=True)
269
+ logits = probs.log()
270
+ num_samples = logits.shape[-1]
271
+ self.use_cross_entropy = use_cross_entropy
272
+ logits = self._mask_logits(
273
+ logits,
274
+ mask,
275
+ neg_inf=neg_inf,
276
+ sparse_mask=sparse_mask,
277
+ padding_value=padding_value,
278
+ )
279
+ self.neg_inf = neg_inf
280
+ self._mask = mask
281
+ self._sparse_mask = sparse_mask
282
+ self._padding_value = padding_value
283
+ self._padding_side = padding_side
284
+ super().__init__(logits=logits)
285
+ self.num_samples = num_samples
286
+
287
+ @property
288
+ def padding_value(self):
289
+ """Padding value of the distribution mask.
290
+
291
+ If the padding value is not set, it will be inferred from the logits.
292
+ """
293
+ return self._padding_value if self._padding_value is not None else 0
294
+
295
+ @property
296
+ def padding_side(self):
297
+ return self._padding_side
298
+
299
+ @property
300
+ def mask(self):
301
+ if self._sparse_mask:
302
+ raise ValueError("MaskedCategorical.mask does not support sparse masks")
303
+ return self._mask
304
+
305
+ def entropy(self):
306
+ """Compute the entropy of the distribution.
307
+
308
+ For masked distributions, we only consider the entropy over the valid (unmasked) outcomes.
309
+ Invalid outcomes have zero probability and don't contribute to entropy.
310
+ """
311
+ min_real = torch.finfo(self.logits.dtype).min
312
+
313
+ # Clamp logits to avoid numerical issues
314
+ logits = self.logits
315
+ if self._mask.dtype is torch.bool:
316
+ mask = expand_as_right(self._mask, logits)
317
+ mask = (~mask) | (~logits.isfinite())
318
+ logits = torch.masked_fill(logits, mask, min_real)
319
+ else:
320
+ # logits are already masked
321
+ pass
322
+ logits = logits - logits.logsumexp(-1, keepdim=True)
323
+
324
+ # Get probabilities and mask them
325
+ probs = logits.exp()
326
+
327
+ # Compute entropy only for valid outcomes
328
+ p_log_p = logits * probs
329
+ return -p_log_p.sum(-1)
330
+
331
+ def sample(
332
+ self, sample_shape: torch.Size | Sequence[int] | None = None
333
+ ) -> torch.Tensor:
334
+ if sample_shape is None:
335
+ sample_shape = torch.Size()
336
+ else:
337
+ sample_shape = torch.Size(sample_shape)
338
+
339
+ ret = super().sample(sample_shape)
340
+ if not self._sparse_mask:
341
+ return ret
342
+
343
+ size = ret.size()
344
+ outer_dim = sample_shape.numel()
345
+ inner_dim = self._mask.shape[:-1].numel()
346
+ idx_3d = self._mask.expand(outer_dim, inner_dim, -1)
347
+ ret = idx_3d.gather(dim=-1, index=ret.view(outer_dim, inner_dim, 1))
348
+ return ret.reshape(size)
349
+
350
+ def log_prob(self, value: torch.Tensor) -> torch.Tensor:
351
+ if not self._sparse_mask:
352
+ if self.use_cross_entropy:
353
+ logits = self.logits
354
+ if logits.ndim > 2:
355
+ # Bring channels in 2nd dim
356
+ logits = logits.permute(0, -1, *range(1, logits.ndim - 1))
357
+ original_value_shape = None
358
+ if logits.ndim == 1 and value.ndim >= 1:
359
+ if value.ndim >= 2:
360
+ original_value_shape = value.shape
361
+ value = value.flatten()
362
+ logits = logits.unsqueeze(0).expand(value.shape + logits.shape)
363
+ result = -torch.nn.functional.cross_entropy(logits, value, reduce=False)
364
+ if original_value_shape is not None:
365
+ result = result.unflatten(0, original_value_shape)
366
+ else:
367
+ result = super().log_prob(value)
368
+ result = torch.where(torch.isfinite(result), result, self.neg_inf)
369
+ return result
370
+
371
+ idx_3d = self._mask.view(1, -1, self._num_events)
372
+ val_3d = value.view(-1, idx_3d.size(1), 1)
373
+ mask = idx_3d == val_3d
374
+ idx = mask.int().argmax(dim=-1, keepdim=True)
375
+ idx = idx.view_as(value)
376
+ if self.use_cross_entropy:
377
+ logits = self.logits
378
+ if logits.ndim > 2:
379
+ # Bring channels in 2nd dim
380
+ logits = logits.transpose(-1, 1)
381
+ # possible shapes:
382
+ # Don't work with cross_entropy (missing batch dimension)
383
+ # logits.shape = (C,) and idx.shape = (B,)
384
+ # logits.shape = (C,) and idx.shape = (B0, B1, ...) => requires flattening of idx, only one batch dimension
385
+ # work with cross_entropy:
386
+ # logits.shape = (B, C) and idx.shape = (B,)
387
+ # logits.shape = (B, C, d1, d2, ...) and idx.shape = (B, d1, d2, ...)
388
+ original_idx_shape = None
389
+ if logits.ndim == 1 and idx.ndim >= 1:
390
+ if idx.ndim >= 2:
391
+ original_idx_shape = idx.shape
392
+ idx = idx.flatten()
393
+ logits = logits.unsqueeze(0).expand(idx.shape + logits.shape)
394
+ ret = -torch.nn.functional.cross_entropy(logits, idx, reduce=False)
395
+ if original_idx_shape is not None:
396
+ ret = ret.unflatten(0, original_idx_shape)
397
+ else:
398
+ ret = super().log_prob(idx)
399
+ # Fill masked values with neg_inf.
400
+ ret = ret.view_as(val_3d)
401
+ ret = ret.masked_fill(
402
+ torch.logical_not(mask.any(dim=-1, keepdim=True)), self.neg_inf
403
+ )
404
+ return ret.view_as(value)
405
+
406
+ @staticmethod
407
+ def _mask_logits(
408
+ logits: torch.Tensor,
409
+ mask: torch.Tensor | None = None,
410
+ neg_inf: float = float("-inf"),
411
+ sparse_mask: bool = False,
412
+ padding_value: int | None = None,
413
+ ) -> torch.Tensor:
414
+ if mask is None:
415
+ return logits
416
+
417
+ if not sparse_mask:
418
+ return logits.masked_fill(~mask, neg_inf)
419
+
420
+ if padding_value is not None:
421
+ padding_mask = mask == padding_value
422
+ if padding_value != 0:
423
+ # Avoid invalid indices in mask.
424
+ mask = mask.masked_fill(padding_mask, 0)
425
+ logits = logits.gather(dim=-1, index=mask)
426
+ if padding_value is not None:
427
+ logits.masked_fill_(padding_mask, neg_inf)
428
+ return logits
429
+
430
+ @property
431
+ def deterministic_sample(self):
432
+ return self.mode
433
+
434
+
435
+ class MaskedOneHotCategorical(MaskedCategorical):
436
+ """MaskedCategorical distribution.
437
+
438
+ Reference:
439
+ https://www.tensorflow.org/agents/api_docs/python/tf_agents/distributions/masked/MaskedCategorical
440
+
441
+ Args:
442
+ logits (torch.Tensor): event log probabilities (unnormalized)
443
+ probs (torch.Tensor): event probabilities. If provided, the probabilities
444
+ corresponding to masked items will be zeroed and the probability
445
+ re-normalized along its last dimension.
446
+
447
+ Keyword Args:
448
+ mask (torch.Tensor): A boolean mask of the same shape as ``logits``/``probs``
449
+ where ``False`` entries are the ones to be masked. Alternatively,
450
+ if ``sparse_mask`` is True, it represents the list of valid indices
451
+ in the distribution. Exclusive with ``indices``.
452
+ indices (torch.Tensor): A dense index tensor representing which actions
453
+ must be taken into account. Exclusive with ``mask``.
454
+ neg_inf (:obj:`float`, optional): The log-probability value allocated to
455
+ invalid (out-of-mask) indices. Defaults to -inf.
456
+ padding_value: The padding value in then mask tensor when
457
+ sparse_mask == True, the padding_value will be ignored.
458
+ grad_method (ReparamGradientStrategy, optional): strategy to gather
459
+ reparameterized samples.
460
+ ``ReparamGradientStrategy.PassThrough`` will compute the sample gradients
461
+ by using the softmax valued log-probability as a proxy to the
462
+ samples gradients.
463
+ ``ReparamGradientStrategy.RelaxedOneHot`` will use
464
+ :class:`torch.distributions.RelaxedOneHot` to sample from the distribution.
465
+
466
+ Examples:
467
+ >>> torch.manual_seed(0)
468
+ >>> logits = torch.randn(4) / 100 # almost equal probabilities
469
+ >>> mask = torch.tensor([True, False, True, True])
470
+ >>> dist = MaskedOneHotCategorical(logits=logits, mask=mask)
471
+ >>> sample = dist.sample((10,))
472
+ >>> print(sample) # no `1` in the sample
473
+ tensor([[0, 0, 1, 0],
474
+ [0, 0, 0, 1],
475
+ [1, 0, 0, 0],
476
+ [0, 0, 1, 0],
477
+ [0, 0, 1, 0],
478
+ [1, 0, 0, 0],
479
+ [0, 0, 1, 0],
480
+ [1, 0, 0, 0],
481
+ [0, 0, 1, 0],
482
+ [0, 0, 1, 0]])
483
+ >>> print(dist.log_prob(sample))
484
+ tensor([-1.1203, -1.0928, -1.0831, -1.1203, -1.1203, -1.0831, -1.1203, -1.0831,
485
+ -1.1203, -1.1203])
486
+ >>> sample_non_valid = torch.zeros_like(sample)
487
+ >>> sample_non_valid[..., 1] = 1
488
+ >>> print(dist.log_prob(sample_non_valid))
489
+ tensor([-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf])
490
+ >>> # with probabilities
491
+ >>> prob = torch.ones(10)
492
+ >>> prob = prob / prob.sum()
493
+ >>> mask = torch.tensor([False] + 9 * [True]) # first outcome is masked
494
+ >>> dist = MaskedOneHotCategorical(probs=prob, mask=mask)
495
+ >>> s = torch.arange(10)
496
+ >>> s = torch.nn.functional.one_hot(s, 10)
497
+ >>> print(dist.log_prob(s))
498
+ tensor([ -inf, -2.1972, -2.1972, -2.1972, -2.1972, -2.1972, -2.1972, -2.1972,
499
+ -2.1972, -2.1972])
500
+ """
501
+
502
+ @lazy_property
503
+ def logits(self):
504
+ return probs_to_logits(self.probs)
505
+
506
+ @lazy_property
507
+ def probs(self):
508
+ return logits_to_probs(self.logits)
509
+
510
+ def __init__(
511
+ self,
512
+ logits: torch.Tensor | None = None,
513
+ probs: torch.Tensor | None = None,
514
+ mask: torch.Tensor = None,
515
+ indices: torch.Tensor = None,
516
+ neg_inf: float = float("-inf"),
517
+ padding_value: int | None = None,
518
+ grad_method: ReparamGradientStrategy = ReparamGradientStrategy.PassThrough,
519
+ ) -> None:
520
+ self.grad_method = grad_method
521
+ super().__init__(
522
+ logits=logits,
523
+ probs=probs,
524
+ mask=mask,
525
+ indices=indices,
526
+ neg_inf=neg_inf,
527
+ padding_value=padding_value,
528
+ )
529
+
530
+ @_one_hot_wrapper(MaskedCategorical)
531
+ def sample(
532
+ self, sample_shape: torch.Size | Sequence[int] | None = None
533
+ ) -> torch.Tensor:
534
+ ...
535
+
536
+ @property
537
+ def deterministic_sample(self):
538
+ return self.mode
539
+
540
+ @property
541
+ def mode(self) -> torch.Tensor:
542
+ if hasattr(self, "logits"):
543
+ return (self.logits == self.logits.max(-1, True)[0]).to(torch.long)
544
+ else:
545
+ return (self.probs == self.probs.max(-1, True)[0]).to(torch.long)
546
+
547
+ def log_prob(self, value: torch.Tensor) -> torch.Tensor:
548
+ return super().log_prob(value.argmax(dim=-1))
549
+
550
+ def rsample(self, sample_shape: torch.Size | Sequence = None) -> torch.Tensor:
551
+ if sample_shape is None:
552
+ sample_shape = torch.Size([])
553
+ if hasattr(self, "logits") and self.logits is not None:
554
+ logits = self.logits
555
+ probs = None
556
+ else:
557
+ logits = None
558
+ probs = self.probs
559
+ if self.grad_method == ReparamGradientStrategy.RelaxedOneHot:
560
+ if self._sparse_mask:
561
+ if probs is not None:
562
+ probs_extended = torch.full(
563
+ (*probs.shape[:-1], self.num_samples),
564
+ 0,
565
+ device=probs.device,
566
+ dtype=probs.dtype,
567
+ )
568
+ probs_extended = torch.scatter(
569
+ probs_extended, -1, self._mask, probs
570
+ )
571
+ logits_extended = None
572
+ else:
573
+ probs_extended = torch.full(
574
+ (*logits.shape[:-1], self.num_samples),
575
+ self.neg_inf,
576
+ device=logits.device,
577
+ dtype=logits.dtype,
578
+ )
579
+ logits_extended = torch.scatter(
580
+ probs_extended, -1, self._mask, logits
581
+ )
582
+ probs_extended = None
583
+ else:
584
+ probs_extended = probs
585
+ logits_extended = logits
586
+
587
+ d = D.relaxed_categorical.RelaxedOneHotCategorical(
588
+ 1.0, probs=probs_extended, logits=logits_extended
589
+ )
590
+ out = d.rsample(sample_shape)
591
+ out.data.copy_((out == out.max(-1)[0].unsqueeze(-1)).to(out.dtype))
592
+ return out
593
+ elif self.grad_method == ReparamGradientStrategy.PassThrough:
594
+ if logits is not None:
595
+ probs = self.probs
596
+ else:
597
+ probs = torch.softmax(self.logits, dim=-1)
598
+ if self._sparse_mask:
599
+ probs_extended = torch.full(
600
+ (*probs.shape[:-1], self.num_samples),
601
+ 0,
602
+ device=probs.device,
603
+ dtype=probs.dtype,
604
+ )
605
+ probs_extended = torch.scatter(probs_extended, -1, self._mask, probs)
606
+ else:
607
+ probs_extended = probs
608
+
609
+ out = self.sample(sample_shape)
610
+ out = out + probs_extended - probs_extended.detach()
611
+ return out
612
+ else:
613
+ raise ValueError(
614
+ f"Unknown reparameterization strategy {self.reparam_strategy}."
615
+ )
616
+
617
+
618
+ class Ordinal(D.Categorical):
619
+ """A discrete distribution for learning to sample from finite ordered sets.
620
+
621
+ It is defined in contrast with the `Categorical` distribution, which does
622
+ not impose any notion of proximity or ordering over its support's atoms.
623
+ The `Ordinal` distribution explicitly encodes those concepts, which is
624
+ useful for learning discrete sampling from continuous sets. See §5 of
625
+ `Tang & Agrawal, 2020 <https://arxiv.org/pdf/1901.10500.pdf>`_ for details.
626
+
627
+ .. note::
628
+ This class is mostly useful when you want to learn a distribution over
629
+ a finite set which is obtained by discretising a continuous set.
630
+
631
+ Args:
632
+ scores (torch.Tensor): a tensor of shape [..., N] where N is the size of the set which supports the distributions.
633
+ Typically, the output of a neural network parametrising the distribution.
634
+
635
+ Examples:
636
+ >>> num_atoms, num_samples = 5, 20
637
+ >>> mean = (num_atoms - 1) / 2 # Target mean for samples, centered around the middle atom
638
+ >>> torch.manual_seed(42)
639
+ >>> logits = torch.ones((num_atoms), requires_grad=True)
640
+ >>> optimizer = torch.optim.Adam([logits], lr=0.1)
641
+ >>>
642
+ >>> # Perform optimisation loop to minimise deviation from `mean`
643
+ >>> for _ in range(20):
644
+ >>> sampler = Ordinal(scores=logits)
645
+ >>> samples = sampler.sample((num_samples,))
646
+ >>> # Define loss to encourage samples around the mean by penalising deviation from mean
647
+ >>> loss = torch.mean((samples - mean) ** 2 * sampler.log_prob(samples))
648
+ >>> loss.backward()
649
+ >>> optimizer.step()
650
+ >>> optimizer.zero_grad()
651
+ >>>
652
+ >>> sampler.probs
653
+ tensor([0.0308, 0.1586, 0.4727, 0.2260, 0.1120], ...)
654
+ >>> # Print histogram to observe sample distribution frequency across 5 bins (0, 1, 2, 3, and 4)
655
+ >>> torch.histogram(sampler.sample((1000,)).reshape(-1).float(), bins=num_atoms)
656
+ torch.return_types.histogram(
657
+ hist=tensor([ 24., 158., 478., 228., 112.]),
658
+ bin_edges=tensor([0.0000, 0.8000, 1.6000, 2.4000, 3.2000, 4.0000]))
659
+ """
660
+
661
+ def __init__(self, scores: torch.Tensor):
662
+ logits = _generate_ordinal_logits(scores)
663
+ super().__init__(logits=logits)
664
+
665
+
666
+ class OneHotOrdinal(OneHotCategorical):
667
+ """The one-hot version of the :class:`~tensordict.nn.distributions.Ordinal` distribution.
668
+
669
+ Args:
670
+ scores (torch.Tensor): a tensor of shape [..., N] where N is the size of the set which supports the distributions.
671
+ Typically, the output of a neural network parametrising the distribution.
672
+ """
673
+
674
+ def __init__(self, scores: torch.Tensor):
675
+ logits = _generate_ordinal_logits(scores)
676
+ super().__init__(logits=logits)
677
+
678
+
679
+ def _generate_ordinal_logits(scores: torch.Tensor) -> torch.Tensor:
680
+ """Implements Eq. 4 of `Tang & Agrawal, 2020<https://arxiv.org/pdf/1901.10500.pdf>`__."""
681
+ # Assigns Bernoulli-like probabilities for each class in the set
682
+ log_probs = F.logsigmoid(scores)
683
+ complementary_log_probs = F.logsigmoid(-scores)
684
+
685
+ # Total log-probability for being "larger than k"
686
+ larger_than_log_probs = log_probs.cumsum(dim=-1)
687
+
688
+ # Total log-probability for being "smaller than k"
689
+ smaller_than_log_probs = (
690
+ complementary_log_probs.flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1])
691
+ - complementary_log_probs
692
+ )
693
+
694
+ return larger_than_log_probs + smaller_than_log_probs
695
+
696
+
697
+ class LLMMaskedCategorical(D.Distribution):
698
+ """LLM-optimized masked categorical distribution.
699
+
700
+ This class provides a more memory-efficient approach for LLM training by:
701
+ 1. Using ignore_index=-100 for log_prob computation (no masking overhead)
702
+ 2. Using traditional masking for sampling operations
703
+
704
+ This is particularly beneficial for large vocabulary sizes where masking
705
+ all logits can be memory-intensive.
706
+
707
+ Args:
708
+ logits (torch.Tensor): Event log probabilities (unnormalized), shape [B, T, C].
709
+ - *B*: batch size (optional)
710
+ - T: sequence length
711
+ - C: vocabulary size (number of classes)
712
+ mask (torch.Tensor): Boolean mask indicating valid positions/tokens.
713
+ - If shape [*B, T]: position-level masking. True means the position is valid (all tokens allowed).
714
+ - If shape [*B, T, C]: token-level masking. True means the token is valid at that position.
715
+
716
+ .. warning:: Token-level masking is considerably more memory-intensive than position-level masking.
717
+ Only use this if you need to mask tokens.
718
+
719
+ ignore_index (int, optional): Index to ignore in log_prob computation. Defaults to -100.
720
+
721
+ Input shapes:
722
+ - logits: [*B, T, C] (required)
723
+ - mask: [*B, T] (position-level) or [*B, T, C] (token-level)
724
+ - tokens (for log_prob): [*B, T] (token indices, with ignore_index for masked positions)
725
+
726
+ Use cases:
727
+ 1. **Position-level masking**
728
+ >>> logits = torch.randn(2, 10, 50000) # [B=2, T=10, C=50000]
729
+ >>> mask = torch.ones(2, 10, dtype=torch.bool) # [B, T]
730
+ >>> mask[0, :5] = False # mask first 5 positions of first sequence
731
+ >>> dist = LLMMaskedCategorical(logits=logits, mask=mask)
732
+ >>> tokens = torch.randint(0, 50000, (2, 10)) # [B, T]
733
+ >>> tokens[0, :5] = -100 # set masked positions to ignore_index
734
+ >>> log_probs = dist.log_prob(tokens)
735
+ >>> samples = dist.sample() # [B, T]
736
+
737
+ 2. **Token-level masking**
738
+ >>> logits = torch.randn(2, 10, 50000)
739
+ >>> mask = torch.ones(2, 10, 50000, dtype=torch.bool) # [B, T, C]
740
+ >>> mask[0, :5, :1000] = False # mask first 1000 tokens for first 5 positions
741
+ >>> dist = LLMMaskedCategorical(logits=logits, mask=mask)
742
+ >>> tokens = torch.randint(0, 50000, (2, 10))
743
+ >>> # Optionally, set tokens at fully-masked positions to ignore_index
744
+ >>> log_probs = dist.log_prob(tokens)
745
+ >>> samples = dist.sample() # [B, T]
746
+
747
+ Notes:
748
+ - For log_prob, tokens must be of shape [B, T] and contain valid token indices (0 <= token < C), or ignore_index for masked/ignored positions.
749
+ - For token-level masking, if a token is masked at a given position, log_prob will return -inf for that entry.
750
+ - For position-level masking, if a position is masked (ignore_index), log_prob will return 0.0 for that entry (correct for cross-entropy loss).
751
+ - Sampling always respects the mask (masked tokens/positions are never sampled).
752
+
753
+ All documented use cases are covered by tests in test_distributions.py.
754
+ """
755
+
756
+ def __init__(
757
+ self,
758
+ logits: torch.Tensor,
759
+ mask: torch.Tensor,
760
+ ignore_index: int = -100,
761
+ ) -> None:
762
+ # Validate shapes
763
+ if logits.shape[:-1] != mask.shape and logits.shape != mask.shape:
764
+ raise ValueError(
765
+ f"Mask shape {mask.shape} must be either logits batch shape {logits.shape[:-1]} "
766
+ f"(for position-level masking) or logits shape {logits.shape} "
767
+ f"(for token-level masking)"
768
+ )
769
+
770
+ self._original_logits = logits
771
+ self._mask = mask
772
+ self.ignore_index = ignore_index
773
+ self._position_level_masking = mask.shape == logits.shape[:-1]
774
+
775
+ # Create masked logits for sampling (only when needed)
776
+ self._masked_logits = None
777
+ self._masked_dist = None
778
+
779
+ # Set up distribution properties
780
+ batch_shape = logits.shape[:-1]
781
+ event_shape = logits.shape[-1:]
782
+ super().__init__(batch_shape=batch_shape, event_shape=event_shape)
783
+
784
+ @property
785
+ def _sampling_logits(self):
786
+ """Get masked logits for sampling operations."""
787
+ if self._masked_logits is None:
788
+ # Only create masked logits when needed for sampling
789
+ large_neg = torch.finfo(self._original_logits.dtype).min
790
+
791
+ if self._position_level_masking:
792
+ # Position-level masking: expand mask to match logits shape
793
+ mask_expanded = expand_as_right(self._mask, self._original_logits)
794
+ self._masked_logits = self._original_logits.masked_fill(
795
+ ~mask_expanded, large_neg
796
+ )
797
+ else:
798
+ # Token-level masking: direct masking
799
+ self._masked_logits = self._original_logits.masked_fill(
800
+ ~self._mask, large_neg
801
+ )
802
+ return self._masked_logits
803
+
804
+ @property
805
+ def _sampling_dist(self):
806
+ """Get masked distribution for sampling operations."""
807
+ if self._masked_dist is None:
808
+ self._masked_dist = D.Categorical(logits=self._sampling_logits)
809
+ return self._masked_dist
810
+
811
+ def log_prob(self, value: torch.Tensor) -> torch.Tensor:
812
+ """Compute log probabilities using ignore_index approach.
813
+
814
+ This is memory-efficient as it doesn't require masking the logits.
815
+ The value tensor should use ignore_index for masked positions.
816
+ """
817
+ if not self._position_level_masking:
818
+ logits = self.masked_logits
819
+ else:
820
+ # Use cross_entropy with ignore_index for efficiency
821
+
822
+ # For position-level masking, keep the default behavior (0.0 for ignore_index)
823
+ # This is correct for cross-entropy loss computation
824
+ # For token-level masking, we need to check if specific tokens are masked
825
+
826
+ logits = self._original_logits
827
+ value = value.masked_fill(~self._mask, self.ignore_index)
828
+ if value.ndim > 1:
829
+ # Reshape for cross_entropy: (batch, seq_len, vocab) -> (batch*seq_len, vocab)
830
+ logits_flat = logits.reshape(-1, logits.size(-1))
831
+ value_flat = value.reshape(-1)
832
+
833
+ # Compute cross_entropy with ignore_index
834
+ log_probs_flat = -F.cross_entropy(
835
+ logits_flat, value_flat, reduce=False, ignore_index=self.ignore_index
836
+ )
837
+
838
+ # Reshape back
839
+ log_probs = log_probs_flat.reshape_as(value)
840
+ else:
841
+ log_probs = -F.cross_entropy(
842
+ logits,
843
+ value,
844
+ reduce=False,
845
+ ignore_index=self.ignore_index,
846
+ )
847
+ return log_probs
848
+
849
+ def sample(
850
+ self, sample_shape: torch.Size | Sequence[int] | None = None
851
+ ) -> torch.Tensor:
852
+ """Sample from the distribution using masked logits."""
853
+ if sample_shape is None:
854
+ sample_shape = torch.Size()
855
+ return self._sampling_dist.sample(sample_shape)
856
+
857
+ def rsample(
858
+ self, sample_shape: torch.Size | Sequence[int] | None = None
859
+ ) -> torch.Tensor:
860
+ """Reparameterized sampling using masked logits."""
861
+ # This would need to be implemented based on the specific reparameterization strategy
862
+ # For now, fall back to regular sampling
863
+ return self.sample(sample_shape)
864
+
865
+ @property
866
+ def mode(self) -> torch.Tensor:
867
+ """Get the mode using masked logits."""
868
+ masked_logits = self._sampling_logits
869
+ return masked_logits.argmax(dim=-1)
870
+
871
+ def entropy(self) -> torch.Tensor:
872
+ """Compute entropy using masked logits."""
873
+ return self._sampling_dist.entropy()
874
+
875
+ def clear_cache(self):
876
+ """Clear cached masked tensors to free memory."""
877
+ self._masked_logits = None
878
+ self._masked_dist = None
879
+
880
+ @property
881
+ def mask(self) -> torch.Tensor:
882
+ """Get the mask."""
883
+ return self._mask
884
+
885
+ @property
886
+ def logits(self) -> torch.Tensor:
887
+ """Get the original logits."""
888
+ return self._original_logits
889
+
890
+ @property
891
+ def probs(self) -> torch.Tensor:
892
+ """Get probabilities from original logits."""
893
+ return torch.softmax(self._original_logits, dim=-1)
894
+
895
+ @property
896
+ def masked_logits(self) -> torch.Tensor:
897
+ """Get the masked logits for sampling operations."""
898
+ return self._sampling_logits
899
+
900
+ @property
901
+ def masked_dist(self) -> D.Categorical:
902
+ """Get the masked distribution for sampling operations."""
903
+ return self._sampling_dist
904
+
905
+ @property
906
+ def position_level_masking(self) -> bool:
907
+ """Whether the mask is position-level (True) or token-level (False)."""
908
+ return self._position_level_masking