prpl_utils 0.0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Tom Silver
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,49 @@
1
+ Metadata-Version: 2.4
2
+ Name: prpl_utils
3
+ Version: 0.0.1
4
+ Summary: Common Python utilities for the Princeton Robot Planning and Learning lab.
5
+ Requires-Python: >=3.10
6
+ Description-Content-Type: text/markdown
7
+ License-File: LICENSE
8
+ Requires-Dist: matplotlib
9
+ Requires-Dist: numpy
10
+ Requires-Dist: graphviz
11
+ Requires-Dist: pyperplan
12
+ Requires-Dist: gymnasium
13
+ Requires-Dist: moviepy
14
+ Provides-Extra: develop
15
+ Requires-Dist: black; extra == "develop"
16
+ Requires-Dist: docformatter; extra == "develop"
17
+ Requires-Dist: isort; extra == "develop"
18
+ Requires-Dist: mypy; extra == "develop"
19
+ Requires-Dist: pylint>=2.14.5; extra == "develop"
20
+ Requires-Dist: pytest-pylint>=0.18.0; extra == "develop"
21
+ Requires-Dist: pytest>=7.2.2; extra == "develop"
22
+ Dynamic: license-file
23
+
24
+ # PRPL Utils
25
+
26
+ Miscellaneous Python utilities from the Princeton Robot Planning and Learning group.
27
+ - **Motion planning**: RRT, BiRRT
28
+ - **PDDL planning**: interfaces to pyperplan, Fast Downward
29
+ - **Heuristic search**: A*, GBFS, hill-climbing, policy-guided A*
30
+ - **Gymnasium**: agent interface, helper spaces
31
+ - **Other**: a few other miscellaneous utils
32
+
33
+ ## Requirements
34
+
35
+ - Python 3.10+
36
+ - Tested on MacOS Monterey and Ubuntu 22.04
37
+
38
+ ## Installation
39
+
40
+ 1. Recommended: create and source a virtualenv.
41
+ 2. `pip install -e ".[develop]"`
42
+
43
+ ## Check Installation
44
+
45
+ Run `./run_ci_checks.sh`. It should complete with all green successes in 5-10 seconds.
46
+
47
+ ## Contributing
48
+
49
+ Pull requests welcome. Note that this is meant to be a **lightweight** package. For example, it should be safe to use in homework assignments with minimal assumptions about the user's setup. Do not add heavy-duty dependencies.
@@ -0,0 +1,26 @@
1
+ # PRPL Utils
2
+
3
+ Miscellaneous Python utilities from the Princeton Robot Planning and Learning group.
4
+ - **Motion planning**: RRT, BiRRT
5
+ - **PDDL planning**: interfaces to pyperplan, Fast Downward
6
+ - **Heuristic search**: A*, GBFS, hill-climbing, policy-guided A*
7
+ - **Gymnasium**: agent interface, helper spaces
8
+ - **Other**: a few other miscellaneous utils
9
+
10
+ ## Requirements
11
+
12
+ - Python 3.10+
13
+ - Tested on MacOS Monterey and Ubuntu 22.04
14
+
15
+ ## Installation
16
+
17
+ 1. Recommended: create and source a virtualenv.
18
+ 2. `pip install -e ".[develop]"`
19
+
20
+ ## Check Installation
21
+
22
+ Run `./run_ci_checks.sh`. It should complete with all green successes in 5-10 seconds.
23
+
24
+ ## Contributing
25
+
26
+ Pull requests welcome. Note that this is meant to be a **lightweight** package. For example, it should be safe to use in homework assignments with minimal assumptions about the user's setup. Do not add heavy-duty dependencies.
@@ -0,0 +1,65 @@
1
+ [build-system]
2
+ requires = ["setuptools"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "prpl_utils"
7
+ version = "0.0.1"
8
+ description = "Common Python utilities for the Princeton Robot Planning and Learning lab."
9
+ readme = "README.md"
10
+ requires-python = ">=3.10"
11
+ dependencies = [
12
+ "matplotlib",
13
+ "numpy",
14
+ "graphviz",
15
+ "pyperplan",
16
+ "gymnasium",
17
+ "moviepy",
18
+ ]
19
+
20
+ [project.optional-dependencies]
21
+ develop = [
22
+ "black",
23
+ "docformatter",
24
+ "isort",
25
+ "mypy",
26
+ "pylint>=2.14.5",
27
+ "pytest-pylint>=0.18.0",
28
+ "pytest>=7.2.2",
29
+ ]
30
+
31
+ [tool.setuptools.packages.find]
32
+ where = ["src"]
33
+
34
+ [tool.setuptools.package-data]
35
+ prpl_utils = ["py.typed"]
36
+
37
+ [tool.black]
38
+ line-length = 88
39
+ target-version = ["py310"]
40
+
41
+ [tool.isort]
42
+ py_version = 310
43
+ profile = "black"
44
+ multi_line_output = 2
45
+ skip_glob = ["venv/*", ".venv/*", "build", "dist", "third_party", "third-party", "__pycache__/"]
46
+ split_on_trailing_comma = true
47
+
48
+ [tool.docformatter]
49
+ line-length = 88
50
+ wrap-summaries = 88
51
+ wrap-descriptions = 88
52
+
53
+ [tool.mypy]
54
+ strict_equality = true
55
+ disallow_untyped_calls = true
56
+ warn_unreachable = true
57
+ exclude = ["venv/*", ".venv/*", "build", "dist", "third_party", "third-party", "__pycache__/"]
58
+
59
+ [[tool.mypy.overrides]]
60
+ module = [
61
+ "matplotlib.*",
62
+ "pyperplan.*",
63
+ "graphviz.*",
64
+ ]
65
+ ignore_missing_imports = true
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
File without changes
@@ -0,0 +1,77 @@
1
+ """A basic agent interface compatible with gym envs."""
2
+
3
+ import abc
4
+ from typing import Any, Generic, TypeVar
5
+
6
+ import numpy as np
7
+
8
+ _ObsType = TypeVar("_ObsType")
9
+ _ActType = TypeVar("_ActType")
10
+
11
+
12
+ class Agent(Generic[_ObsType, _ActType]):
13
+ """Base class for a sequential decision-making agent."""
14
+
15
+ def __init__(self, seed: int) -> None:
16
+ self._rng = np.random.default_rng(seed)
17
+ self._last_observation: _ObsType | None = None
18
+ self._last_action: _ActType | None = None
19
+ self._last_info: dict[str, Any] | None = None
20
+ self._timestep: int = 0
21
+ self._train_or_eval = "eval"
22
+
23
+ @abc.abstractmethod
24
+ def _get_action(self) -> _ActType:
25
+ """Produce an action to execute now."""
26
+
27
+ def _learn_from_transition(
28
+ self,
29
+ obs: _ObsType,
30
+ act: _ActType,
31
+ next_obs: _ObsType,
32
+ reward: float,
33
+ done: bool,
34
+ info: dict[str, Any],
35
+ ) -> None:
36
+ """Update any internal models based on the observed transition."""
37
+
38
+ def reset(
39
+ self,
40
+ obs: _ObsType,
41
+ info: dict[str, Any],
42
+ ) -> None:
43
+ """Start a new episode."""
44
+ self._last_observation = obs
45
+ self._last_info = info
46
+ self._timestep = 0
47
+
48
+ def step(self) -> _ActType:
49
+ """Get the next action to take."""
50
+ self._last_action = self._get_action()
51
+ self._timestep += 1
52
+ return self._last_action
53
+
54
+ def update(
55
+ self, obs: _ObsType, reward: float, done: bool, info: dict[str, Any]
56
+ ) -> None:
57
+ """Record the reward and next observation following an action."""
58
+ assert self._last_observation is not None
59
+ assert self._last_action is not None
60
+ if self._train_or_eval == "train":
61
+ self._learn_from_transition(
62
+ self._last_observation, self._last_action, obs, reward, done, info
63
+ )
64
+ self._last_observation = obs
65
+ self._last_info = info
66
+
67
+ def seed(self, seed: int) -> None:
68
+ """Reset the random number generator."""
69
+ self._rng = np.random.default_rng(seed)
70
+
71
+ def train(self) -> None:
72
+ """Switch to train mode."""
73
+ self._train_or_eval = "train"
74
+
75
+ def eval(self) -> None:
76
+ """Switch to eval mode."""
77
+ self._train_or_eval = "eval"
@@ -0,0 +1,105 @@
1
+ """Utilities for Gym/Gymnasium environment compatibility and wrappers."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any, Callable
6
+
7
+ import numpy as np
8
+ from gymnasium import spaces
9
+
10
+
11
+ class GymToGymnasium:
12
+ """Wrap a legacy Gym environment to present the Gymnasium API.
13
+
14
+ - Maps `reset()` → `(obs, info)`; seeds via `env.seed(seed)` if available.
15
+ - Maps `step()` → `(obs, reward, terminated, truncated, info)`; derives
16
+ `truncated` from `info.get('TimeLimit.truncated', False)`.
17
+ - Forwards `observation_space`, `action_space`, `spec`, `render()`, `close()`.
18
+ """
19
+
20
+ def __init__(self, base_env: Any) -> None:
21
+ """Wrap a legacy gym environment."""
22
+ self._env = base_env
23
+ self.observation_space: spaces.Space | None = getattr(
24
+ base_env, "observation_space", None
25
+ )
26
+ self.action_space: spaces.Space | None = getattr(base_env, "action_space", None)
27
+ self.spec = getattr(base_env, "spec", None)
28
+
29
+ @property
30
+ def unwrapped(self) -> Any:
31
+ """Get the underlying unwrapped environment."""
32
+ return getattr(self._env, "unwrapped", self._env)
33
+
34
+ def reset(self, *, seed: int | None = None) -> tuple[Any, dict[str, Any]]:
35
+ """Reset environment and return (obs, info)."""
36
+ if seed is not None:
37
+ if hasattr(self._env, "seed") and callable(self._env.seed):
38
+ try:
39
+ self._env.seed(seed)
40
+ except (AttributeError, TypeError):
41
+ # Legacy env has broken seed implementation, continue without seeding
42
+ _ = None
43
+ res = self._env.reset()
44
+ if isinstance(res, tuple) and len(res) == 2:
45
+ return res # already (obs, info)
46
+ return res, {}
47
+
48
+ def step(self, action: Any) -> tuple[Any, float, bool, bool, dict[str, Any]]:
49
+ """Step environment and return (obs, reward, terminated, truncated, info)."""
50
+ result = self._env.step(action)
51
+ if isinstance(result, tuple):
52
+ if len(result) == 4:
53
+ obs, reward, done, info = result
54
+ terminated = bool(done)
55
+ truncated = bool(info.get("TimeLimit.truncated", False))
56
+ return obs, float(reward), terminated, truncated, info
57
+ if len(result) == 5:
58
+ obs, reward, terminated, truncated, info = result
59
+ return obs, float(reward), bool(terminated), bool(truncated), info
60
+ raise ValueError("Unexpected number of values returned from env.step")
61
+
62
+ def render(self) -> Any:
63
+ """Render the environment."""
64
+ if hasattr(self._env, "render"):
65
+ try:
66
+ return self._env.render()
67
+ except TypeError:
68
+ return self._env.render(mode="human")
69
+ return None
70
+
71
+ def close(self) -> None:
72
+ """Close the environment."""
73
+ getattr(self._env, "close", lambda: None)()
74
+
75
+
76
+ def patch_box_float32() -> Callable[..., None]:
77
+ """Resolves Gymnasium Box precision warnings by patching space creation to use
78
+ float32 dtypes from the start, eliminating the need for runtime casting that
79
+ triggers "precision lowered" warnings from environment libraries like KinDER (KinDER
80
+ creates box spaces with float64 bounds).
81
+
82
+ Returns:
83
+ The original Box.__init__ method for restoration.
84
+ """
85
+
86
+ original_box_init = spaces.Box.__init__
87
+
88
+ def patched_box_init(self, low, high, shape=None, dtype=np.float32, seed=None):
89
+ # Convert bounds to float32 if they're float64 to avoid warnings
90
+ if hasattr(low, "dtype") and low.dtype == np.float64:
91
+ low = low.astype(np.float32)
92
+ if hasattr(high, "dtype") and high.dtype == np.float64:
93
+ high = high.astype(np.float32)
94
+
95
+ # Force dtype to float32 for floating point types
96
+ if dtype == np.float64:
97
+ dtype = np.float32
98
+
99
+ return original_box_init(self, low, high, shape, dtype, seed)
100
+
101
+ spaces.Box.__init__ = patched_box_init # type: ignore
102
+ return original_box_init
103
+
104
+
105
+ __all__ = ["GymToGymnasium", "patch_box_float32"]
@@ -0,0 +1,222 @@
1
+ """Motion planning utilities."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import functools
6
+ from typing import Callable, Generic, Iterable, Optional, TypeVar
7
+
8
+ import numpy as np
9
+
10
+ _RRTState = TypeVar("_RRTState")
11
+
12
+
13
+ class RRT(Generic[_RRTState]):
14
+ """Rapidly-exploring random tree."""
15
+
16
+ def __init__(
17
+ self,
18
+ sample_fn: Callable[[_RRTState], _RRTState],
19
+ extend_fn: Callable[[_RRTState, _RRTState], Iterable[_RRTState]],
20
+ collision_fn: Callable[[_RRTState], bool],
21
+ distance_fn: Callable[[_RRTState, _RRTState], float],
22
+ rng: np.random.Generator,
23
+ num_attempts: int,
24
+ num_iters: int,
25
+ smooth_amt: int,
26
+ ):
27
+ self._sample_fn = sample_fn
28
+ self._extend_fn = extend_fn
29
+ self._collision_fn = collision_fn
30
+ self._distance_fn = distance_fn
31
+ self._rng = rng
32
+ self._num_attempts = num_attempts
33
+ self._num_iters = num_iters
34
+ self._smooth_amt = smooth_amt
35
+
36
+ def query(
37
+ self, pt1: _RRTState, pt2: _RRTState, sample_goal_eps: float = 0.0
38
+ ) -> Optional[list[_RRTState]]:
39
+ """Query the RRT, to get a collision-free path from pt1 to pt2.
40
+
41
+ If none is found, returns None.
42
+ """
43
+ if self._collision_fn(pt1) or self._collision_fn(pt2):
44
+ return None
45
+ direct_path = self.try_direct_path(pt1, pt2)
46
+ if direct_path is not None:
47
+ return direct_path
48
+ for _ in range(self._num_attempts):
49
+ path = self._rrt_connect(
50
+ pt1, goal_sampler=lambda: pt2, sample_goal_eps=sample_goal_eps
51
+ )
52
+ if path is not None:
53
+ return self._smooth_path(path)
54
+ return None
55
+
56
+ def query_to_goal_fn(
57
+ self,
58
+ start: _RRTState,
59
+ goal_fn: Callable[[_RRTState], bool],
60
+ goal_sampler: Callable[[], _RRTState] | None = None,
61
+ sample_goal_eps: float = 0.0,
62
+ ) -> Optional[list[_RRTState]]:
63
+ """Query the RRT, to get a collision-free path from start to a point such that
64
+ goal_fn(point) is True. Uses goal_sampler to sample a target for a direct path
65
+ or with probability sample_goal_eps.
66
+
67
+ If none is found, returns None.
68
+ """
69
+ assert sample_goal_eps == 0.0 or goal_sampler is not None
70
+ if self._collision_fn(start):
71
+ return None
72
+ if goal_sampler:
73
+ direct_path = self.try_direct_path(start, goal_sampler())
74
+ if direct_path is not None:
75
+ return direct_path
76
+ for _ in range(self._num_attempts):
77
+ path = self._rrt_connect(
78
+ start, goal_sampler, goal_fn, sample_goal_eps=sample_goal_eps
79
+ )
80
+ if path is not None:
81
+ return self._smooth_path(path)
82
+ return None
83
+
84
+ def try_direct_path(
85
+ self, pt1: _RRTState, pt2: _RRTState
86
+ ) -> Optional[list[_RRTState]]:
87
+ """Attempt to plan a direct path from pt1 to pt2, returning None if collision-
88
+ free path can be found."""
89
+ path = [pt1]
90
+ for newpt in self._extend_fn(pt1, pt2):
91
+ if self._collision_fn(newpt):
92
+ return None
93
+ path.append(newpt)
94
+ return path
95
+
96
+ def _rrt_connect(
97
+ self,
98
+ pt1: _RRTState,
99
+ goal_sampler: Callable[[], _RRTState] | None = None,
100
+ goal_fn: Callable[[_RRTState], bool] | None = None,
101
+ sample_goal_eps: float = 0.0,
102
+ ) -> Optional[list[_RRTState]]:
103
+ root = _RRTNode(pt1)
104
+ nodes = [root]
105
+
106
+ for _ in range(self._num_iters):
107
+ # Sample the goal with a small probability, otherwise randomly
108
+ # choose a point.
109
+ sample_goal = self._rng.random() < sample_goal_eps
110
+ if sample_goal:
111
+ assert goal_sampler is not None
112
+ samp = goal_sampler()
113
+ else:
114
+ samp = self._sample_fn(pt1)
115
+ min_key = functools.partial(self._get_pt_dist_to_node, samp)
116
+ nearest = min(nodes, key=min_key)
117
+ reached_goal = False
118
+ for newpt in self._extend_fn(nearest.data, samp):
119
+ if self._collision_fn(newpt):
120
+ break
121
+ nearest = _RRTNode(newpt, parent=nearest)
122
+ nodes.append(nearest)
123
+ else:
124
+ reached_goal = sample_goal
125
+ # Check goal_fn if defined
126
+ if reached_goal or goal_fn is not None and goal_fn(nearest.data):
127
+ path = nearest.path_from_root()
128
+ return [node.data for node in path]
129
+ return None
130
+
131
+ def _get_pt_dist_to_node(self, pt: _RRTState, node: _RRTNode[_RRTState]) -> float:
132
+ return self._distance_fn(pt, node.data)
133
+
134
+ def _smooth_path(self, path: list[_RRTState]) -> list[_RRTState]:
135
+ assert len(path) > 2
136
+ for _ in range(self._smooth_amt):
137
+ i = self._rng.integers(0, len(path) - 1)
138
+ j = self._rng.integers(0, len(path) - 1)
139
+ if abs(i - j) <= 1:
140
+ continue
141
+ if j < i:
142
+ i, j = j, i
143
+ shortcut = list(self._extend_fn(path[i], path[j]))
144
+ if len(shortcut) < j - i and all(
145
+ not self._collision_fn(pt) for pt in shortcut
146
+ ):
147
+ path = path[: i + 1] + shortcut + path[j + 1 :]
148
+ return path
149
+
150
+
151
+ class BiRRT(RRT[_RRTState]):
152
+ """Bidirectional rapidly-exploring random tree."""
153
+
154
+ def query_to_goal_fn(
155
+ self,
156
+ start: _RRTState,
157
+ goal_fn: Callable[[_RRTState], bool],
158
+ goal_sampler: Callable[[], _RRTState] | None = None,
159
+ sample_goal_eps: float = 0.0,
160
+ ) -> Optional[list[_RRTState]]:
161
+ raise NotImplementedError("Can't query to goal function using BiRRT")
162
+
163
+ def _rrt_connect(
164
+ self,
165
+ pt1: _RRTState,
166
+ goal_sampler: Callable[[], _RRTState] | None = None,
167
+ goal_fn: Callable[[_RRTState], bool] | None = None,
168
+ sample_goal_eps: float = 0.0,
169
+ ) -> Optional[list[_RRTState]]:
170
+ # goal_fn and sample_goal_eps are unused
171
+ assert goal_sampler is not None
172
+ pt2 = goal_sampler()
173
+ root1, root2 = _RRTNode(pt1), _RRTNode(pt2)
174
+ nodes1, nodes2 = [root1], [root2]
175
+
176
+ for _ in range(self._num_iters):
177
+ if len(nodes1) > len(nodes2):
178
+ nodes1, nodes2 = nodes2, nodes1
179
+ samp = self._sample_fn(pt1)
180
+ min_key1 = functools.partial(self._get_pt_dist_to_node, samp)
181
+ nearest1 = min(nodes1, key=min_key1)
182
+ for newpt in self._extend_fn(nearest1.data, samp):
183
+ if self._collision_fn(newpt):
184
+ break
185
+ nearest1 = _RRTNode(newpt, parent=nearest1)
186
+ nodes1.append(nearest1)
187
+ min_key2 = functools.partial(self._get_pt_dist_to_node, nearest1.data)
188
+ nearest2 = min(nodes2, key=min_key2)
189
+ for newpt in self._extend_fn(nearest2.data, nearest1.data):
190
+ if self._collision_fn(newpt):
191
+ break
192
+ nearest2 = _RRTNode(newpt, parent=nearest2)
193
+ nodes2.append(nearest2)
194
+ else:
195
+ path1 = nearest1.path_from_root()
196
+ path2 = nearest2.path_from_root()
197
+ # This is a tricky case to cover.
198
+ if path1[0] != root1: # pragma: no cover
199
+ path1, path2 = path2, path1
200
+ assert path1[0] == root1
201
+ path = path1[:-1] + path2[::-1]
202
+ return [node.data for node in path]
203
+ return None
204
+
205
+
206
+ class _RRTNode(Generic[_RRTState]):
207
+ """A node for RRT."""
208
+
209
+ def __init__(
210
+ self, data: _RRTState, parent: Optional[_RRTNode[_RRTState]] = None
211
+ ) -> None:
212
+ self.data = data
213
+ self.parent = parent
214
+
215
+ def path_from_root(self) -> list[_RRTNode[_RRTState]]:
216
+ """Return the path from the root to this node."""
217
+ sequence = []
218
+ node: Optional[_RRTNode[_RRTState]] = self
219
+ while node is not None:
220
+ sequence.append(node)
221
+ node = node.parent
222
+ return sequence[::-1]