torchrl 0.11.0__cp314-cp314t-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cp314t-win_amd64.pyd +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/LICENSE +21 -0
- torchrl-0.11.0.dist-info/METADATA +1307 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,458 @@
|
|
|
1
|
+
// Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
//
|
|
3
|
+
// This source code is licensed under the MIT license found in the
|
|
4
|
+
// LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
#pragma once
|
|
7
|
+
|
|
8
|
+
#include <pybind11/numpy.h>
|
|
9
|
+
#include <pybind11/pybind11.h>
|
|
10
|
+
#include <torch/extension.h>
|
|
11
|
+
#include <torch/torch.h>
|
|
12
|
+
|
|
13
|
+
#include <cassert>
|
|
14
|
+
#include <cstdint>
|
|
15
|
+
#include <functional>
|
|
16
|
+
#include <limits>
|
|
17
|
+
#include <vector>
|
|
18
|
+
|
|
19
|
+
#include "numpy_utils.h"
|
|
20
|
+
#include "torch_utils.h"
|
|
21
|
+
|
|
22
|
+
namespace py = pybind11;
|
|
23
|
+
|
|
24
|
+
namespace torchrl {
|
|
25
|
+
|
|
26
|
+
// SegmentTree is a tree data structure to maintain statistics of intervals.
|
|
27
|
+
// https://en.wikipedia.org/wiki/Segment_tree
|
|
28
|
+
// Here is the implementaion of non-recursive SegmentTree for single point
|
|
29
|
+
// update and interval query. The time complexities of both Update and Query are
|
|
30
|
+
// O(logN).
|
|
31
|
+
// One example of a SegmentTree is shown below.
|
|
32
|
+
//
|
|
33
|
+
// 1: [0, 8)
|
|
34
|
+
// / \
|
|
35
|
+
// 2: [0, 4) 3: [4, 8)
|
|
36
|
+
// / \ / \
|
|
37
|
+
// 4: [0, 2) 5: [2, 4) 6: [4, 6) 7: [6, 8)
|
|
38
|
+
// / \ / \ / \ / \
|
|
39
|
+
// 8: 0 9: 1 10: 2 11: 3 12: 4 13: 5 14: 6 15: 7
|
|
40
|
+
|
|
41
|
+
template <typename T, class Operator>
|
|
42
|
+
class SegmentTree {
|
|
43
|
+
public:
|
|
44
|
+
SegmentTree(int64_t size, const T& identity_element)
|
|
45
|
+
: size_(size), identity_element_(identity_element) {
|
|
46
|
+
for (capacity_ = 1; capacity_ <= size; capacity_ <<= 1);
|
|
47
|
+
values_.assign(2 * capacity_, identity_element_);
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
int64_t size() const { return size_; }
|
|
51
|
+
|
|
52
|
+
int64_t capacity() const { return capacity_; }
|
|
53
|
+
|
|
54
|
+
const T& identity_element() const { return identity_element_; }
|
|
55
|
+
|
|
56
|
+
const T& At(int64_t index) const { return values_[index | capacity_]; }
|
|
57
|
+
|
|
58
|
+
std::vector<T> At(const std::vector<int64_t>& index) const {
|
|
59
|
+
const int64_t n = index.size();
|
|
60
|
+
std::vector<T> value(n);
|
|
61
|
+
BatchAtImpl(n, index.data(), value.data());
|
|
62
|
+
return value;
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
py::array_t<T> At(const py::array_t<int64_t>& index) const {
|
|
66
|
+
py::array_t<T> value = utils::NumpyEmptyLike<int64_t, T>(index);
|
|
67
|
+
BatchAtImpl(index.size(), index.data(), value.mutable_data());
|
|
68
|
+
return value;
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
torch::Tensor At(const torch::Tensor& index) const {
|
|
72
|
+
assert(index.dtype() == torch::kInt64);
|
|
73
|
+
const torch::Tensor index_contiguous = index.contiguous();
|
|
74
|
+
const int64_t n = index_contiguous.numel();
|
|
75
|
+
torch::Tensor value =
|
|
76
|
+
torch::empty_like(index_contiguous, utils::TorchDataType<T>::value);
|
|
77
|
+
BatchAtImpl(n, index_contiguous.data_ptr<int64_t>(), value.data_ptr<T>());
|
|
78
|
+
return value;
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
// Update the item at index to value.
|
|
82
|
+
// Time complexity: O(logN).
|
|
83
|
+
void Update(int64_t index, const T& value) {
|
|
84
|
+
index |= capacity_;
|
|
85
|
+
for (values_[index] = value; index > 1; index >>= 1) {
|
|
86
|
+
values_[index >> 1] = op_(values_[index], values_[index ^ 1]);
|
|
87
|
+
}
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
void Update(const std::vector<int64_t>& index, const T& value) {
|
|
91
|
+
BatchUpdateImpl(index.size(), index.data(), value);
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
void Update(const std::vector<int64_t>& index, const std::vector<T>& value) {
|
|
95
|
+
assert(value.size() == 1 || index.size() == value.size());
|
|
96
|
+
const int64_t n = index.size();
|
|
97
|
+
if (value.size() == 1) {
|
|
98
|
+
BatchUpdateImpl(n, index.data(), value[0]);
|
|
99
|
+
} else {
|
|
100
|
+
BatchUpdateImpl(n, index.data(), value.data());
|
|
101
|
+
}
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
void Update(const py::array_t<int64_t>& index, const T& value) {
|
|
105
|
+
BatchUpdateImpl(index.size(), index.data(), value);
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
void Update(const py::array_t<int64_t>& index, const py::array_t<T>& value) {
|
|
109
|
+
assert(value.size() == 1 || index.size() == value.size());
|
|
110
|
+
const int64_t n = index.size();
|
|
111
|
+
if (value.size() == 1) {
|
|
112
|
+
BatchUpdateImpl(n, index.data(), *(value.data()));
|
|
113
|
+
} else {
|
|
114
|
+
BatchUpdateImpl(n, index.data(), value.data());
|
|
115
|
+
}
|
|
116
|
+
}
|
|
117
|
+
|
|
118
|
+
void Update(const torch::Tensor& index, const T& value) {
|
|
119
|
+
assert(index.dtype() == torch::kInt64);
|
|
120
|
+
const torch::Tensor index_contiguous = index.contiguous();
|
|
121
|
+
const int64_t n = index_contiguous.numel();
|
|
122
|
+
BatchUpdateImpl(n, index_contiguous.data_ptr<int64_t>(), value);
|
|
123
|
+
}
|
|
124
|
+
|
|
125
|
+
void Update(const torch::Tensor& index, const torch::Tensor& value) {
|
|
126
|
+
assert(index.dtype() == torch::kInt64);
|
|
127
|
+
assert(value.dtype() == utils::TorchDataType<T>::value);
|
|
128
|
+
assert(value.numel() == 1 || index.sizes() == value.sizes());
|
|
129
|
+
const torch::Tensor index_contiguous = index.contiguous();
|
|
130
|
+
const torch::Tensor value_contiguous = value.contiguous();
|
|
131
|
+
const int64_t n = index_contiguous.numel();
|
|
132
|
+
if (value_contiguous.numel() == 1) {
|
|
133
|
+
BatchUpdateImpl(n, index_contiguous.data_ptr<int64_t>(),
|
|
134
|
+
*(value_contiguous.data_ptr<T>()));
|
|
135
|
+
} else {
|
|
136
|
+
BatchUpdateImpl(n, index_contiguous.data_ptr<int64_t>(),
|
|
137
|
+
value_contiguous.data_ptr<T>());
|
|
138
|
+
}
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
// Reduce the range of [l, r) by Operator.
|
|
142
|
+
// Time complexity: O(logN)
|
|
143
|
+
T Query(int64_t l, int64_t r) const {
|
|
144
|
+
assert(l < r);
|
|
145
|
+
if (l <= 0 && r >= size_) {
|
|
146
|
+
return values_[1];
|
|
147
|
+
}
|
|
148
|
+
T ret = identity_element_;
|
|
149
|
+
l |= capacity_;
|
|
150
|
+
r |= capacity_;
|
|
151
|
+
while (l < r) {
|
|
152
|
+
if (l & 1) {
|
|
153
|
+
ret = op_(ret, values_[l++]);
|
|
154
|
+
}
|
|
155
|
+
if (r & 1) {
|
|
156
|
+
ret = op_(ret, values_[--r]);
|
|
157
|
+
}
|
|
158
|
+
l >>= 1;
|
|
159
|
+
r >>= 1;
|
|
160
|
+
}
|
|
161
|
+
return ret;
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
std::vector<T> Query(const std::vector<int64_t>& l,
|
|
165
|
+
const std::vector<int64_t>& r) const {
|
|
166
|
+
assert(l.size() == r.size());
|
|
167
|
+
std::vector<T> ret(l.size());
|
|
168
|
+
const int64_t n = l.size();
|
|
169
|
+
BatchQueryImpl(n, l.data(), r.data(), ret.data());
|
|
170
|
+
return ret;
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
py::array_t<T> Query(const py::array_t<int64_t>& l,
|
|
174
|
+
const py::array_t<int64_t>& r) const {
|
|
175
|
+
py::array_t<T> ret = utils::NumpyEmptyLike<int64_t, T>(l);
|
|
176
|
+
BatchQueryImpl(l.size(), l.data(), r.data(), ret.mutable_data());
|
|
177
|
+
return ret;
|
|
178
|
+
}
|
|
179
|
+
|
|
180
|
+
torch::Tensor Query(const torch::Tensor& l, const torch::Tensor& r) const {
|
|
181
|
+
assert(l.dtype() == torch::kInt64);
|
|
182
|
+
assert(r.dtype() == torch::kInt64);
|
|
183
|
+
assert(l.sizes() == r.sizes());
|
|
184
|
+
const torch::Tensor l_contiguous = l.contiguous();
|
|
185
|
+
const torch::Tensor r_contiguous = r.contiguous();
|
|
186
|
+
torch::Tensor ret =
|
|
187
|
+
torch::empty_like(l_contiguous, utils::TorchDataType<T>::value);
|
|
188
|
+
const int64_t n = l_contiguous.numel();
|
|
189
|
+
BatchQueryImpl(n, l_contiguous.data_ptr<int64_t>(),
|
|
190
|
+
r_contiguous.data_ptr<int64_t>(), ret.data_ptr<T>());
|
|
191
|
+
return ret;
|
|
192
|
+
}
|
|
193
|
+
|
|
194
|
+
py::array_t<T> DumpValues() const {
|
|
195
|
+
py::array_t<T> ret(size_);
|
|
196
|
+
std::memcpy(ret.mutable_data(), values_.data() + capacity_,
|
|
197
|
+
size_ * sizeof(T));
|
|
198
|
+
return ret;
|
|
199
|
+
}
|
|
200
|
+
|
|
201
|
+
void LoadValues(const py::array_t<T>& values) {
|
|
202
|
+
assert(values.size() == size_);
|
|
203
|
+
std::memcpy(values_.data() + capacity_, values.data(), size_ * sizeof(T));
|
|
204
|
+
for (int64_t i = capacity_ - 1; i > 0; --i) {
|
|
205
|
+
values_[i] = op_(values_[(i << 1)], values_[(i << 1) | 1]);
|
|
206
|
+
}
|
|
207
|
+
}
|
|
208
|
+
|
|
209
|
+
protected:
|
|
210
|
+
void BatchAtImpl(int64_t n, const int64_t* index, T* value) const {
|
|
211
|
+
for (int64_t i = 0; i < n; ++i) {
|
|
212
|
+
value[i] = values_[index[i] | capacity_];
|
|
213
|
+
}
|
|
214
|
+
}
|
|
215
|
+
|
|
216
|
+
void BatchUpdateImpl(int64_t n, const int64_t* index, const T& value) {
|
|
217
|
+
for (int64_t i = 0; i < n; ++i) {
|
|
218
|
+
Update(index[i], value);
|
|
219
|
+
}
|
|
220
|
+
}
|
|
221
|
+
|
|
222
|
+
void BatchUpdateImpl(int64_t n, const int64_t* index, const T* value) {
|
|
223
|
+
for (int64_t i = 0; i < n; ++i) {
|
|
224
|
+
Update(index[i], value[i]);
|
|
225
|
+
}
|
|
226
|
+
}
|
|
227
|
+
|
|
228
|
+
void BatchQueryImpl(int64_t n, const int64_t* l, const int64_t* r,
|
|
229
|
+
T* result) const {
|
|
230
|
+
for (int64_t i = 0; i < n; ++i) {
|
|
231
|
+
result[i] = Query(l[i], r[i]);
|
|
232
|
+
}
|
|
233
|
+
}
|
|
234
|
+
|
|
235
|
+
const Operator op_{};
|
|
236
|
+
const int64_t size_;
|
|
237
|
+
int64_t capacity_;
|
|
238
|
+
const T identity_element_;
|
|
239
|
+
std::vector<T> values_;
|
|
240
|
+
};
|
|
241
|
+
|
|
242
|
+
template <typename T>
|
|
243
|
+
class SumSegmentTree final : public SegmentTree<T, std::plus<T>> {
|
|
244
|
+
public:
|
|
245
|
+
SumSegmentTree(int64_t size) : SegmentTree<T, std::plus<T>>(size, T(0)) {}
|
|
246
|
+
|
|
247
|
+
// Get the 1st index where the scan (prefix sum) is not less than value.
|
|
248
|
+
// Time complexity: O(logN)
|
|
249
|
+
int64_t ScanLowerBound(const T& value) const {
|
|
250
|
+
if (value > this->values_[1]) {
|
|
251
|
+
return this->size_;
|
|
252
|
+
}
|
|
253
|
+
int64_t index = 1;
|
|
254
|
+
T current_value = value;
|
|
255
|
+
while (index < this->capacity_) {
|
|
256
|
+
index <<= 1;
|
|
257
|
+
const T& lvalue = this->values_[index];
|
|
258
|
+
if (current_value > lvalue) {
|
|
259
|
+
current_value -= lvalue;
|
|
260
|
+
index |= 1;
|
|
261
|
+
}
|
|
262
|
+
}
|
|
263
|
+
return index ^ this->capacity_;
|
|
264
|
+
}
|
|
265
|
+
|
|
266
|
+
std::vector<int64_t> ScanLowerBound(const std::vector<T>& value) const {
|
|
267
|
+
std::vector<int64_t> index(value.size());
|
|
268
|
+
BatchScanLowerBoundImpl(value.size(), value.data(), index.data());
|
|
269
|
+
return index;
|
|
270
|
+
}
|
|
271
|
+
|
|
272
|
+
py::array_t<int64_t> ScanLowerBound(const py::array_t<T>& value) const {
|
|
273
|
+
py::array_t<int64_t> index = utils::NumpyEmptyLike<T, int64_t>(value);
|
|
274
|
+
BatchScanLowerBoundImpl(value.size(), value.data(), index.mutable_data());
|
|
275
|
+
return index;
|
|
276
|
+
}
|
|
277
|
+
|
|
278
|
+
torch::Tensor ScanLowerBound(const torch::Tensor& value) const {
|
|
279
|
+
assert(value.dtype() == utils::TorchDataType<T>::value);
|
|
280
|
+
const torch::Tensor value_contiguous = value.contiguous();
|
|
281
|
+
torch::Tensor index = torch::empty_like(value_contiguous, torch::kInt64);
|
|
282
|
+
const int64_t n = value_contiguous.numel();
|
|
283
|
+
BatchScanLowerBoundImpl(n, value_contiguous.data_ptr<T>(),
|
|
284
|
+
index.data_ptr<int64_t>());
|
|
285
|
+
return index;
|
|
286
|
+
}
|
|
287
|
+
|
|
288
|
+
protected:
|
|
289
|
+
void BatchScanLowerBoundImpl(int64_t n, const T* value,
|
|
290
|
+
int64_t* index) const {
|
|
291
|
+
for (int64_t i = 0; i < n; ++i) {
|
|
292
|
+
index[i] = ScanLowerBound(value[i]);
|
|
293
|
+
}
|
|
294
|
+
}
|
|
295
|
+
};
|
|
296
|
+
|
|
297
|
+
template <typename T>
|
|
298
|
+
struct MinOp {
|
|
299
|
+
T operator()(const T& lhs, const T& rhs) const { return std::min(lhs, rhs); }
|
|
300
|
+
};
|
|
301
|
+
|
|
302
|
+
template <typename T>
|
|
303
|
+
class MinSegmentTree final : public SegmentTree<T, MinOp<T>> {
|
|
304
|
+
public:
|
|
305
|
+
MinSegmentTree(int64_t size)
|
|
306
|
+
: SegmentTree<T, MinOp<T>>(size, std::numeric_limits<T>::max()) {}
|
|
307
|
+
};
|
|
308
|
+
|
|
309
|
+
template <typename T>
|
|
310
|
+
void DefineSumSegmentTree(const std::string& type, py::module& m) {
|
|
311
|
+
const std::string pyclass = "SumSegmentTree" + type;
|
|
312
|
+
py::class_<SumSegmentTree<T>, std::shared_ptr<SumSegmentTree<T>>>(
|
|
313
|
+
m, pyclass.c_str())
|
|
314
|
+
.def(py::init<int64_t>())
|
|
315
|
+
.def_property_readonly("size", &SumSegmentTree<T>::size)
|
|
316
|
+
.def_property_readonly("capacity", &SumSegmentTree<T>::capacity)
|
|
317
|
+
.def_property_readonly("identity_element",
|
|
318
|
+
&SumSegmentTree<T>::identity_element)
|
|
319
|
+
.def("__len__", &SumSegmentTree<T>::size)
|
|
320
|
+
.def("__getitem__",
|
|
321
|
+
py::overload_cast<int64_t>(&SumSegmentTree<T>::At, py::const_))
|
|
322
|
+
.def("__getitem__", py::overload_cast<const py::array_t<int64_t>&>(
|
|
323
|
+
&SumSegmentTree<T>::At, py::const_))
|
|
324
|
+
.def("__getitem__", py::overload_cast<const torch::Tensor&>(
|
|
325
|
+
&SumSegmentTree<T>::At, py::const_))
|
|
326
|
+
.def("at", py::overload_cast<int64_t>(&SumSegmentTree<T>::At, py::const_))
|
|
327
|
+
.def("at", py::overload_cast<const py::array_t<int64_t>&>(
|
|
328
|
+
&SumSegmentTree<T>::At, py::const_))
|
|
329
|
+
.def("at", py::overload_cast<const torch::Tensor&>(&SumSegmentTree<T>::At,
|
|
330
|
+
py::const_))
|
|
331
|
+
.def("__setitem__",
|
|
332
|
+
py::overload_cast<int64_t, const T&>(&SumSegmentTree<T>::Update))
|
|
333
|
+
.def("__setitem__",
|
|
334
|
+
py::overload_cast<const py::array_t<int64_t>&, const T&>(
|
|
335
|
+
&SumSegmentTree<T>::Update))
|
|
336
|
+
.def(
|
|
337
|
+
"__setitem__",
|
|
338
|
+
py::overload_cast<const py::array_t<int64_t>&, const py::array_t<T>&>(
|
|
339
|
+
&SumSegmentTree<T>::Update))
|
|
340
|
+
.def("__setitem__", py::overload_cast<const torch::Tensor&, const T&>(
|
|
341
|
+
&SumSegmentTree<T>::Update))
|
|
342
|
+
.def("__setitem__",
|
|
343
|
+
py::overload_cast<const torch::Tensor&, const torch::Tensor&>(
|
|
344
|
+
&SumSegmentTree<T>::Update))
|
|
345
|
+
.def("update",
|
|
346
|
+
py::overload_cast<int64_t, const T&>(&SumSegmentTree<T>::Update))
|
|
347
|
+
.def("update", py::overload_cast<const py::array_t<int64_t>&, const T&>(
|
|
348
|
+
&SumSegmentTree<T>::Update))
|
|
349
|
+
.def(
|
|
350
|
+
"update",
|
|
351
|
+
py::overload_cast<const py::array_t<int64_t>&, const py::array_t<T>&>(
|
|
352
|
+
&SumSegmentTree<T>::Update))
|
|
353
|
+
.def("update", py::overload_cast<const torch::Tensor&, const T&>(
|
|
354
|
+
&SumSegmentTree<T>::Update))
|
|
355
|
+
.def("update",
|
|
356
|
+
py::overload_cast<const torch::Tensor&, const torch::Tensor&>(
|
|
357
|
+
&SumSegmentTree<T>::Update))
|
|
358
|
+
.def("query", py::overload_cast<int64_t, int64_t>(
|
|
359
|
+
&SumSegmentTree<T>::Query, py::const_))
|
|
360
|
+
.def("query", py::overload_cast<const py::array_t<int64_t>&,
|
|
361
|
+
const py::array_t<int64_t>&>(
|
|
362
|
+
&SumSegmentTree<T>::Query, py::const_))
|
|
363
|
+
.def("query",
|
|
364
|
+
py::overload_cast<const torch::Tensor&, const torch::Tensor&>(
|
|
365
|
+
&SumSegmentTree<T>::Query, py::const_))
|
|
366
|
+
.def("scan_lower_bound",
|
|
367
|
+
py::overload_cast<const T&>(&SumSegmentTree<T>::ScanLowerBound,
|
|
368
|
+
py::const_))
|
|
369
|
+
.def("scan_lower_bound",
|
|
370
|
+
py::overload_cast<const py::array_t<T>&>(
|
|
371
|
+
&SumSegmentTree<T>::ScanLowerBound, py::const_))
|
|
372
|
+
.def("scan_lower_bound",
|
|
373
|
+
py::overload_cast<const torch::Tensor&>(
|
|
374
|
+
&SumSegmentTree<T>::ScanLowerBound, py::const_))
|
|
375
|
+
.def(py::pickle(
|
|
376
|
+
[](const SumSegmentTree<T>& s) {
|
|
377
|
+
return py::make_tuple(s.DumpValues());
|
|
378
|
+
},
|
|
379
|
+
[](const py::tuple& t) {
|
|
380
|
+
assert(t.size() == 1);
|
|
381
|
+
const py::array_t<T>& arr = t[0].cast<py::array_t<T>>();
|
|
382
|
+
SumSegmentTree<T> s(arr.size());
|
|
383
|
+
s.LoadValues(arr);
|
|
384
|
+
return s;
|
|
385
|
+
}));
|
|
386
|
+
}
|
|
387
|
+
|
|
388
|
+
template <typename T>
|
|
389
|
+
void DefineMinSegmentTree(const std::string& type, py::module& m) {
|
|
390
|
+
const std::string pyclass = "MinSegmentTree" + type;
|
|
391
|
+
py::class_<MinSegmentTree<T>, std::shared_ptr<MinSegmentTree<T>>>(
|
|
392
|
+
m, pyclass.c_str())
|
|
393
|
+
.def(py::init<int64_t>())
|
|
394
|
+
.def_property_readonly("size", &MinSegmentTree<T>::size)
|
|
395
|
+
.def_property_readonly("capacity", &MinSegmentTree<T>::capacity)
|
|
396
|
+
.def_property_readonly("identity_element",
|
|
397
|
+
&MinSegmentTree<T>::identity_element)
|
|
398
|
+
.def("__len__", &MinSegmentTree<T>::size)
|
|
399
|
+
.def("__getitem__",
|
|
400
|
+
py::overload_cast<int64_t>(&MinSegmentTree<T>::At, py::const_))
|
|
401
|
+
.def("__getitem__", py::overload_cast<const py::array_t<int64_t>&>(
|
|
402
|
+
&MinSegmentTree<T>::At, py::const_))
|
|
403
|
+
.def("__getitem__", py::overload_cast<const torch::Tensor&>(
|
|
404
|
+
&MinSegmentTree<T>::At, py::const_))
|
|
405
|
+
.def("at", py::overload_cast<int64_t>(&MinSegmentTree<T>::At, py::const_))
|
|
406
|
+
.def("at", py::overload_cast<const py::array_t<int64_t>&>(
|
|
407
|
+
&MinSegmentTree<T>::At, py::const_))
|
|
408
|
+
.def("at", py::overload_cast<const torch::Tensor&>(&MinSegmentTree<T>::At,
|
|
409
|
+
py::const_))
|
|
410
|
+
.def("__setitem__",
|
|
411
|
+
py::overload_cast<int64_t, const T&>(&MinSegmentTree<T>::Update))
|
|
412
|
+
.def("__setitem__",
|
|
413
|
+
py::overload_cast<const py::array_t<int64_t>&, const T&>(
|
|
414
|
+
&MinSegmentTree<T>::Update))
|
|
415
|
+
.def(
|
|
416
|
+
"__setitem__",
|
|
417
|
+
py::overload_cast<const py::array_t<int64_t>&, const py::array_t<T>&>(
|
|
418
|
+
&MinSegmentTree<T>::Update))
|
|
419
|
+
.def("__setitem__", py::overload_cast<const torch::Tensor&, const T&>(
|
|
420
|
+
&MinSegmentTree<T>::Update))
|
|
421
|
+
.def("__setitem__",
|
|
422
|
+
py::overload_cast<const torch::Tensor&, const torch::Tensor&>(
|
|
423
|
+
&MinSegmentTree<T>::Update))
|
|
424
|
+
.def("update",
|
|
425
|
+
py::overload_cast<int64_t, const T&>(&MinSegmentTree<T>::Update))
|
|
426
|
+
.def("update", py::overload_cast<const py::array_t<int64_t>&, const T&>(
|
|
427
|
+
&MinSegmentTree<T>::Update))
|
|
428
|
+
.def(
|
|
429
|
+
"update",
|
|
430
|
+
py::overload_cast<const py::array_t<int64_t>&, const py::array_t<T>&>(
|
|
431
|
+
&MinSegmentTree<T>::Update))
|
|
432
|
+
.def("update", py::overload_cast<const torch::Tensor&, const T&>(
|
|
433
|
+
&MinSegmentTree<T>::Update))
|
|
434
|
+
.def("update",
|
|
435
|
+
py::overload_cast<const torch::Tensor&, const torch::Tensor&>(
|
|
436
|
+
&MinSegmentTree<T>::Update))
|
|
437
|
+
.def("query", py::overload_cast<int64_t, int64_t>(
|
|
438
|
+
&MinSegmentTree<T>::Query, py::const_))
|
|
439
|
+
.def("query", py::overload_cast<const py::array_t<int64_t>&,
|
|
440
|
+
const py::array_t<int64_t>&>(
|
|
441
|
+
&MinSegmentTree<T>::Query, py::const_))
|
|
442
|
+
.def("query",
|
|
443
|
+
py::overload_cast<const torch::Tensor&, const torch::Tensor&>(
|
|
444
|
+
&MinSegmentTree<T>::Query, py::const_))
|
|
445
|
+
.def(py::pickle(
|
|
446
|
+
[](const MinSegmentTree<T>& s) {
|
|
447
|
+
return py::make_tuple(s.DumpValues());
|
|
448
|
+
},
|
|
449
|
+
[](const py::tuple& t) {
|
|
450
|
+
assert(t.size() == 1);
|
|
451
|
+
const py::array_t<T>& arr = t[0].cast<py::array_t<T>>();
|
|
452
|
+
MinSegmentTree<T> s(arr.size());
|
|
453
|
+
s.LoadValues(arr);
|
|
454
|
+
return s;
|
|
455
|
+
}));
|
|
456
|
+
}
|
|
457
|
+
|
|
458
|
+
} // namespace torchrl
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
// Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
//
|
|
3
|
+
// This source code is licensed under the MIT license found in the
|
|
4
|
+
// LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
#pragma once
|
|
7
|
+
|
|
8
|
+
#include <torch/torch.h>
|
|
9
|
+
|
|
10
|
+
#include <cstdint>
|
|
11
|
+
|
|
12
|
+
namespace torchrl {
|
|
13
|
+
namespace utils {
|
|
14
|
+
|
|
15
|
+
template <typename T>
|
|
16
|
+
struct TorchDataType;
|
|
17
|
+
|
|
18
|
+
template <>
|
|
19
|
+
struct TorchDataType<int64_t> {
|
|
20
|
+
static constexpr torch::ScalarType value = torch::kInt64;
|
|
21
|
+
};
|
|
22
|
+
|
|
23
|
+
template <>
|
|
24
|
+
struct TorchDataType<float> {
|
|
25
|
+
static constexpr torch::ScalarType value = torch::kFloat;
|
|
26
|
+
};
|
|
27
|
+
|
|
28
|
+
template <>
|
|
29
|
+
struct TorchDataType<double> {
|
|
30
|
+
static constexpr torch::ScalarType value = torch::kDouble;
|
|
31
|
+
};
|
|
32
|
+
|
|
33
|
+
} // namespace utils
|
|
34
|
+
} // namespace torchrl
|
torchrl/csrc/utils.cpp
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
// Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
//
|
|
3
|
+
// This source code is licensed under the MIT license found in the
|
|
4
|
+
// LICENSE file in the root directory of this source tree.
|
|
5
|
+
// utils.h
|
|
6
|
+
#include "utils.h"
|
|
7
|
+
|
|
8
|
+
#include <iostream>
|
|
9
|
+
torch::Tensor safetanh(torch::Tensor input, float eps) {
|
|
10
|
+
return SafeTanh::apply(input, eps);
|
|
11
|
+
}
|
|
12
|
+
torch::Tensor safeatanh(torch::Tensor input, float eps) {
|
|
13
|
+
return SafeInvTanh::apply(input, eps);
|
|
14
|
+
}
|
|
15
|
+
torch::Tensor SafeTanh::forward(torch::autograd::AutogradContext* ctx,
|
|
16
|
+
torch::Tensor input, float eps) {
|
|
17
|
+
auto out = torch::tanh(input);
|
|
18
|
+
auto lim = 1.0 - eps;
|
|
19
|
+
out = out.clamp(-lim, lim);
|
|
20
|
+
ctx->save_for_backward({out});
|
|
21
|
+
return out;
|
|
22
|
+
}
|
|
23
|
+
torch::autograd::tensor_list SafeTanh::backward(
|
|
24
|
+
torch::autograd::AutogradContext* ctx,
|
|
25
|
+
torch::autograd::tensor_list grad_outputs) {
|
|
26
|
+
auto saved = ctx->get_saved_variables();
|
|
27
|
+
auto out = saved[0];
|
|
28
|
+
auto go = grad_outputs[0];
|
|
29
|
+
auto grad = go * (1 - out * out);
|
|
30
|
+
return {grad, torch::Tensor()};
|
|
31
|
+
}
|
|
32
|
+
torch::Tensor SafeInvTanh::forward(torch::autograd::AutogradContext* ctx,
|
|
33
|
+
torch::Tensor input, float eps) {
|
|
34
|
+
auto lim = 1.0 - eps;
|
|
35
|
+
auto intermediate = input.clamp(-lim, lim);
|
|
36
|
+
ctx->save_for_backward({intermediate});
|
|
37
|
+
auto out = torch::atanh(intermediate);
|
|
38
|
+
return out;
|
|
39
|
+
}
|
|
40
|
+
torch::autograd::tensor_list SafeInvTanh::backward(
|
|
41
|
+
torch::autograd::AutogradContext* ctx,
|
|
42
|
+
torch::autograd::tensor_list grad_outputs) {
|
|
43
|
+
auto saved = ctx->get_saved_variables();
|
|
44
|
+
auto input = saved[0];
|
|
45
|
+
auto go = grad_outputs[0];
|
|
46
|
+
auto grad = go / (1 - input * input);
|
|
47
|
+
return {grad, torch::Tensor()};
|
|
48
|
+
}
|
torchrl/csrc/utils.h
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
// Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
//
|
|
3
|
+
// This source code is licensed under the MIT license found in the
|
|
4
|
+
// LICENSE file in the root directory of this source tree.
|
|
5
|
+
// utils.h
|
|
6
|
+
|
|
7
|
+
#pragma once
|
|
8
|
+
|
|
9
|
+
#include <torch/extension.h>
|
|
10
|
+
#include <torch/torch.h>
|
|
11
|
+
|
|
12
|
+
torch::Tensor safetanh(torch::Tensor input, float eps = 1e-6);
|
|
13
|
+
torch::Tensor safeatanh(torch::Tensor input, float eps = 1e-6);
|
|
14
|
+
|
|
15
|
+
class SafeTanh : public torch::autograd::Function<SafeTanh> {
|
|
16
|
+
public:
|
|
17
|
+
static torch::Tensor forward(torch::autograd::AutogradContext* ctx,
|
|
18
|
+
torch::Tensor input, float eps);
|
|
19
|
+
static torch::autograd::tensor_list backward(
|
|
20
|
+
torch::autograd::AutogradContext* ctx,
|
|
21
|
+
torch::autograd::tensor_list grad_outputs);
|
|
22
|
+
};
|
|
23
|
+
|
|
24
|
+
class SafeInvTanh : public torch::autograd::Function<SafeInvTanh> {
|
|
25
|
+
public:
|
|
26
|
+
static torch::Tensor forward(torch::autograd::AutogradContext* ctx,
|
|
27
|
+
torch::Tensor input, float eps);
|
|
28
|
+
static torch::autograd::tensor_list backward(
|
|
29
|
+
torch::autograd::AutogradContext* ctx,
|
|
30
|
+
torch::autograd::tensor_list grad_outputs);
|
|
31
|
+
};
|