torchrl 0.11.0__cp314-cp314-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cpython-314-aarch64-linux-gnu.so +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/METADATA +1308 -0
  390. torchrl-0.11.0.dist-info/RECORD +394 -0
  391. torchrl-0.11.0.dist-info/WHEEL +5 -0
  392. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  393. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,1677 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ # Original LICENSE:
7
+ # Copyright 2024 The Google Research Authors.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+
21
+ from __future__ import annotations
22
+
23
+ import functools
24
+ import random
25
+ import re
26
+
27
+ import immutabledict
28
+ import nltk
29
+
30
+ WORD_LIST = [
31
+ "western",
32
+ "sentence",
33
+ "signal",
34
+ "dump",
35
+ "spot",
36
+ "opposite",
37
+ "bottom",
38
+ "potato",
39
+ "administration",
40
+ "working",
41
+ "welcome",
42
+ "morning",
43
+ "good",
44
+ "agency",
45
+ "primary",
46
+ "wish",
47
+ "responsibility",
48
+ "press",
49
+ "problem",
50
+ "president",
51
+ "steal",
52
+ "brush",
53
+ "read",
54
+ "type",
55
+ "beat",
56
+ "trainer",
57
+ "growth",
58
+ "lock",
59
+ "bone",
60
+ "case",
61
+ "equal",
62
+ "comfortable",
63
+ "region",
64
+ "replacement",
65
+ "performance",
66
+ "mate",
67
+ "walk",
68
+ "medicine",
69
+ "film",
70
+ "thing",
71
+ "rock",
72
+ "tap",
73
+ "total",
74
+ "competition",
75
+ "ease",
76
+ "south",
77
+ "establishment",
78
+ "gather",
79
+ "parking",
80
+ "world",
81
+ "plenty",
82
+ "breath",
83
+ "claim",
84
+ "alcohol",
85
+ "trade",
86
+ "dear",
87
+ "highlight",
88
+ "street",
89
+ "matter",
90
+ "decision",
91
+ "mess",
92
+ "agreement",
93
+ "studio",
94
+ "coach",
95
+ "assist",
96
+ "brain",
97
+ "wing",
98
+ "style",
99
+ "private",
100
+ "top",
101
+ "brown",
102
+ "leg",
103
+ "buy",
104
+ "procedure",
105
+ "method",
106
+ "speed",
107
+ "high",
108
+ "company",
109
+ "valuable",
110
+ "pie",
111
+ "analyst",
112
+ "session",
113
+ "pattern",
114
+ "district",
115
+ "pleasure",
116
+ "dinner",
117
+ "swimming",
118
+ "joke",
119
+ "order",
120
+ "plate",
121
+ "department",
122
+ "motor",
123
+ "cell",
124
+ "spend",
125
+ "cabinet",
126
+ "difference",
127
+ "power",
128
+ "examination",
129
+ "engine",
130
+ "horse",
131
+ "dimension",
132
+ "pay",
133
+ "toe",
134
+ "curve",
135
+ "literature",
136
+ "bother",
137
+ "fire",
138
+ "possibility",
139
+ "debate",
140
+ "activity",
141
+ "passage",
142
+ "hello",
143
+ "cycle",
144
+ "background",
145
+ "quiet",
146
+ "author",
147
+ "effect",
148
+ "actor",
149
+ "page",
150
+ "bicycle",
151
+ "error",
152
+ "throat",
153
+ "attack",
154
+ "character",
155
+ "phone",
156
+ "tea",
157
+ "increase",
158
+ "outcome",
159
+ "file",
160
+ "specific",
161
+ "inspector",
162
+ "internal",
163
+ "potential",
164
+ "staff",
165
+ "building",
166
+ "employer",
167
+ "shoe",
168
+ "hand",
169
+ "direction",
170
+ "garden",
171
+ "purchase",
172
+ "interview",
173
+ "study",
174
+ "recognition",
175
+ "member",
176
+ "spiritual",
177
+ "oven",
178
+ "sandwich",
179
+ "weird",
180
+ "passenger",
181
+ "particular",
182
+ "response",
183
+ "reaction",
184
+ "size",
185
+ "variation",
186
+ "a",
187
+ "cancel",
188
+ "candy",
189
+ "exit",
190
+ "guest",
191
+ "condition",
192
+ "fly",
193
+ "price",
194
+ "weakness",
195
+ "convert",
196
+ "hotel",
197
+ "great",
198
+ "mouth",
199
+ "mind",
200
+ "song",
201
+ "sugar",
202
+ "suspect",
203
+ "telephone",
204
+ "ear",
205
+ "roof",
206
+ "paint",
207
+ "refrigerator",
208
+ "organization",
209
+ "jury",
210
+ "reward",
211
+ "engineering",
212
+ "day",
213
+ "possession",
214
+ "crew",
215
+ "bar",
216
+ "road",
217
+ "description",
218
+ "celebration",
219
+ "score",
220
+ "mark",
221
+ "letter",
222
+ "shower",
223
+ "suggestion",
224
+ "sir",
225
+ "luck",
226
+ "national",
227
+ "progress",
228
+ "hall",
229
+ "stroke",
230
+ "theory",
231
+ "offer",
232
+ "story",
233
+ "tax",
234
+ "definition",
235
+ "history",
236
+ "ride",
237
+ "medium",
238
+ "opening",
239
+ "glass",
240
+ "elevator",
241
+ "stomach",
242
+ "question",
243
+ "ability",
244
+ "leading",
245
+ "village",
246
+ "computer",
247
+ "city",
248
+ "grand",
249
+ "confidence",
250
+ "candle",
251
+ "priest",
252
+ "recommendation",
253
+ "point",
254
+ "necessary",
255
+ "body",
256
+ "desk",
257
+ "secret",
258
+ "horror",
259
+ "noise",
260
+ "culture",
261
+ "warning",
262
+ "water",
263
+ "round",
264
+ "diet",
265
+ "flower",
266
+ "bus",
267
+ "tough",
268
+ "permission",
269
+ "week",
270
+ "prompt",
271
+ "connection",
272
+ "abuse",
273
+ "height",
274
+ "save",
275
+ "corner",
276
+ "border",
277
+ "stress",
278
+ "drive",
279
+ "stop",
280
+ "rip",
281
+ "meal",
282
+ "listen",
283
+ "confusion",
284
+ "girlfriend",
285
+ "living",
286
+ "relation",
287
+ "significance",
288
+ "plan",
289
+ "creative",
290
+ "atmosphere",
291
+ "blame",
292
+ "invite",
293
+ "housing",
294
+ "paper",
295
+ "drink",
296
+ "roll",
297
+ "silver",
298
+ "drunk",
299
+ "age",
300
+ "damage",
301
+ "smoke",
302
+ "environment",
303
+ "pack",
304
+ "savings",
305
+ "influence",
306
+ "tourist",
307
+ "rain",
308
+ "post",
309
+ "sign",
310
+ "grandmother",
311
+ "run",
312
+ "profit",
313
+ "push",
314
+ "clerk",
315
+ "final",
316
+ "wine",
317
+ "swim",
318
+ "pause",
319
+ "stuff",
320
+ "singer",
321
+ "funeral",
322
+ "average",
323
+ "source",
324
+ "scene",
325
+ "tradition",
326
+ "personal",
327
+ "snow",
328
+ "nobody",
329
+ "distance",
330
+ "sort",
331
+ "sensitive",
332
+ "animal",
333
+ "major",
334
+ "negotiation",
335
+ "click",
336
+ "mood",
337
+ "period",
338
+ "arrival",
339
+ "expression",
340
+ "holiday",
341
+ "repeat",
342
+ "dust",
343
+ "closet",
344
+ "gold",
345
+ "bad",
346
+ "sail",
347
+ "combination",
348
+ "clothes",
349
+ "emphasis",
350
+ "duty",
351
+ "black",
352
+ "step",
353
+ "school",
354
+ "jump",
355
+ "document",
356
+ "professional",
357
+ "lip",
358
+ "chemical",
359
+ "front",
360
+ "wake",
361
+ "while",
362
+ "inside",
363
+ "watch",
364
+ "row",
365
+ "subject",
366
+ "penalty",
367
+ "balance",
368
+ "possible",
369
+ "adult",
370
+ "aside",
371
+ "sample",
372
+ "appeal",
373
+ "wedding",
374
+ "depth",
375
+ "king",
376
+ "award",
377
+ "wife",
378
+ "blow",
379
+ "site",
380
+ "camp",
381
+ "music",
382
+ "safe",
383
+ "gift",
384
+ "fault",
385
+ "guess",
386
+ "act",
387
+ "shame",
388
+ "drama",
389
+ "capital",
390
+ "exam",
391
+ "stupid",
392
+ "record",
393
+ "sound",
394
+ "swing",
395
+ "novel",
396
+ "minimum",
397
+ "ratio",
398
+ "machine",
399
+ "shape",
400
+ "lead",
401
+ "operation",
402
+ "salary",
403
+ "cloud",
404
+ "affair",
405
+ "hit",
406
+ "chapter",
407
+ "stage",
408
+ "quantity",
409
+ "access",
410
+ "army",
411
+ "chain",
412
+ "traffic",
413
+ "kick",
414
+ "analysis",
415
+ "airport",
416
+ "time",
417
+ "vacation",
418
+ "philosophy",
419
+ "ball",
420
+ "chest",
421
+ "thanks",
422
+ "place",
423
+ "mountain",
424
+ "advertising",
425
+ "red",
426
+ "past",
427
+ "rent",
428
+ "return",
429
+ "tour",
430
+ "house",
431
+ "construction",
432
+ "net",
433
+ "native",
434
+ "war",
435
+ "figure",
436
+ "fee",
437
+ "spray",
438
+ "user",
439
+ "dirt",
440
+ "shot",
441
+ "task",
442
+ "stick",
443
+ "friend",
444
+ "software",
445
+ "promotion",
446
+ "interaction",
447
+ "surround",
448
+ "block",
449
+ "purpose",
450
+ "practice",
451
+ "conflict",
452
+ "routine",
453
+ "requirement",
454
+ "bonus",
455
+ "hole",
456
+ "state",
457
+ "junior",
458
+ "sweet",
459
+ "catch",
460
+ "tear",
461
+ "fold",
462
+ "wall",
463
+ "editor",
464
+ "life",
465
+ "position",
466
+ "pound",
467
+ "respect",
468
+ "bathroom",
469
+ "coat",
470
+ "script",
471
+ "job",
472
+ "teach",
473
+ "birth",
474
+ "view",
475
+ "resolve",
476
+ "theme",
477
+ "employee",
478
+ "doubt",
479
+ "market",
480
+ "education",
481
+ "serve",
482
+ "recover",
483
+ "tone",
484
+ "harm",
485
+ "miss",
486
+ "union",
487
+ "understanding",
488
+ "cow",
489
+ "river",
490
+ "association",
491
+ "concept",
492
+ "training",
493
+ "recipe",
494
+ "relationship",
495
+ "reserve",
496
+ "depression",
497
+ "proof",
498
+ "hair",
499
+ "revenue",
500
+ "independent",
501
+ "lift",
502
+ "assignment",
503
+ "temporary",
504
+ "amount",
505
+ "loss",
506
+ "edge",
507
+ "track",
508
+ "check",
509
+ "rope",
510
+ "estimate",
511
+ "pollution",
512
+ "stable",
513
+ "message",
514
+ "delivery",
515
+ "perspective",
516
+ "mirror",
517
+ "assistant",
518
+ "representative",
519
+ "witness",
520
+ "nature",
521
+ "judge",
522
+ "fruit",
523
+ "tip",
524
+ "devil",
525
+ "town",
526
+ "emergency",
527
+ "upper",
528
+ "drop",
529
+ "stay",
530
+ "human",
531
+ "neck",
532
+ "speaker",
533
+ "network",
534
+ "sing",
535
+ "resist",
536
+ "league",
537
+ "trip",
538
+ "signature",
539
+ "lawyer",
540
+ "importance",
541
+ "gas",
542
+ "choice",
543
+ "engineer",
544
+ "success",
545
+ "part",
546
+ "external",
547
+ "worker",
548
+ "simple",
549
+ "quarter",
550
+ "student",
551
+ "heart",
552
+ "pass",
553
+ "spite",
554
+ "shift",
555
+ "rough",
556
+ "lady",
557
+ "grass",
558
+ "community",
559
+ "garage",
560
+ "youth",
561
+ "standard",
562
+ "skirt",
563
+ "promise",
564
+ "blind",
565
+ "television",
566
+ "disease",
567
+ "commission",
568
+ "positive",
569
+ "energy",
570
+ "calm",
571
+ "presence",
572
+ "tune",
573
+ "basis",
574
+ "preference",
575
+ "head",
576
+ "common",
577
+ "cut",
578
+ "somewhere",
579
+ "presentation",
580
+ "current",
581
+ "thought",
582
+ "revolution",
583
+ "effort",
584
+ "master",
585
+ "implement",
586
+ "republic",
587
+ "floor",
588
+ "principle",
589
+ "stranger",
590
+ "shoulder",
591
+ "grade",
592
+ "button",
593
+ "tennis",
594
+ "police",
595
+ "collection",
596
+ "account",
597
+ "register",
598
+ "glove",
599
+ "divide",
600
+ "professor",
601
+ "chair",
602
+ "priority",
603
+ "combine",
604
+ "peace",
605
+ "extension",
606
+ "maybe",
607
+ "evening",
608
+ "frame",
609
+ "sister",
610
+ "wave",
611
+ "code",
612
+ "application",
613
+ "mouse",
614
+ "match",
615
+ "counter",
616
+ "bottle",
617
+ "half",
618
+ "cheek",
619
+ "resolution",
620
+ "back",
621
+ "knowledge",
622
+ "make",
623
+ "discussion",
624
+ "screw",
625
+ "length",
626
+ "accident",
627
+ "battle",
628
+ "dress",
629
+ "knee",
630
+ "log",
631
+ "package",
632
+ "it",
633
+ "turn",
634
+ "hearing",
635
+ "newspaper",
636
+ "layer",
637
+ "wealth",
638
+ "profile",
639
+ "imagination",
640
+ "answer",
641
+ "weekend",
642
+ "teacher",
643
+ "appearance",
644
+ "meet",
645
+ "bike",
646
+ "rise",
647
+ "belt",
648
+ "crash",
649
+ "bowl",
650
+ "equivalent",
651
+ "support",
652
+ "image",
653
+ "poem",
654
+ "risk",
655
+ "excitement",
656
+ "remote",
657
+ "secretary",
658
+ "public",
659
+ "produce",
660
+ "plane",
661
+ "display",
662
+ "money",
663
+ "sand",
664
+ "situation",
665
+ "punch",
666
+ "customer",
667
+ "title",
668
+ "shake",
669
+ "mortgage",
670
+ "option",
671
+ "number",
672
+ "pop",
673
+ "window",
674
+ "extent",
675
+ "nothing",
676
+ "experience",
677
+ "opinion",
678
+ "departure",
679
+ "dance",
680
+ "indication",
681
+ "boy",
682
+ "material",
683
+ "band",
684
+ "leader",
685
+ "sun",
686
+ "beautiful",
687
+ "muscle",
688
+ "farmer",
689
+ "variety",
690
+ "fat",
691
+ "handle",
692
+ "director",
693
+ "opportunity",
694
+ "calendar",
695
+ "outside",
696
+ "pace",
697
+ "bath",
698
+ "fish",
699
+ "consequence",
700
+ "put",
701
+ "owner",
702
+ "go",
703
+ "doctor",
704
+ "information",
705
+ "share",
706
+ "hurt",
707
+ "protection",
708
+ "career",
709
+ "finance",
710
+ "force",
711
+ "golf",
712
+ "garbage",
713
+ "aspect",
714
+ "kid",
715
+ "food",
716
+ "boot",
717
+ "milk",
718
+ "respond",
719
+ "objective",
720
+ "reality",
721
+ "raw",
722
+ "ring",
723
+ "mall",
724
+ "one",
725
+ "impact",
726
+ "area",
727
+ "news",
728
+ "international",
729
+ "series",
730
+ "impress",
731
+ "mother",
732
+ "shelter",
733
+ "strike",
734
+ "loan",
735
+ "month",
736
+ "seat",
737
+ "anything",
738
+ "entertainment",
739
+ "familiar",
740
+ "clue",
741
+ "year",
742
+ "glad",
743
+ "supermarket",
744
+ "natural",
745
+ "god",
746
+ "cost",
747
+ "conversation",
748
+ "tie",
749
+ "ruin",
750
+ "comfort",
751
+ "earth",
752
+ "storm",
753
+ "percentage",
754
+ "assistance",
755
+ "budget",
756
+ "strength",
757
+ "beginning",
758
+ "sleep",
759
+ "other",
760
+ "young",
761
+ "unit",
762
+ "fill",
763
+ "store",
764
+ "desire",
765
+ "hide",
766
+ "value",
767
+ "cup",
768
+ "maintenance",
769
+ "nurse",
770
+ "function",
771
+ "tower",
772
+ "role",
773
+ "class",
774
+ "camera",
775
+ "database",
776
+ "panic",
777
+ "nation",
778
+ "basket",
779
+ "ice",
780
+ "art",
781
+ "spirit",
782
+ "chart",
783
+ "exchange",
784
+ "feedback",
785
+ "statement",
786
+ "reputation",
787
+ "search",
788
+ "hunt",
789
+ "exercise",
790
+ "nasty",
791
+ "notice",
792
+ "male",
793
+ "yard",
794
+ "annual",
795
+ "collar",
796
+ "date",
797
+ "platform",
798
+ "plant",
799
+ "fortune",
800
+ "passion",
801
+ "friendship",
802
+ "spread",
803
+ "cancer",
804
+ "ticket",
805
+ "attitude",
806
+ "island",
807
+ "active",
808
+ "object",
809
+ "service",
810
+ "buyer",
811
+ "bite",
812
+ "card",
813
+ "face",
814
+ "steak",
815
+ "proposal",
816
+ "patient",
817
+ "heat",
818
+ "rule",
819
+ "resident",
820
+ "broad",
821
+ "politics",
822
+ "west",
823
+ "knife",
824
+ "expert",
825
+ "girl",
826
+ "design",
827
+ "salt",
828
+ "baseball",
829
+ "grab",
830
+ "inspection",
831
+ "cousin",
832
+ "couple",
833
+ "magazine",
834
+ "cook",
835
+ "dependent",
836
+ "security",
837
+ "chicken",
838
+ "version",
839
+ "currency",
840
+ "ladder",
841
+ "scheme",
842
+ "kitchen",
843
+ "employment",
844
+ "local",
845
+ "attention",
846
+ "manager",
847
+ "fact",
848
+ "cover",
849
+ "sad",
850
+ "guard",
851
+ "relative",
852
+ "county",
853
+ "rate",
854
+ "lunch",
855
+ "program",
856
+ "initiative",
857
+ "gear",
858
+ "bridge",
859
+ "breast",
860
+ "talk",
861
+ "dish",
862
+ "guarantee",
863
+ "beer",
864
+ "vehicle",
865
+ "reception",
866
+ "woman",
867
+ "substance",
868
+ "copy",
869
+ "lecture",
870
+ "advantage",
871
+ "park",
872
+ "cold",
873
+ "death",
874
+ "mix",
875
+ "hold",
876
+ "scale",
877
+ "tomorrow",
878
+ "blood",
879
+ "request",
880
+ "green",
881
+ "cookie",
882
+ "church",
883
+ "strip",
884
+ "forever",
885
+ "beyond",
886
+ "debt",
887
+ "tackle",
888
+ "wash",
889
+ "following",
890
+ "feel",
891
+ "maximum",
892
+ "sector",
893
+ "sea",
894
+ "property",
895
+ "economics",
896
+ "menu",
897
+ "bench",
898
+ "try",
899
+ "language",
900
+ "start",
901
+ "call",
902
+ "solid",
903
+ "address",
904
+ "income",
905
+ "foot",
906
+ "senior",
907
+ "honey",
908
+ "few",
909
+ "mixture",
910
+ "cash",
911
+ "grocery",
912
+ "link",
913
+ "map",
914
+ "form",
915
+ "factor",
916
+ "pot",
917
+ "model",
918
+ "writer",
919
+ "farm",
920
+ "winter",
921
+ "skill",
922
+ "anywhere",
923
+ "birthday",
924
+ "policy",
925
+ "release",
926
+ "husband",
927
+ "lab",
928
+ "hurry",
929
+ "mail",
930
+ "equipment",
931
+ "sink",
932
+ "pair",
933
+ "driver",
934
+ "consideration",
935
+ "leather",
936
+ "skin",
937
+ "blue",
938
+ "boat",
939
+ "sale",
940
+ "brick",
941
+ "two",
942
+ "feed",
943
+ "square",
944
+ "dot",
945
+ "rush",
946
+ "dream",
947
+ "location",
948
+ "afternoon",
949
+ "manufacturer",
950
+ "control",
951
+ "occasion",
952
+ "trouble",
953
+ "introduction",
954
+ "advice",
955
+ "bet",
956
+ "eat",
957
+ "kill",
958
+ "category",
959
+ "manner",
960
+ "office",
961
+ "estate",
962
+ "pride",
963
+ "awareness",
964
+ "slip",
965
+ "crack",
966
+ "client",
967
+ "nail",
968
+ "shoot",
969
+ "membership",
970
+ "soft",
971
+ "anybody",
972
+ "web",
973
+ "official",
974
+ "individual",
975
+ "pizza",
976
+ "interest",
977
+ "bag",
978
+ "spell",
979
+ "profession",
980
+ "queen",
981
+ "deal",
982
+ "resource",
983
+ "ship",
984
+ "guy",
985
+ "chocolate",
986
+ "joint",
987
+ "formal",
988
+ "upstairs",
989
+ "car",
990
+ "resort",
991
+ "abroad",
992
+ "dealer",
993
+ "associate",
994
+ "finger",
995
+ "surgery",
996
+ "comment",
997
+ "team",
998
+ "detail",
999
+ "crazy",
1000
+ "path",
1001
+ "tale",
1002
+ "initial",
1003
+ "arm",
1004
+ "radio",
1005
+ "demand",
1006
+ "single",
1007
+ "draw",
1008
+ "yellow",
1009
+ "contest",
1010
+ "piece",
1011
+ "quote",
1012
+ "pull",
1013
+ "commercial",
1014
+ "shirt",
1015
+ "contribution",
1016
+ "cream",
1017
+ "channel",
1018
+ "suit",
1019
+ "discipline",
1020
+ "instruction",
1021
+ "concert",
1022
+ "speech",
1023
+ "low",
1024
+ "effective",
1025
+ "hang",
1026
+ "scratch",
1027
+ "industry",
1028
+ "breakfast",
1029
+ "lay",
1030
+ "join",
1031
+ "metal",
1032
+ "bedroom",
1033
+ "minute",
1034
+ "product",
1035
+ "rest",
1036
+ "temperature",
1037
+ "many",
1038
+ "give",
1039
+ "argument",
1040
+ "print",
1041
+ "purple",
1042
+ "laugh",
1043
+ "health",
1044
+ "credit",
1045
+ "investment",
1046
+ "sell",
1047
+ "setting",
1048
+ "lesson",
1049
+ "egg",
1050
+ "middle",
1051
+ "marriage",
1052
+ "level",
1053
+ "evidence",
1054
+ "phrase",
1055
+ "love",
1056
+ "self",
1057
+ "benefit",
1058
+ "guidance",
1059
+ "affect",
1060
+ "you",
1061
+ "dad",
1062
+ "anxiety",
1063
+ "special",
1064
+ "boyfriend",
1065
+ "test",
1066
+ "blank",
1067
+ "payment",
1068
+ "soup",
1069
+ "obligation",
1070
+ "reply",
1071
+ "smile",
1072
+ "deep",
1073
+ "complaint",
1074
+ "addition",
1075
+ "review",
1076
+ "box",
1077
+ "towel",
1078
+ "minor",
1079
+ "fun",
1080
+ "soil",
1081
+ "issue",
1082
+ "cigarette",
1083
+ "internet",
1084
+ "gain",
1085
+ "tell",
1086
+ "entry",
1087
+ "spare",
1088
+ "incident",
1089
+ "family",
1090
+ "refuse",
1091
+ "branch",
1092
+ "can",
1093
+ "pen",
1094
+ "grandfather",
1095
+ "constant",
1096
+ "tank",
1097
+ "uncle",
1098
+ "climate",
1099
+ "ground",
1100
+ "volume",
1101
+ "communication",
1102
+ "kind",
1103
+ "poet",
1104
+ "child",
1105
+ "screen",
1106
+ "mine",
1107
+ "quit",
1108
+ "gene",
1109
+ "lack",
1110
+ "charity",
1111
+ "memory",
1112
+ "tooth",
1113
+ "fear",
1114
+ "mention",
1115
+ "marketing",
1116
+ "reveal",
1117
+ "reason",
1118
+ "court",
1119
+ "season",
1120
+ "freedom",
1121
+ "land",
1122
+ "sport",
1123
+ "audience",
1124
+ "classroom",
1125
+ "law",
1126
+ "hook",
1127
+ "win",
1128
+ "carry",
1129
+ "eye",
1130
+ "smell",
1131
+ "distribution",
1132
+ "research",
1133
+ "country",
1134
+ "dare",
1135
+ "hope",
1136
+ "whereas",
1137
+ "stretch",
1138
+ "library",
1139
+ "if",
1140
+ "delay",
1141
+ "college",
1142
+ "plastic",
1143
+ "book",
1144
+ "present",
1145
+ "use",
1146
+ "worry",
1147
+ "champion",
1148
+ "goal",
1149
+ "economy",
1150
+ "march",
1151
+ "election",
1152
+ "reflection",
1153
+ "midnight",
1154
+ "slide",
1155
+ "inflation",
1156
+ "action",
1157
+ "challenge",
1158
+ "guitar",
1159
+ "coast",
1160
+ "apple",
1161
+ "campaign",
1162
+ "field",
1163
+ "jacket",
1164
+ "sense",
1165
+ "way",
1166
+ "visual",
1167
+ "remove",
1168
+ "weather",
1169
+ "trash",
1170
+ "cable",
1171
+ "regret",
1172
+ "buddy",
1173
+ "beach",
1174
+ "historian",
1175
+ "courage",
1176
+ "sympathy",
1177
+ "truck",
1178
+ "tension",
1179
+ "permit",
1180
+ "nose",
1181
+ "bed",
1182
+ "son",
1183
+ "person",
1184
+ "base",
1185
+ "meat",
1186
+ "usual",
1187
+ "air",
1188
+ "meeting",
1189
+ "worth",
1190
+ "game",
1191
+ "independence",
1192
+ "physical",
1193
+ "brief",
1194
+ "play",
1195
+ "raise",
1196
+ "board",
1197
+ "she",
1198
+ "key",
1199
+ "writing",
1200
+ "pick",
1201
+ "command",
1202
+ "party",
1203
+ "yesterday",
1204
+ "spring",
1205
+ "candidate",
1206
+ "physics",
1207
+ "university",
1208
+ "concern",
1209
+ "development",
1210
+ "change",
1211
+ "string",
1212
+ "target",
1213
+ "instance",
1214
+ "room",
1215
+ "bitter",
1216
+ "bird",
1217
+ "football",
1218
+ "normal",
1219
+ "split",
1220
+ "impression",
1221
+ "wood",
1222
+ "long",
1223
+ "meaning",
1224
+ "stock",
1225
+ "cap",
1226
+ "leadership",
1227
+ "media",
1228
+ "ambition",
1229
+ "fishing",
1230
+ "essay",
1231
+ "salad",
1232
+ "repair",
1233
+ "today",
1234
+ "designer",
1235
+ "night",
1236
+ "bank",
1237
+ "drawing",
1238
+ "inevitable",
1239
+ "phase",
1240
+ "vast",
1241
+ "chip",
1242
+ "anger",
1243
+ "switch",
1244
+ "cry",
1245
+ "twist",
1246
+ "personality",
1247
+ "attempt",
1248
+ "storage",
1249
+ "being",
1250
+ "preparation",
1251
+ "bat",
1252
+ "selection",
1253
+ "white",
1254
+ "technology",
1255
+ "contract",
1256
+ "side",
1257
+ "section",
1258
+ "station",
1259
+ "till",
1260
+ "structure",
1261
+ "tongue",
1262
+ "taste",
1263
+ "truth",
1264
+ "difficulty",
1265
+ "group",
1266
+ "limit",
1267
+ "main",
1268
+ "move",
1269
+ "feeling",
1270
+ "light",
1271
+ "example",
1272
+ "mission",
1273
+ "might",
1274
+ "wait",
1275
+ "wheel",
1276
+ "shop",
1277
+ "host",
1278
+ "classic",
1279
+ "alternative",
1280
+ "cause",
1281
+ "agent",
1282
+ "consist",
1283
+ "table",
1284
+ "airline",
1285
+ "text",
1286
+ "pool",
1287
+ "craft",
1288
+ "range",
1289
+ "fuel",
1290
+ "tool",
1291
+ "partner",
1292
+ "load",
1293
+ "entrance",
1294
+ "deposit",
1295
+ "hate",
1296
+ "article",
1297
+ "video",
1298
+ "summer",
1299
+ "feature",
1300
+ "extreme",
1301
+ "mobile",
1302
+ "hospital",
1303
+ "flight",
1304
+ "fall",
1305
+ "pension",
1306
+ "piano",
1307
+ "fail",
1308
+ "result",
1309
+ "rub",
1310
+ "gap",
1311
+ "system",
1312
+ "report",
1313
+ "suck",
1314
+ "ordinary",
1315
+ "wind",
1316
+ "nerve",
1317
+ "ask",
1318
+ "shine",
1319
+ "note",
1320
+ "line",
1321
+ "mom",
1322
+ "perception",
1323
+ "brother",
1324
+ "reference",
1325
+ "bend",
1326
+ "charge",
1327
+ "treat",
1328
+ "trick",
1329
+ "term",
1330
+ "homework",
1331
+ "bake",
1332
+ "bid",
1333
+ "status",
1334
+ "project",
1335
+ "strategy",
1336
+ "orange",
1337
+ "let",
1338
+ "enthusiasm",
1339
+ "parent",
1340
+ "concentrate",
1341
+ "device",
1342
+ "travel",
1343
+ "poetry",
1344
+ "business",
1345
+ "society",
1346
+ "kiss",
1347
+ "end",
1348
+ "vegetable",
1349
+ "employ",
1350
+ "schedule",
1351
+ "hour",
1352
+ "brave",
1353
+ "focus",
1354
+ "process",
1355
+ "movie",
1356
+ "illegal",
1357
+ "general",
1358
+ "coffee",
1359
+ "ad",
1360
+ "highway",
1361
+ "chemistry",
1362
+ "psychology",
1363
+ "hire",
1364
+ "bell",
1365
+ "conference",
1366
+ "relief",
1367
+ "show",
1368
+ "neat",
1369
+ "funny",
1370
+ "weight",
1371
+ "quality",
1372
+ "club",
1373
+ "daughter",
1374
+ "zone",
1375
+ "touch",
1376
+ "tonight",
1377
+ "shock",
1378
+ "burn",
1379
+ "excuse",
1380
+ "name",
1381
+ "survey",
1382
+ "landscape",
1383
+ "advance",
1384
+ "satisfaction",
1385
+ "bread",
1386
+ "disaster",
1387
+ "item",
1388
+ "hat",
1389
+ "prior",
1390
+ "shopping",
1391
+ "visit",
1392
+ "east",
1393
+ "photo",
1394
+ "home",
1395
+ "idea",
1396
+ "father",
1397
+ "comparison",
1398
+ "cat",
1399
+ "pipe",
1400
+ "winner",
1401
+ "count",
1402
+ "lake",
1403
+ "fight",
1404
+ "prize",
1405
+ "foundation",
1406
+ "dog",
1407
+ "keep",
1408
+ "ideal",
1409
+ "fan",
1410
+ "struggle",
1411
+ "peak",
1412
+ "safety",
1413
+ "solution",
1414
+ "hell",
1415
+ "conclusion",
1416
+ "population",
1417
+ "strain",
1418
+ "alarm",
1419
+ "measurement",
1420
+ "second",
1421
+ "train",
1422
+ "race",
1423
+ "due",
1424
+ "insurance",
1425
+ "boss",
1426
+ "tree",
1427
+ "monitor",
1428
+ "sick",
1429
+ "course",
1430
+ "drag",
1431
+ "appointment",
1432
+ "slice",
1433
+ "still",
1434
+ "care",
1435
+ "patience",
1436
+ "rich",
1437
+ "escape",
1438
+ "emotion",
1439
+ "royal",
1440
+ "female",
1441
+ "childhood",
1442
+ "government",
1443
+ "picture",
1444
+ "will",
1445
+ "sock",
1446
+ "big",
1447
+ "gate",
1448
+ "oil",
1449
+ "cross",
1450
+ "pin",
1451
+ "improvement",
1452
+ "championship",
1453
+ "silly",
1454
+ "help",
1455
+ "sky",
1456
+ "pitch",
1457
+ "man",
1458
+ "diamond",
1459
+ "most",
1460
+ "transition",
1461
+ "work",
1462
+ "science",
1463
+ "committee",
1464
+ "moment",
1465
+ "fix",
1466
+ "teaching",
1467
+ "dig",
1468
+ "specialist",
1469
+ "complex",
1470
+ "guide",
1471
+ "people",
1472
+ "dead",
1473
+ "voice",
1474
+ "original",
1475
+ "break",
1476
+ "topic",
1477
+ "data",
1478
+ "degree",
1479
+ "reading",
1480
+ "recording",
1481
+ "bunch",
1482
+ "reach",
1483
+ "judgment",
1484
+ "lie",
1485
+ "regular",
1486
+ "set",
1487
+ "painting",
1488
+ "mode",
1489
+ "list",
1490
+ "player",
1491
+ "bear",
1492
+ "north",
1493
+ "wonder",
1494
+ "carpet",
1495
+ "heavy",
1496
+ "officer",
1497
+ "negative",
1498
+ "clock",
1499
+ "unique",
1500
+ "baby",
1501
+ "pain",
1502
+ "assumption",
1503
+ "disk",
1504
+ "iron",
1505
+ "bill",
1506
+ "drawer",
1507
+ "look",
1508
+ "double",
1509
+ "mistake",
1510
+ "finish",
1511
+ "future",
1512
+ "brilliant",
1513
+ "contact",
1514
+ "math",
1515
+ "rice",
1516
+ "leave",
1517
+ "restaurant",
1518
+ "discount",
1519
+ "sex",
1520
+ "virus",
1521
+ "bit",
1522
+ "trust",
1523
+ "event",
1524
+ "wear",
1525
+ "juice",
1526
+ "failure",
1527
+ "bug",
1528
+ "context",
1529
+ "mud",
1530
+ "whole",
1531
+ "wrap",
1532
+ "intention",
1533
+ "draft",
1534
+ "pressure",
1535
+ "cake",
1536
+ "dark",
1537
+ "explanation",
1538
+ "space",
1539
+ "angle",
1540
+ "word",
1541
+ "efficiency",
1542
+ "management",
1543
+ "habit",
1544
+ "star",
1545
+ "chance",
1546
+ "finding",
1547
+ "transportation",
1548
+ "stand",
1549
+ "criticism",
1550
+ "flow",
1551
+ "door",
1552
+ "injury",
1553
+ "insect",
1554
+ "surprise",
1555
+ "apartment",
1556
+ ] # pylint: disable=line-too-long
1557
+
1558
+ # ISO 639-1 codes to language names.
1559
+ LANGUAGE_CODES = immutabledict.immutabledict(
1560
+ {
1561
+ "en": "English",
1562
+ "es": "Spanish",
1563
+ "pt": "Portuguese",
1564
+ "ar": "Arabic",
1565
+ "hi": "Hindi",
1566
+ "fr": "French",
1567
+ "ru": "Russian",
1568
+ "de": "German",
1569
+ "ja": "Japanese",
1570
+ "it": "Italian",
1571
+ "bn": "Bengali",
1572
+ "uk": "Ukrainian",
1573
+ "th": "Thai",
1574
+ "ur": "Urdu",
1575
+ "ta": "Tamil",
1576
+ "te": "Telugu",
1577
+ "bg": "Bulgarian",
1578
+ "ko": "Korean",
1579
+ "pl": "Polish",
1580
+ "he": "Hebrew",
1581
+ "fa": "Persian",
1582
+ "vi": "Vietnamese",
1583
+ "ne": "Nepali",
1584
+ "sw": "Swahili",
1585
+ "kn": "Kannada",
1586
+ "mr": "Marathi",
1587
+ "gu": "Gujarati",
1588
+ "pa": "Punjabi",
1589
+ "ml": "Malayalam",
1590
+ "fi": "Finnish",
1591
+ }
1592
+ )
1593
+
1594
+ _ALPHABETS = "([A-Za-z])"
1595
+ _PREFIXES = "(Mr|St|Mrs|Ms|Dr)[.]"
1596
+ _SUFFIXES = "(Inc|Ltd|Jr|Sr|Co)"
1597
+ _STARTERS = r"(Mr|Mrs|Ms|Dr|Prof|Capt|Cpt|Lt|He\s|She\s|It\s|They\s|Their\s|Our\s|We\s|But\s|However\s|That\s|This\s|Wherever)"
1598
+ _ACRONYMS = "([A-Z][.][A-Z][.](?:[A-Z][.])?)"
1599
+ _WEBSITES = "[.](com|net|org|io|gov|edu|me)"
1600
+ _DIGITS = "([0-9])"
1601
+ _MULTIPLE_DOTS = r"\.{2,}"
1602
+
1603
+
1604
+ def split_into_sentences(text):
1605
+ """Split the text into sentences.
1606
+
1607
+ Args:
1608
+ text: A string that consists of more than or equal to one sentences.
1609
+
1610
+ Returns:
1611
+ A list of strings where each string is a sentence.
1612
+ """
1613
+ text = " " + text + " "
1614
+ text = text.replace("\n", " ")
1615
+ text = re.sub(_PREFIXES, "\\1<prd>", text)
1616
+ text = re.sub(_WEBSITES, "<prd>\\1", text)
1617
+ text = re.sub(_DIGITS + "[.]" + _DIGITS, "\\1<prd>\\2", text)
1618
+ text = re.sub(
1619
+ _MULTIPLE_DOTS,
1620
+ lambda match: "<prd>" * len(match.group(0)) + "<stop>",
1621
+ text,
1622
+ )
1623
+ if "Ph.D" in text:
1624
+ text = text.replace("Ph.D.", "Ph<prd>D<prd>")
1625
+ text = re.sub(r"\s" + _ALPHABETS + "[.] ", " \\1<prd> ", text)
1626
+ text = re.sub(_ACRONYMS + " " + _STARTERS, "\\1<stop> \\2", text)
1627
+ text = re.sub(
1628
+ _ALPHABETS + "[.]" + _ALPHABETS + "[.]" + _ALPHABETS + "[.]",
1629
+ "\\1<prd>\\2<prd>\\3<prd>",
1630
+ text,
1631
+ )
1632
+ text = re.sub(_ALPHABETS + "[.]" + _ALPHABETS + "[.]", "\\1<prd>\\2<prd>", text)
1633
+ text = re.sub(" " + _SUFFIXES + "[.] " + _STARTERS, " \\1<stop> \\2", text)
1634
+ text = re.sub(" " + _SUFFIXES + "[.]", " \\1<prd>", text)
1635
+ text = re.sub(" " + _ALPHABETS + "[.]", " \\1<prd>", text)
1636
+ if "”" in text:
1637
+ text = text.replace(".”", "”.")
1638
+ if '"' in text:
1639
+ text = text.replace('."', '".')
1640
+ if "!" in text:
1641
+ text = text.replace('!"', '"!')
1642
+ if "?" in text:
1643
+ text = text.replace('?"', '"?')
1644
+ text = text.replace(".", ".<stop>")
1645
+ text = text.replace("?", "?<stop>")
1646
+ text = text.replace("!", "!<stop>")
1647
+ text = text.replace("<prd>", ".")
1648
+ sentences = text.split("<stop>")
1649
+ sentences = [s.strip() for s in sentences]
1650
+ if sentences and not sentences[-1]:
1651
+ sentences = sentences[:-1]
1652
+ return sentences
1653
+
1654
+
1655
+ def count_words(text):
1656
+ """Counts the number of words."""
1657
+ tokenizer = nltk.tokenize.RegexpTokenizer(r"\w+")
1658
+ tokens = tokenizer.tokenize(text)
1659
+ num_words = len(tokens)
1660
+ return num_words
1661
+
1662
+
1663
+ @functools.cache
1664
+ def _get_sentence_tokenizer():
1665
+ return nltk.data.load("nltk:tokenizers/punkt/english.pickle")
1666
+
1667
+
1668
+ def count_sentences(text):
1669
+ """Count the number of sentences."""
1670
+ tokenizer = _get_sentence_tokenizer()
1671
+ tokenized_sentences = tokenizer.tokenize(text)
1672
+ return len(tokenized_sentences)
1673
+
1674
+
1675
+ def generate_keywords(num_keywords):
1676
+ """Randomly generates a few keywords."""
1677
+ return random.sample(WORD_LIST, k=num_keywords)