torchrl 0.11.0__cp314-cp314t-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cp314t-win_amd64.pyd +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/LICENSE +21 -0
  390. torchrl-0.11.0.dist-info/METADATA +1307 -0
  391. torchrl-0.11.0.dist-info/RECORD +394 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,23 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from __future__ import annotations
7
+
8
+ import contextlib
9
+ import os
10
+
11
+ import torch
12
+
13
+
14
+ @contextlib.contextmanager
15
+ def _cuda_visible_devices(devices: list[torch.device | int]):
16
+ devices = [torch.device(d).index if not isinstance(d, int) else d for d in devices]
17
+ CUDA_VISIBLE_DEVICES = os.getenv("CUDA_VISIBLE_DEVICES")
18
+ os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, devices))
19
+ yield
20
+ if CUDA_VISIBLE_DEVICES:
21
+ os.environ["CUDA_VISIBLE_DEVICES"] = CUDA_VISIBLE_DEVICES
22
+ else:
23
+ os.unsetenv("CUDA_VISIBLE_DEVICES")
@@ -0,0 +1,21 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from .scores import (
6
+ EXP3Score,
7
+ MCTSScore,
8
+ MCTSScores,
9
+ PUCTScore,
10
+ UCB1TunedScore,
11
+ UCBScore,
12
+ )
13
+
14
+ __all__ = [
15
+ "EXP3Score",
16
+ "MCTSScore",
17
+ "MCTSScores",
18
+ "PUCTScore",
19
+ "UCB1TunedScore",
20
+ "UCBScore",
21
+ ]
@@ -0,0 +1,579 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import functools
8
+ import math
9
+ import warnings
10
+ from abc import abstractmethod
11
+ from enum import Enum
12
+
13
+ import torch
14
+
15
+ from tensordict import NestedKey, TensorDictBase
16
+ from tensordict.nn import TensorDictModuleBase
17
+
18
+
19
+ class MCTSScore(TensorDictModuleBase):
20
+ """Abstract base class for MCTS score computation modules."""
21
+
22
+ @abstractmethod
23
+ def forward(self, node: TensorDictBase) -> TensorDictBase:
24
+ pass
25
+
26
+
27
+ class PUCTScore(MCTSScore):
28
+ """Computes the PUCT (Polynomial Upper Confidence Trees) score for MCTS.
29
+
30
+ PUCT is a widely used score in MCTS algorithms, notably in AlphaGo and AlphaZero,
31
+ to balance exploration and exploitation. It incorporates prior probabilities from a
32
+ policy network, encouraging exploration of actions deemed promising by the policy,
33
+ while also considering visit counts and accumulated rewards.
34
+
35
+ The formula used is:
36
+ `score = (win_count / visits) + c * prior_prob * sqrt(total_visits) / (1 + visits)`
37
+
38
+ Where:
39
+ - `win_count`: Sum of rewards (or win counts) for the action.
40
+ - `visits`: Visit count for the action.
41
+ - `total_visits`: Visit count of the parent node (N).
42
+ - `prior_prob`: Prior probability of selecting the action (e.g., from a policy network).
43
+ - `c`: The exploration constant, controlling the trade-off between exploitation
44
+ (first term) and exploration (second term).
45
+
46
+ Args:
47
+ c (float): The exploration constant.
48
+ win_count_key (NestedKey, optional): Key for the tensor in the input `TensorDictBase`
49
+ containing the sum of rewards (or win counts) for each action.
50
+ Defaults to "win_count".
51
+ visits_key (NestedKey, optional): Key for the tensor containing the visit
52
+ count for each action. Defaults to "visits".
53
+ total_visits_key (NestedKey, optional): Key for the tensor (or scalar)
54
+ representing the visit count of the parent node (N). Defaults to "total_visits".
55
+ prior_prob_key (NestedKey, optional): Key for the tensor containing the
56
+ prior probabilities for each action. Defaults to "prior_prob".
57
+ score_key (NestedKey, optional): Key where the calculated PUCT scores
58
+ will be stored in the output `TensorDictBase`. Defaults to "score".
59
+
60
+ Input Keys:
61
+ - `win_count_key` (torch.Tensor): Tensor of shape (..., num_actions)
62
+ or matching `visits_key`.
63
+ - `visits_key` (torch.Tensor): Tensor of shape (..., num_actions). If an action
64
+ has zero visits, its exploitation term (win_count / visits) will result in NaN
65
+ if win_count is also zero, or +/-inf if win_count is non-zero. The exploration
66
+ term will still be valid due to `(1 + visits)`.
67
+ - `total_visits_key` (torch.Tensor): Scalar or tensor broadcastable to other inputs,
68
+ representing the parent node's visit count.
69
+ - `prior_prob_key` (torch.Tensor): Tensor of shape (..., num_actions) containing
70
+ prior probabilities.
71
+
72
+ Output Keys:
73
+ - `score_key` (torch.Tensor): Tensor of the same shape as `visits_key`, containing
74
+ the calculated PUCT scores.
75
+
76
+ Example:
77
+ ```python
78
+ from tensordict import TensorDict
79
+ from torchrl.modules.mcts.scores import PUCTScore
80
+
81
+ # Create a PUCTScore instance
82
+ puct = PUCTScore(c=1.5)
83
+
84
+ # Define a TensorDict with required keys
85
+ node = TensorDict(
86
+ {
87
+ "win_count": torch.tensor([10.0, 20.0]),
88
+ "visits": torch.tensor([5.0, 10.0]),
89
+ "total_visits": torch.tensor(50.0),
90
+ "prior_prob": torch.tensor([0.6, 0.4]),
91
+ },
92
+ batch_size=[],
93
+ )
94
+
95
+ # Compute the PUCT scores
96
+ result = puct(node)
97
+ print(result["score"]) # Output: Tensor with PUCT scores
98
+ ```
99
+ """
100
+
101
+ c: float
102
+
103
+ def __init__(
104
+ self,
105
+ *,
106
+ c: float,
107
+ win_count_key: NestedKey = "win_count",
108
+ visits_key: NestedKey = "visits",
109
+ total_visits_key: NestedKey = "total_visits",
110
+ prior_prob_key: NestedKey = "prior_prob",
111
+ score_key: NestedKey = "score",
112
+ ):
113
+ super().__init__()
114
+ self.c = c
115
+ self.win_count_key = win_count_key
116
+ self.visits_key = visits_key
117
+ self.total_visits_key = total_visits_key
118
+ self.prior_prob_key = prior_prob_key
119
+ self.score_key = score_key
120
+ self.in_keys = [
121
+ self.win_count_key,
122
+ self.prior_prob_key,
123
+ self.total_visits_key,
124
+ self.visits_key,
125
+ ]
126
+ self.out_keys = [self.score_key]
127
+
128
+ def forward(self, node: TensorDictBase) -> TensorDictBase:
129
+ win_count = node.get(self.win_count_key)
130
+ visits = node.get(self.visits_key)
131
+ n_total = node.get(self.total_visits_key)
132
+ prior_prob = node.get(self.prior_prob_key)
133
+ # Handle broadcasting for batched inputs
134
+ if n_total.ndim > 0 and n_total.ndim < visits.ndim:
135
+ n_total = n_total.unsqueeze(-1)
136
+ node.set(
137
+ self.score_key,
138
+ (win_count / visits) + self.c * prior_prob * n_total.sqrt() / (1 + visits),
139
+ )
140
+ return node
141
+
142
+
143
+ class UCBScore(MCTSScore):
144
+ """Computes the UCB (Upper Confidence Bound) score, specifically UCB1, for MCTS.
145
+
146
+ UCB1 is a classic algorithm for the multi-armed bandit problem that balances
147
+ exploration and exploitation. In MCTS, it's used to select which action to
148
+ explore from a given node. The score encourages trying actions with high
149
+ empirical rewards and actions that have been visited less frequently.
150
+
151
+ The formula used is:
152
+ `score = (win_count / visits) + c * sqrt(total_visits) / (1 + visits)`
153
+
154
+ Args:
155
+ c (float): The exploration constant. A common value is `sqrt(2)`.
156
+ win_count_key (NestedKey, optional): Key for the tensor in the input `TensorDictBase`
157
+ containing the sum of rewards (or win counts) for each action.
158
+ Defaults to "win_count".
159
+ visits_key (NestedKey, optional): Key for the tensor containing the visit
160
+ count for each action. Defaults to "visits".
161
+ total_visits_key (NestedKey, optional): Key for the tensor (or scalar)
162
+ representing the visit count of the parent node (N). This is used in the
163
+ exploration term. Defaults to "total_visits".
164
+ score_key (NestedKey, optional): Key where the calculated UCB scores
165
+ will be stored in the output `TensorDictBase`. Defaults to "score".
166
+
167
+ Input Keys:
168
+ - `win_count_key` (torch.Tensor): Tensor of shape (..., num_actions).
169
+ - `visits_key` (torch.Tensor): Tensor of shape (..., num_actions).
170
+ - `total_visits_key` (torch.Tensor): Scalar or tensor broadcastable to other inputs.
171
+
172
+ Output Keys:
173
+ - `score_key` (torch.Tensor): Tensor of the same shape as `visits_key`, containing
174
+ the calculated UCB scores.
175
+
176
+ Example:
177
+ ```python
178
+ from tensordict import TensorDict
179
+ from torchrl.modules.mcts.scores import UCBScore
180
+
181
+ # Create a UCBScore instance
182
+ ucb = UCBScore(c=1.414)
183
+
184
+ # Define a TensorDict with required keys
185
+ node = TensorDict(
186
+ {
187
+ "win_count": torch.tensor([15.0, 25.0]),
188
+ "visits": torch.tensor([10.0, 20.0]),
189
+ "total_visits": torch.tensor(100.0),
190
+ },
191
+ batch_size=[],
192
+ )
193
+
194
+ # Compute the UCB scores
195
+ result = ucb(node)
196
+ print(result["score"]) # Output: Tensor with UCB scores
197
+ ```
198
+ """
199
+
200
+ c: float
201
+
202
+ def __init__(
203
+ self,
204
+ *,
205
+ c: float,
206
+ win_count_key: NestedKey = "win_count",
207
+ visits_key: NestedKey = "visits",
208
+ total_visits_key: NestedKey = "total_visits",
209
+ score_key: NestedKey = "score",
210
+ ):
211
+ super().__init__()
212
+ self.c = c
213
+ self.win_count_key = win_count_key
214
+ self.visits_key = visits_key
215
+ self.total_visits_key = total_visits_key
216
+ self.score_key = score_key
217
+ self.in_keys = [self.win_count_key, self.total_visits_key, self.visits_key]
218
+ self.out_keys = [self.score_key]
219
+
220
+ def forward(self, node: TensorDictBase) -> TensorDictBase:
221
+ win_count = node.get(self.win_count_key)
222
+ visits = node.get(self.visits_key)
223
+ n_total = node.get(self.total_visits_key)
224
+ # Handle broadcasting for batched inputs
225
+ if n_total.ndim > 0 and n_total.ndim < visits.ndim:
226
+ n_total = n_total.unsqueeze(-1)
227
+ node.set(
228
+ self.score_key,
229
+ (win_count / visits) + self.c * n_total.sqrt() / (1 + visits),
230
+ )
231
+ return node
232
+
233
+
234
+ class EXP3Score(MCTSScore):
235
+ """Computes action selection probabilities for the EXP3 algorithm in MCTS.
236
+
237
+ EXP3 (Exponential-weight algorithm for Exploration and Exploitation) is a bandit
238
+ algorithm that performs well in adversarial or non-stationary environments.
239
+ It maintains weights for each action and adjusts them based on received rewards.
240
+
241
+ Args:
242
+ gamma (float, optional): Exploration factor, balancing uniform exploration
243
+ and exploitation of current weights. Must be in [0, 1]. Defaults to 0.1.
244
+ weights_key (NestedKey, optional): Key in the input `TensorDictBase` for
245
+ the tensor containing current action weights. Defaults to "weights".
246
+ action_prob_key (NestedKey, optional): Key to store the calculated action
247
+ probabilities. Defaults to "action_prob".
248
+ score_key (NestedKey, optional): Key where the calculated action probabilities
249
+ will be stored. Defaults to "score".
250
+ num_actions_key (NestedKey, optional): Key for the number of available
251
+ actions (K). Defaults to "num_actions".
252
+
253
+ Input Keys:
254
+ - `weights_key` (torch.Tensor): Tensor of shape (..., num_actions).
255
+ - `num_actions_key` (int or torch.Tensor): Scalar representing K, the number of actions.
256
+
257
+ Output Keys:
258
+ - `score_key` (torch.Tensor): Tensor of shape (..., num_actions) containing
259
+ the calculated action probabilities.
260
+
261
+ Example:
262
+ ```python
263
+ from tensordict import TensorDict
264
+ from torchrl.modules.mcts.scores import EXP3Score
265
+
266
+ # Create an EXP3Score instance
267
+ exp3 = EXP3Score(gamma=0.1)
268
+
269
+ # Define a TensorDict with required keys
270
+ node = TensorDict(
271
+ {
272
+ "weights": torch.tensor([1.0, 1.0]),
273
+ "num_actions": torch.tensor(2),
274
+ },
275
+ batch_size=[],
276
+ )
277
+
278
+ # Compute the action probabilities
279
+ result = exp3(node)
280
+ print(result["score"]) # Output: Tensor with action probabilities
281
+ ```
282
+ """
283
+
284
+ def __init__(
285
+ self,
286
+ *,
287
+ gamma: float = 0.1,
288
+ weights_key: NestedKey = "weights",
289
+ action_prob_key: NestedKey = "action_prob",
290
+ reward_key: NestedKey = "reward",
291
+ score_key: NestedKey = "score",
292
+ num_actions_key: NestedKey = "num_actions",
293
+ ):
294
+ super().__init__()
295
+ if not 0 <= gamma <= 1:
296
+ raise ValueError(f"gamma must be between 0 and 1, got {gamma}")
297
+ self.gamma = gamma
298
+ self.weights_key = weights_key
299
+ self.action_prob_key = action_prob_key
300
+ self.reward_key = reward_key
301
+ self.score_key = score_key
302
+ self.num_actions_key = num_actions_key
303
+
304
+ self.in_keys = [self.weights_key, self.num_actions_key]
305
+ self.out_keys = [self.score_key]
306
+
307
+ def forward(self, node: TensorDictBase) -> TensorDictBase:
308
+ num_actions = node.get(self.num_actions_key)
309
+
310
+ # Extract scalar value from num_actions (handles batched tensors too)
311
+ if isinstance(num_actions, torch.Tensor):
312
+ # For batched tensors, take the first element (all should be same)
313
+ k = int(num_actions.flatten()[0].item())
314
+ elif isinstance(num_actions, int):
315
+ k = num_actions
316
+ else:
317
+ raise ValueError(
318
+ f"'{self.num_actions_key}' ('num_actions') must be an integer or a tensor."
319
+ )
320
+
321
+ if self.weights_key not in node.keys(include_nested=True):
322
+ batch_size = node.batch_size
323
+ weights_shape = (*batch_size, k)
324
+ weights = torch.ones(weights_shape, device=node.device)
325
+ node.set(self.weights_key, weights)
326
+ else:
327
+ weights = node.get(self.weights_key)
328
+
329
+ k_from_weights = weights.shape[-1]
330
+ if k_from_weights != k:
331
+ raise ValueError(
332
+ f"Shape of weights {weights.shape} implies {k_from_weights} actions, "
333
+ f"but num_actions is {k}."
334
+ )
335
+
336
+ sum_weights = torch.sum(weights, dim=-1, keepdim=True)
337
+ sum_weights = torch.where(
338
+ sum_weights == 0, torch.ones_like(sum_weights), sum_weights
339
+ )
340
+
341
+ p_i = (1 - self.gamma) * (weights / sum_weights) + (self.gamma / k)
342
+ node.set(self.score_key, p_i)
343
+ if self.action_prob_key != self.score_key:
344
+ node.set(self.action_prob_key, p_i)
345
+ return node
346
+
347
+ def update_weights(
348
+ self, node: TensorDictBase, action_idx: int, reward: float
349
+ ) -> None:
350
+ """Updates the weight of the chosen action based on the reward.
351
+
352
+ This method updates the weight of the selected action using the EXP3 algorithm.
353
+ The weight update formula is:
354
+ `w_i(t+1) = w_i(t) * exp((gamma / K) * (reward / p_i(t)))`
355
+
356
+ Args:
357
+ node (TensorDictBase): The node containing the current weights and probabilities.
358
+ Must include the keys specified by `weights_key` and `score_key`.
359
+ action_idx (int): The index of the action that was selected.
360
+ reward (float): The reward received for the selected action. Must be in the range [0, 1].
361
+
362
+ Raises:
363
+ ValueError: If the reward is not in the range [0, 1].
364
+ ValueError: If the probability of the selected action is less than or equal to 0.
365
+
366
+ Example:
367
+ ```python
368
+ from tensordict import TensorDict
369
+ from torchrl.modules.mcts.scores import EXP3Score
370
+
371
+ # Create an EXP3Score instance
372
+ exp3 = EXP3Score(gamma=0.1)
373
+
374
+ # Define a TensorDict with required keys
375
+ node = TensorDict(
376
+ {
377
+ "weights": torch.tensor([1.0, 1.0]),
378
+ "num_actions": torch.tensor(2),
379
+ },
380
+ batch_size=[],
381
+ )
382
+
383
+ # Compute the action probabilities
384
+ result = exp3(node)
385
+ print(result["score"]) # Output: Tensor with action probabilities
386
+
387
+ # Update the weights based on the reward for action 0
388
+ exp3.update_weights(node, action_idx=0, reward=0.8)
389
+ print(node["weights"]) # Updated weights
390
+ ```
391
+ """
392
+ if not (0 <= reward <= 1):
393
+ warnings.warn(
394
+ f"Reward {reward} is outside the expected [0,1] range for EXP3.",
395
+ UserWarning,
396
+ )
397
+
398
+ weights = node.get(self.weights_key)
399
+ action_probs = node.get(self.score_key)
400
+ k = weights.shape[-1]
401
+
402
+ if weights.ndim == 1:
403
+ current_weight = weights[action_idx]
404
+ prob_i = action_probs[action_idx]
405
+ elif weights.ndim > 1:
406
+ current_weight = weights[..., action_idx]
407
+ prob_i = action_probs[..., action_idx]
408
+ else:
409
+ raise ValueError(f"Invalid weights dimensions: {weights.ndim}")
410
+
411
+ if torch.any(prob_i <= 0):
412
+ prob_i_val = prob_i.item() if prob_i.numel() == 1 else prob_i
413
+ warnings.warn(
414
+ f"Probability p_i(t) for action {action_idx} is {prob_i_val}. "
415
+ "Weight will not be updated for zero probability actions.",
416
+ UserWarning,
417
+ )
418
+ # Don't update weights for zero probability - just return
419
+ return
420
+
421
+ reward_tensor = torch.as_tensor(
422
+ reward, device=current_weight.device, dtype=current_weight.dtype
423
+ )
424
+ exponent = (self.gamma / k) * (reward_tensor / prob_i)
425
+ new_weight = current_weight * torch.exp(exponent)
426
+
427
+ if weights.ndim == 1:
428
+ weights[action_idx] = new_weight
429
+ else:
430
+ weights[..., action_idx] = new_weight
431
+ node.set(self.weights_key, weights)
432
+
433
+
434
+ class UCB1TunedScore(MCTSScore):
435
+ """Computes the UCB1-Tuned score for MCTS, using variance estimation.
436
+
437
+ UCB1-Tuned is an enhancement of the UCB1 algorithm that incorporates an estimate
438
+ of the variance of rewards for each action. This allows for a more refined
439
+ balance between exploration and exploitation, potentially leading to better
440
+ performance, especially when reward variances differ significantly across actions.
441
+
442
+ The score for an action `i` is calculated as:
443
+ `score_i = avg_reward_i + sqrt(log(N) / N_i * min(0.25, V_i))`
444
+
445
+ The variance estimate `V_i` for action `i` is calculated as:
446
+ `V_i = (sum_squared_rewards_i / N_i) - avg_reward_i^2 + sqrt(exploration_constant * log(N) / N_i)`
447
+
448
+ Where:
449
+ - `avg_reward_i`: Average reward obtained from action `i`.
450
+ - `N_i`: Number of times action `i` has been visited.
451
+ - `N`: Total number of times the parent node has been visited.
452
+ - `sum_squared_rewards_i`: Sum of the squares of rewards received from action `i`.
453
+ - `exploration_constant`: A constant used in the bias correction term of `V_i`.
454
+ Auer et al. (2002) suggest a value of 2.0 for rewards in the range [0,1].
455
+ - The term `min(0.25, V_i)` implies that rewards are scaled to `[0, 1]`, as 0.25 is
456
+ the maximum variance for a distribution in this range (e.g., Bernoulli(0.5)).
457
+
458
+ Reference: "Finite-time Analysis of the Multiarmed Bandit Problem"
459
+ (Auer, Cesa-Bianchi, Fischer, 2002).
460
+
461
+ Args:
462
+ exploration_constant (float, optional): The constant `C` used in the bias
463
+ correction term for the variance estimate `V_i`. Defaults to `2.0`,
464
+ as suggested for rewards in `[0,1]`.
465
+ win_count_key (NestedKey, optional): Key for the tensor in the input `TensorDictBase`
466
+ containing the sum of rewards for each action (Q_i * N_i). Defaults to "win_count".
467
+ visits_key (NestedKey, optional): Key for the tensor containing the visit
468
+ count for each action (N_i). Defaults to "visits".
469
+ total_visits_key (NestedKey, optional): Key for the tensor (or scalar)
470
+ representing the visit count of the parent node (N). Defaults to "total_visits".
471
+ sum_squared_rewards_key (NestedKey, optional): Key for the tensor containing
472
+ the sum of squared rewards received for each action. This is crucial for
473
+ calculating the empirical variance. Defaults to "sum_squared_rewards".
474
+ score_key (NestedKey, optional): Key where the calculated UCB1-Tuned scores
475
+ will be stored in the output `TensorDictBase`. Defaults to "score".
476
+
477
+ Input Keys:
478
+ - `win_count_key` (torch.Tensor): Sum of rewards for each action.
479
+ - `visits_key` (torch.Tensor): Visit counts for each action (N_i).
480
+ - `total_visits_key` (torch.Tensor): Parent node's visit count (N).
481
+ - `sum_squared_rewards_key` (torch.Tensor): Sum of squared rewards for each action.
482
+
483
+ Output Keys:
484
+ - `score_key` (torch.Tensor): Calculated UCB1-Tuned scores for each action.
485
+
486
+ Important Notes:
487
+ - **Unvisited Nodes**: Actions with zero visits (`visits_key` is 0) are assigned a
488
+ very large positive score to ensure they are selected for exploration.
489
+ - **Reward Range**: The `min(0.25, V_i)` term is theoretically most sound when
490
+ rewards are normalized to the range `[0, 1]`.
491
+ - **Logarithm of N**: `log(N)` (log of parent visits) is calculated using `torch.log(torch.clamp(N, min=1.0))`
492
+ to prevent issues with `N=0` or `N` between 0 and 1.
493
+ """
494
+
495
+ def __init__(
496
+ self,
497
+ *,
498
+ win_count_key: NestedKey = "win_count",
499
+ visits_key: NestedKey = "visits",
500
+ total_visits_key: NestedKey = "total_visits",
501
+ sum_squared_rewards_key: NestedKey = "sum_squared_rewards",
502
+ score_key: NestedKey = "score",
503
+ exploration_constant: float = 2.0,
504
+ ):
505
+ super().__init__()
506
+ self.win_count_key = win_count_key
507
+ self.visits_key = visits_key
508
+ self.total_visits_key = total_visits_key
509
+ self.sum_squared_rewards_key = sum_squared_rewards_key
510
+ self.score_key = score_key
511
+ self.exploration_constant = exploration_constant
512
+
513
+ self.in_keys = [
514
+ self.win_count_key,
515
+ self.visits_key,
516
+ self.total_visits_key,
517
+ self.sum_squared_rewards_key,
518
+ ]
519
+ self.out_keys = [self.score_key]
520
+
521
+ def forward(self, node: TensorDictBase) -> TensorDictBase:
522
+ q_sum_i = node.get(self.win_count_key)
523
+ n_i = node.get(self.visits_key)
524
+ n_parent = node.get(self.total_visits_key)
525
+ sum_sq_rewards_i = node.get(self.sum_squared_rewards_key)
526
+
527
+ if n_parent.ndim > 0 and n_parent.ndim < q_sum_i.ndim:
528
+ n_parent_expanded = n_parent.unsqueeze(-1)
529
+ else:
530
+ n_parent_expanded = n_parent
531
+
532
+ safe_n_parent_for_log = torch.clamp(n_parent_expanded, min=1.0)
533
+ log_n_parent = torch.log(safe_n_parent_for_log)
534
+
535
+ scores = torch.zeros_like(q_sum_i, device=q_sum_i.device)
536
+
537
+ visited_mask = n_i > 0
538
+
539
+ if torch.any(visited_mask):
540
+ q_sum_i_v = q_sum_i[visited_mask]
541
+ n_i_v = n_i[visited_mask]
542
+ sum_sq_rewards_i_v = sum_sq_rewards_i[visited_mask]
543
+
544
+ log_n_parent_v = log_n_parent.expand_as(n_i)[visited_mask]
545
+
546
+ avg_reward_i_v = q_sum_i_v / n_i_v
547
+
548
+ empirical_variance_v = (sum_sq_rewards_i_v / n_i_v) - avg_reward_i_v.pow(2)
549
+ bias_correction_v = (
550
+ self.exploration_constant * log_n_parent_v / n_i_v
551
+ ).sqrt()
552
+
553
+ v_i_v = empirical_variance_v + bias_correction_v
554
+ v_i_v = v_i_v.clamp(min=0)
555
+
556
+ min_variance_term_v = torch.min(torch.full_like(v_i_v, 0.25), v_i_v)
557
+ exploration_component_v = (
558
+ log_n_parent_v / n_i_v * min_variance_term_v
559
+ ).sqrt()
560
+
561
+ scores[visited_mask] = avg_reward_i_v + exploration_component_v
562
+
563
+ unvisited_mask = ~visited_mask
564
+ if torch.any(unvisited_mask):
565
+ scores[unvisited_mask] = torch.finfo(scores.dtype).max / 10.0
566
+
567
+ node.set(self.score_key, scores)
568
+ return node
569
+
570
+
571
+ class MCTSScores(Enum):
572
+ """Enum providing factory functions for common MCTS score configurations."""
573
+
574
+ PUCT = functools.partial(PUCTScore, c=5) # AlphaGo default value
575
+ UCB = functools.partial(UCBScore, c=math.sqrt(2)) # default from Auer et al. 2002
576
+ UCB1_TUNED = functools.partial(
577
+ UCB1TunedScore, exploration_constant=2.0
578
+ ) # Auer et al. (2002) C=2 for rewards in [0,1]
579
+ EXP3 = functools.partial(EXP3Score, gamma=0.1)