prpl_utils 0.0.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- prpl_utils-0.0.1/LICENSE +21 -0
- prpl_utils-0.0.1/PKG-INFO +49 -0
- prpl_utils-0.0.1/README.md +26 -0
- prpl_utils-0.0.1/pyproject.toml +65 -0
- prpl_utils-0.0.1/setup.cfg +4 -0
- prpl_utils-0.0.1/src/prpl_utils/__init__.py +0 -0
- prpl_utils-0.0.1/src/prpl_utils/gym_agent.py +77 -0
- prpl_utils-0.0.1/src/prpl_utils/gym_utils.py +105 -0
- prpl_utils-0.0.1/src/prpl_utils/motion_planning.py +222 -0
- prpl_utils-0.0.1/src/prpl_utils/pddl_planning.py +124 -0
- prpl_utils-0.0.1/src/prpl_utils/py.typed +0 -0
- prpl_utils-0.0.1/src/prpl_utils/search.py +354 -0
- prpl_utils-0.0.1/src/prpl_utils/spaces.py +71 -0
- prpl_utils-0.0.1/src/prpl_utils/structs.py +8 -0
- prpl_utils-0.0.1/src/prpl_utils/utils.py +208 -0
- prpl_utils-0.0.1/src/prpl_utils.egg-info/PKG-INFO +49 -0
- prpl_utils-0.0.1/src/prpl_utils.egg-info/SOURCES.txt +24 -0
- prpl_utils-0.0.1/src/prpl_utils.egg-info/dependency_links.txt +1 -0
- prpl_utils-0.0.1/src/prpl_utils.egg-info/requires.txt +15 -0
- prpl_utils-0.0.1/src/prpl_utils.egg-info/top_level.txt +1 -0
- prpl_utils-0.0.1/tests/test_gym_utils.py +60 -0
- prpl_utils-0.0.1/tests/test_motion_planning.py +30 -0
- prpl_utils-0.0.1/tests/test_pddl_planning.py +69 -0
- prpl_utils-0.0.1/tests/test_search.py +607 -0
- prpl_utils-0.0.1/tests/test_spaces.py +31 -0
- prpl_utils-0.0.1/tests/test_utils.py +108 -0
prpl_utils-0.0.1/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2024 Tom Silver
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: prpl_utils
|
|
3
|
+
Version: 0.0.1
|
|
4
|
+
Summary: Common Python utilities for the Princeton Robot Planning and Learning lab.
|
|
5
|
+
Requires-Python: >=3.10
|
|
6
|
+
Description-Content-Type: text/markdown
|
|
7
|
+
License-File: LICENSE
|
|
8
|
+
Requires-Dist: matplotlib
|
|
9
|
+
Requires-Dist: numpy
|
|
10
|
+
Requires-Dist: graphviz
|
|
11
|
+
Requires-Dist: pyperplan
|
|
12
|
+
Requires-Dist: gymnasium
|
|
13
|
+
Requires-Dist: moviepy
|
|
14
|
+
Provides-Extra: develop
|
|
15
|
+
Requires-Dist: black; extra == "develop"
|
|
16
|
+
Requires-Dist: docformatter; extra == "develop"
|
|
17
|
+
Requires-Dist: isort; extra == "develop"
|
|
18
|
+
Requires-Dist: mypy; extra == "develop"
|
|
19
|
+
Requires-Dist: pylint>=2.14.5; extra == "develop"
|
|
20
|
+
Requires-Dist: pytest-pylint>=0.18.0; extra == "develop"
|
|
21
|
+
Requires-Dist: pytest>=7.2.2; extra == "develop"
|
|
22
|
+
Dynamic: license-file
|
|
23
|
+
|
|
24
|
+
# PRPL Utils
|
|
25
|
+
|
|
26
|
+
Miscellaneous Python utilities from the Princeton Robot Planning and Learning group.
|
|
27
|
+
- **Motion planning**: RRT, BiRRT
|
|
28
|
+
- **PDDL planning**: interfaces to pyperplan, Fast Downward
|
|
29
|
+
- **Heuristic search**: A*, GBFS, hill-climbing, policy-guided A*
|
|
30
|
+
- **Gymnasium**: agent interface, helper spaces
|
|
31
|
+
- **Other**: a few other miscellaneous utils
|
|
32
|
+
|
|
33
|
+
## Requirements
|
|
34
|
+
|
|
35
|
+
- Python 3.10+
|
|
36
|
+
- Tested on MacOS Monterey and Ubuntu 22.04
|
|
37
|
+
|
|
38
|
+
## Installation
|
|
39
|
+
|
|
40
|
+
1. Recommended: create and source a virtualenv.
|
|
41
|
+
2. `pip install -e ".[develop]"`
|
|
42
|
+
|
|
43
|
+
## Check Installation
|
|
44
|
+
|
|
45
|
+
Run `./run_ci_checks.sh`. It should complete with all green successes in 5-10 seconds.
|
|
46
|
+
|
|
47
|
+
## Contributing
|
|
48
|
+
|
|
49
|
+
Pull requests welcome. Note that this is meant to be a **lightweight** package. For example, it should be safe to use in homework assignments with minimal assumptions about the user's setup. Do not add heavy-duty dependencies.
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
# PRPL Utils
|
|
2
|
+
|
|
3
|
+
Miscellaneous Python utilities from the Princeton Robot Planning and Learning group.
|
|
4
|
+
- **Motion planning**: RRT, BiRRT
|
|
5
|
+
- **PDDL planning**: interfaces to pyperplan, Fast Downward
|
|
6
|
+
- **Heuristic search**: A*, GBFS, hill-climbing, policy-guided A*
|
|
7
|
+
- **Gymnasium**: agent interface, helper spaces
|
|
8
|
+
- **Other**: a few other miscellaneous utils
|
|
9
|
+
|
|
10
|
+
## Requirements
|
|
11
|
+
|
|
12
|
+
- Python 3.10+
|
|
13
|
+
- Tested on MacOS Monterey and Ubuntu 22.04
|
|
14
|
+
|
|
15
|
+
## Installation
|
|
16
|
+
|
|
17
|
+
1. Recommended: create and source a virtualenv.
|
|
18
|
+
2. `pip install -e ".[develop]"`
|
|
19
|
+
|
|
20
|
+
## Check Installation
|
|
21
|
+
|
|
22
|
+
Run `./run_ci_checks.sh`. It should complete with all green successes in 5-10 seconds.
|
|
23
|
+
|
|
24
|
+
## Contributing
|
|
25
|
+
|
|
26
|
+
Pull requests welcome. Note that this is meant to be a **lightweight** package. For example, it should be safe to use in homework assignments with minimal assumptions about the user's setup. Do not add heavy-duty dependencies.
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "prpl_utils"
|
|
7
|
+
version = "0.0.1"
|
|
8
|
+
description = "Common Python utilities for the Princeton Robot Planning and Learning lab."
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
requires-python = ">=3.10"
|
|
11
|
+
dependencies = [
|
|
12
|
+
"matplotlib",
|
|
13
|
+
"numpy",
|
|
14
|
+
"graphviz",
|
|
15
|
+
"pyperplan",
|
|
16
|
+
"gymnasium",
|
|
17
|
+
"moviepy",
|
|
18
|
+
]
|
|
19
|
+
|
|
20
|
+
[project.optional-dependencies]
|
|
21
|
+
develop = [
|
|
22
|
+
"black",
|
|
23
|
+
"docformatter",
|
|
24
|
+
"isort",
|
|
25
|
+
"mypy",
|
|
26
|
+
"pylint>=2.14.5",
|
|
27
|
+
"pytest-pylint>=0.18.0",
|
|
28
|
+
"pytest>=7.2.2",
|
|
29
|
+
]
|
|
30
|
+
|
|
31
|
+
[tool.setuptools.packages.find]
|
|
32
|
+
where = ["src"]
|
|
33
|
+
|
|
34
|
+
[tool.setuptools.package-data]
|
|
35
|
+
prpl_utils = ["py.typed"]
|
|
36
|
+
|
|
37
|
+
[tool.black]
|
|
38
|
+
line-length = 88
|
|
39
|
+
target-version = ["py310"]
|
|
40
|
+
|
|
41
|
+
[tool.isort]
|
|
42
|
+
py_version = 310
|
|
43
|
+
profile = "black"
|
|
44
|
+
multi_line_output = 2
|
|
45
|
+
skip_glob = ["venv/*", ".venv/*", "build", "dist", "third_party", "third-party", "__pycache__/"]
|
|
46
|
+
split_on_trailing_comma = true
|
|
47
|
+
|
|
48
|
+
[tool.docformatter]
|
|
49
|
+
line-length = 88
|
|
50
|
+
wrap-summaries = 88
|
|
51
|
+
wrap-descriptions = 88
|
|
52
|
+
|
|
53
|
+
[tool.mypy]
|
|
54
|
+
strict_equality = true
|
|
55
|
+
disallow_untyped_calls = true
|
|
56
|
+
warn_unreachable = true
|
|
57
|
+
exclude = ["venv/*", ".venv/*", "build", "dist", "third_party", "third-party", "__pycache__/"]
|
|
58
|
+
|
|
59
|
+
[[tool.mypy.overrides]]
|
|
60
|
+
module = [
|
|
61
|
+
"matplotlib.*",
|
|
62
|
+
"pyperplan.*",
|
|
63
|
+
"graphviz.*",
|
|
64
|
+
]
|
|
65
|
+
ignore_missing_imports = true
|
|
File without changes
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
"""A basic agent interface compatible with gym envs."""
|
|
2
|
+
|
|
3
|
+
import abc
|
|
4
|
+
from typing import Any, Generic, TypeVar
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
_ObsType = TypeVar("_ObsType")
|
|
9
|
+
_ActType = TypeVar("_ActType")
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class Agent(Generic[_ObsType, _ActType]):
|
|
13
|
+
"""Base class for a sequential decision-making agent."""
|
|
14
|
+
|
|
15
|
+
def __init__(self, seed: int) -> None:
|
|
16
|
+
self._rng = np.random.default_rng(seed)
|
|
17
|
+
self._last_observation: _ObsType | None = None
|
|
18
|
+
self._last_action: _ActType | None = None
|
|
19
|
+
self._last_info: dict[str, Any] | None = None
|
|
20
|
+
self._timestep: int = 0
|
|
21
|
+
self._train_or_eval = "eval"
|
|
22
|
+
|
|
23
|
+
@abc.abstractmethod
|
|
24
|
+
def _get_action(self) -> _ActType:
|
|
25
|
+
"""Produce an action to execute now."""
|
|
26
|
+
|
|
27
|
+
def _learn_from_transition(
|
|
28
|
+
self,
|
|
29
|
+
obs: _ObsType,
|
|
30
|
+
act: _ActType,
|
|
31
|
+
next_obs: _ObsType,
|
|
32
|
+
reward: float,
|
|
33
|
+
done: bool,
|
|
34
|
+
info: dict[str, Any],
|
|
35
|
+
) -> None:
|
|
36
|
+
"""Update any internal models based on the observed transition."""
|
|
37
|
+
|
|
38
|
+
def reset(
|
|
39
|
+
self,
|
|
40
|
+
obs: _ObsType,
|
|
41
|
+
info: dict[str, Any],
|
|
42
|
+
) -> None:
|
|
43
|
+
"""Start a new episode."""
|
|
44
|
+
self._last_observation = obs
|
|
45
|
+
self._last_info = info
|
|
46
|
+
self._timestep = 0
|
|
47
|
+
|
|
48
|
+
def step(self) -> _ActType:
|
|
49
|
+
"""Get the next action to take."""
|
|
50
|
+
self._last_action = self._get_action()
|
|
51
|
+
self._timestep += 1
|
|
52
|
+
return self._last_action
|
|
53
|
+
|
|
54
|
+
def update(
|
|
55
|
+
self, obs: _ObsType, reward: float, done: bool, info: dict[str, Any]
|
|
56
|
+
) -> None:
|
|
57
|
+
"""Record the reward and next observation following an action."""
|
|
58
|
+
assert self._last_observation is not None
|
|
59
|
+
assert self._last_action is not None
|
|
60
|
+
if self._train_or_eval == "train":
|
|
61
|
+
self._learn_from_transition(
|
|
62
|
+
self._last_observation, self._last_action, obs, reward, done, info
|
|
63
|
+
)
|
|
64
|
+
self._last_observation = obs
|
|
65
|
+
self._last_info = info
|
|
66
|
+
|
|
67
|
+
def seed(self, seed: int) -> None:
|
|
68
|
+
"""Reset the random number generator."""
|
|
69
|
+
self._rng = np.random.default_rng(seed)
|
|
70
|
+
|
|
71
|
+
def train(self) -> None:
|
|
72
|
+
"""Switch to train mode."""
|
|
73
|
+
self._train_or_eval = "train"
|
|
74
|
+
|
|
75
|
+
def eval(self) -> None:
|
|
76
|
+
"""Switch to eval mode."""
|
|
77
|
+
self._train_or_eval = "eval"
|
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
"""Utilities for Gym/Gymnasium environment compatibility and wrappers."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Any, Callable
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
from gymnasium import spaces
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class GymToGymnasium:
|
|
12
|
+
"""Wrap a legacy Gym environment to present the Gymnasium API.
|
|
13
|
+
|
|
14
|
+
- Maps `reset()` → `(obs, info)`; seeds via `env.seed(seed)` if available.
|
|
15
|
+
- Maps `step()` → `(obs, reward, terminated, truncated, info)`; derives
|
|
16
|
+
`truncated` from `info.get('TimeLimit.truncated', False)`.
|
|
17
|
+
- Forwards `observation_space`, `action_space`, `spec`, `render()`, `close()`.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def __init__(self, base_env: Any) -> None:
|
|
21
|
+
"""Wrap a legacy gym environment."""
|
|
22
|
+
self._env = base_env
|
|
23
|
+
self.observation_space: spaces.Space | None = getattr(
|
|
24
|
+
base_env, "observation_space", None
|
|
25
|
+
)
|
|
26
|
+
self.action_space: spaces.Space | None = getattr(base_env, "action_space", None)
|
|
27
|
+
self.spec = getattr(base_env, "spec", None)
|
|
28
|
+
|
|
29
|
+
@property
|
|
30
|
+
def unwrapped(self) -> Any:
|
|
31
|
+
"""Get the underlying unwrapped environment."""
|
|
32
|
+
return getattr(self._env, "unwrapped", self._env)
|
|
33
|
+
|
|
34
|
+
def reset(self, *, seed: int | None = None) -> tuple[Any, dict[str, Any]]:
|
|
35
|
+
"""Reset environment and return (obs, info)."""
|
|
36
|
+
if seed is not None:
|
|
37
|
+
if hasattr(self._env, "seed") and callable(self._env.seed):
|
|
38
|
+
try:
|
|
39
|
+
self._env.seed(seed)
|
|
40
|
+
except (AttributeError, TypeError):
|
|
41
|
+
# Legacy env has broken seed implementation, continue without seeding
|
|
42
|
+
_ = None
|
|
43
|
+
res = self._env.reset()
|
|
44
|
+
if isinstance(res, tuple) and len(res) == 2:
|
|
45
|
+
return res # already (obs, info)
|
|
46
|
+
return res, {}
|
|
47
|
+
|
|
48
|
+
def step(self, action: Any) -> tuple[Any, float, bool, bool, dict[str, Any]]:
|
|
49
|
+
"""Step environment and return (obs, reward, terminated, truncated, info)."""
|
|
50
|
+
result = self._env.step(action)
|
|
51
|
+
if isinstance(result, tuple):
|
|
52
|
+
if len(result) == 4:
|
|
53
|
+
obs, reward, done, info = result
|
|
54
|
+
terminated = bool(done)
|
|
55
|
+
truncated = bool(info.get("TimeLimit.truncated", False))
|
|
56
|
+
return obs, float(reward), terminated, truncated, info
|
|
57
|
+
if len(result) == 5:
|
|
58
|
+
obs, reward, terminated, truncated, info = result
|
|
59
|
+
return obs, float(reward), bool(terminated), bool(truncated), info
|
|
60
|
+
raise ValueError("Unexpected number of values returned from env.step")
|
|
61
|
+
|
|
62
|
+
def render(self) -> Any:
|
|
63
|
+
"""Render the environment."""
|
|
64
|
+
if hasattr(self._env, "render"):
|
|
65
|
+
try:
|
|
66
|
+
return self._env.render()
|
|
67
|
+
except TypeError:
|
|
68
|
+
return self._env.render(mode="human")
|
|
69
|
+
return None
|
|
70
|
+
|
|
71
|
+
def close(self) -> None:
|
|
72
|
+
"""Close the environment."""
|
|
73
|
+
getattr(self._env, "close", lambda: None)()
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def patch_box_float32() -> Callable[..., None]:
|
|
77
|
+
"""Resolves Gymnasium Box precision warnings by patching space creation to use
|
|
78
|
+
float32 dtypes from the start, eliminating the need for runtime casting that
|
|
79
|
+
triggers "precision lowered" warnings from environment libraries like KinDER (KinDER
|
|
80
|
+
creates box spaces with float64 bounds).
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
The original Box.__init__ method for restoration.
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
original_box_init = spaces.Box.__init__
|
|
87
|
+
|
|
88
|
+
def patched_box_init(self, low, high, shape=None, dtype=np.float32, seed=None):
|
|
89
|
+
# Convert bounds to float32 if they're float64 to avoid warnings
|
|
90
|
+
if hasattr(low, "dtype") and low.dtype == np.float64:
|
|
91
|
+
low = low.astype(np.float32)
|
|
92
|
+
if hasattr(high, "dtype") and high.dtype == np.float64:
|
|
93
|
+
high = high.astype(np.float32)
|
|
94
|
+
|
|
95
|
+
# Force dtype to float32 for floating point types
|
|
96
|
+
if dtype == np.float64:
|
|
97
|
+
dtype = np.float32
|
|
98
|
+
|
|
99
|
+
return original_box_init(self, low, high, shape, dtype, seed)
|
|
100
|
+
|
|
101
|
+
spaces.Box.__init__ = patched_box_init # type: ignore
|
|
102
|
+
return original_box_init
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
__all__ = ["GymToGymnasium", "patch_box_float32"]
|
|
@@ -0,0 +1,222 @@
|
|
|
1
|
+
"""Motion planning utilities."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import functools
|
|
6
|
+
from typing import Callable, Generic, Iterable, Optional, TypeVar
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
|
|
10
|
+
_RRTState = TypeVar("_RRTState")
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class RRT(Generic[_RRTState]):
|
|
14
|
+
"""Rapidly-exploring random tree."""
|
|
15
|
+
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
sample_fn: Callable[[_RRTState], _RRTState],
|
|
19
|
+
extend_fn: Callable[[_RRTState, _RRTState], Iterable[_RRTState]],
|
|
20
|
+
collision_fn: Callable[[_RRTState], bool],
|
|
21
|
+
distance_fn: Callable[[_RRTState, _RRTState], float],
|
|
22
|
+
rng: np.random.Generator,
|
|
23
|
+
num_attempts: int,
|
|
24
|
+
num_iters: int,
|
|
25
|
+
smooth_amt: int,
|
|
26
|
+
):
|
|
27
|
+
self._sample_fn = sample_fn
|
|
28
|
+
self._extend_fn = extend_fn
|
|
29
|
+
self._collision_fn = collision_fn
|
|
30
|
+
self._distance_fn = distance_fn
|
|
31
|
+
self._rng = rng
|
|
32
|
+
self._num_attempts = num_attempts
|
|
33
|
+
self._num_iters = num_iters
|
|
34
|
+
self._smooth_amt = smooth_amt
|
|
35
|
+
|
|
36
|
+
def query(
|
|
37
|
+
self, pt1: _RRTState, pt2: _RRTState, sample_goal_eps: float = 0.0
|
|
38
|
+
) -> Optional[list[_RRTState]]:
|
|
39
|
+
"""Query the RRT, to get a collision-free path from pt1 to pt2.
|
|
40
|
+
|
|
41
|
+
If none is found, returns None.
|
|
42
|
+
"""
|
|
43
|
+
if self._collision_fn(pt1) or self._collision_fn(pt2):
|
|
44
|
+
return None
|
|
45
|
+
direct_path = self.try_direct_path(pt1, pt2)
|
|
46
|
+
if direct_path is not None:
|
|
47
|
+
return direct_path
|
|
48
|
+
for _ in range(self._num_attempts):
|
|
49
|
+
path = self._rrt_connect(
|
|
50
|
+
pt1, goal_sampler=lambda: pt2, sample_goal_eps=sample_goal_eps
|
|
51
|
+
)
|
|
52
|
+
if path is not None:
|
|
53
|
+
return self._smooth_path(path)
|
|
54
|
+
return None
|
|
55
|
+
|
|
56
|
+
def query_to_goal_fn(
|
|
57
|
+
self,
|
|
58
|
+
start: _RRTState,
|
|
59
|
+
goal_fn: Callable[[_RRTState], bool],
|
|
60
|
+
goal_sampler: Callable[[], _RRTState] | None = None,
|
|
61
|
+
sample_goal_eps: float = 0.0,
|
|
62
|
+
) -> Optional[list[_RRTState]]:
|
|
63
|
+
"""Query the RRT, to get a collision-free path from start to a point such that
|
|
64
|
+
goal_fn(point) is True. Uses goal_sampler to sample a target for a direct path
|
|
65
|
+
or with probability sample_goal_eps.
|
|
66
|
+
|
|
67
|
+
If none is found, returns None.
|
|
68
|
+
"""
|
|
69
|
+
assert sample_goal_eps == 0.0 or goal_sampler is not None
|
|
70
|
+
if self._collision_fn(start):
|
|
71
|
+
return None
|
|
72
|
+
if goal_sampler:
|
|
73
|
+
direct_path = self.try_direct_path(start, goal_sampler())
|
|
74
|
+
if direct_path is not None:
|
|
75
|
+
return direct_path
|
|
76
|
+
for _ in range(self._num_attempts):
|
|
77
|
+
path = self._rrt_connect(
|
|
78
|
+
start, goal_sampler, goal_fn, sample_goal_eps=sample_goal_eps
|
|
79
|
+
)
|
|
80
|
+
if path is not None:
|
|
81
|
+
return self._smooth_path(path)
|
|
82
|
+
return None
|
|
83
|
+
|
|
84
|
+
def try_direct_path(
|
|
85
|
+
self, pt1: _RRTState, pt2: _RRTState
|
|
86
|
+
) -> Optional[list[_RRTState]]:
|
|
87
|
+
"""Attempt to plan a direct path from pt1 to pt2, returning None if collision-
|
|
88
|
+
free path can be found."""
|
|
89
|
+
path = [pt1]
|
|
90
|
+
for newpt in self._extend_fn(pt1, pt2):
|
|
91
|
+
if self._collision_fn(newpt):
|
|
92
|
+
return None
|
|
93
|
+
path.append(newpt)
|
|
94
|
+
return path
|
|
95
|
+
|
|
96
|
+
def _rrt_connect(
|
|
97
|
+
self,
|
|
98
|
+
pt1: _RRTState,
|
|
99
|
+
goal_sampler: Callable[[], _RRTState] | None = None,
|
|
100
|
+
goal_fn: Callable[[_RRTState], bool] | None = None,
|
|
101
|
+
sample_goal_eps: float = 0.0,
|
|
102
|
+
) -> Optional[list[_RRTState]]:
|
|
103
|
+
root = _RRTNode(pt1)
|
|
104
|
+
nodes = [root]
|
|
105
|
+
|
|
106
|
+
for _ in range(self._num_iters):
|
|
107
|
+
# Sample the goal with a small probability, otherwise randomly
|
|
108
|
+
# choose a point.
|
|
109
|
+
sample_goal = self._rng.random() < sample_goal_eps
|
|
110
|
+
if sample_goal:
|
|
111
|
+
assert goal_sampler is not None
|
|
112
|
+
samp = goal_sampler()
|
|
113
|
+
else:
|
|
114
|
+
samp = self._sample_fn(pt1)
|
|
115
|
+
min_key = functools.partial(self._get_pt_dist_to_node, samp)
|
|
116
|
+
nearest = min(nodes, key=min_key)
|
|
117
|
+
reached_goal = False
|
|
118
|
+
for newpt in self._extend_fn(nearest.data, samp):
|
|
119
|
+
if self._collision_fn(newpt):
|
|
120
|
+
break
|
|
121
|
+
nearest = _RRTNode(newpt, parent=nearest)
|
|
122
|
+
nodes.append(nearest)
|
|
123
|
+
else:
|
|
124
|
+
reached_goal = sample_goal
|
|
125
|
+
# Check goal_fn if defined
|
|
126
|
+
if reached_goal or goal_fn is not None and goal_fn(nearest.data):
|
|
127
|
+
path = nearest.path_from_root()
|
|
128
|
+
return [node.data for node in path]
|
|
129
|
+
return None
|
|
130
|
+
|
|
131
|
+
def _get_pt_dist_to_node(self, pt: _RRTState, node: _RRTNode[_RRTState]) -> float:
|
|
132
|
+
return self._distance_fn(pt, node.data)
|
|
133
|
+
|
|
134
|
+
def _smooth_path(self, path: list[_RRTState]) -> list[_RRTState]:
|
|
135
|
+
assert len(path) > 2
|
|
136
|
+
for _ in range(self._smooth_amt):
|
|
137
|
+
i = self._rng.integers(0, len(path) - 1)
|
|
138
|
+
j = self._rng.integers(0, len(path) - 1)
|
|
139
|
+
if abs(i - j) <= 1:
|
|
140
|
+
continue
|
|
141
|
+
if j < i:
|
|
142
|
+
i, j = j, i
|
|
143
|
+
shortcut = list(self._extend_fn(path[i], path[j]))
|
|
144
|
+
if len(shortcut) < j - i and all(
|
|
145
|
+
not self._collision_fn(pt) for pt in shortcut
|
|
146
|
+
):
|
|
147
|
+
path = path[: i + 1] + shortcut + path[j + 1 :]
|
|
148
|
+
return path
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
class BiRRT(RRT[_RRTState]):
|
|
152
|
+
"""Bidirectional rapidly-exploring random tree."""
|
|
153
|
+
|
|
154
|
+
def query_to_goal_fn(
|
|
155
|
+
self,
|
|
156
|
+
start: _RRTState,
|
|
157
|
+
goal_fn: Callable[[_RRTState], bool],
|
|
158
|
+
goal_sampler: Callable[[], _RRTState] | None = None,
|
|
159
|
+
sample_goal_eps: float = 0.0,
|
|
160
|
+
) -> Optional[list[_RRTState]]:
|
|
161
|
+
raise NotImplementedError("Can't query to goal function using BiRRT")
|
|
162
|
+
|
|
163
|
+
def _rrt_connect(
|
|
164
|
+
self,
|
|
165
|
+
pt1: _RRTState,
|
|
166
|
+
goal_sampler: Callable[[], _RRTState] | None = None,
|
|
167
|
+
goal_fn: Callable[[_RRTState], bool] | None = None,
|
|
168
|
+
sample_goal_eps: float = 0.0,
|
|
169
|
+
) -> Optional[list[_RRTState]]:
|
|
170
|
+
# goal_fn and sample_goal_eps are unused
|
|
171
|
+
assert goal_sampler is not None
|
|
172
|
+
pt2 = goal_sampler()
|
|
173
|
+
root1, root2 = _RRTNode(pt1), _RRTNode(pt2)
|
|
174
|
+
nodes1, nodes2 = [root1], [root2]
|
|
175
|
+
|
|
176
|
+
for _ in range(self._num_iters):
|
|
177
|
+
if len(nodes1) > len(nodes2):
|
|
178
|
+
nodes1, nodes2 = nodes2, nodes1
|
|
179
|
+
samp = self._sample_fn(pt1)
|
|
180
|
+
min_key1 = functools.partial(self._get_pt_dist_to_node, samp)
|
|
181
|
+
nearest1 = min(nodes1, key=min_key1)
|
|
182
|
+
for newpt in self._extend_fn(nearest1.data, samp):
|
|
183
|
+
if self._collision_fn(newpt):
|
|
184
|
+
break
|
|
185
|
+
nearest1 = _RRTNode(newpt, parent=nearest1)
|
|
186
|
+
nodes1.append(nearest1)
|
|
187
|
+
min_key2 = functools.partial(self._get_pt_dist_to_node, nearest1.data)
|
|
188
|
+
nearest2 = min(nodes2, key=min_key2)
|
|
189
|
+
for newpt in self._extend_fn(nearest2.data, nearest1.data):
|
|
190
|
+
if self._collision_fn(newpt):
|
|
191
|
+
break
|
|
192
|
+
nearest2 = _RRTNode(newpt, parent=nearest2)
|
|
193
|
+
nodes2.append(nearest2)
|
|
194
|
+
else:
|
|
195
|
+
path1 = nearest1.path_from_root()
|
|
196
|
+
path2 = nearest2.path_from_root()
|
|
197
|
+
# This is a tricky case to cover.
|
|
198
|
+
if path1[0] != root1: # pragma: no cover
|
|
199
|
+
path1, path2 = path2, path1
|
|
200
|
+
assert path1[0] == root1
|
|
201
|
+
path = path1[:-1] + path2[::-1]
|
|
202
|
+
return [node.data for node in path]
|
|
203
|
+
return None
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
class _RRTNode(Generic[_RRTState]):
|
|
207
|
+
"""A node for RRT."""
|
|
208
|
+
|
|
209
|
+
def __init__(
|
|
210
|
+
self, data: _RRTState, parent: Optional[_RRTNode[_RRTState]] = None
|
|
211
|
+
) -> None:
|
|
212
|
+
self.data = data
|
|
213
|
+
self.parent = parent
|
|
214
|
+
|
|
215
|
+
def path_from_root(self) -> list[_RRTNode[_RRTState]]:
|
|
216
|
+
"""Return the path from the root to this node."""
|
|
217
|
+
sequence = []
|
|
218
|
+
node: Optional[_RRTNode[_RRTState]] = self
|
|
219
|
+
while node is not None:
|
|
220
|
+
sequence.append(node)
|
|
221
|
+
node = node.parent
|
|
222
|
+
return sequence[::-1]
|