traceplane 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- traceplane/__init__.py +16 -0
- traceplane/dataset.py +93 -0
- traceplane/embeddings.py +207 -0
- traceplane/jax.py +108 -0
- traceplane/lerobot_reader.py +362 -0
- traceplane/query.py +643 -0
- traceplane/sim/__init__.py +31 -0
- traceplane/sim/_compat.py +16 -0
- traceplane/sim/cli.py +115 -0
- traceplane/sim/config.py +58 -0
- traceplane/sim/controller.py +77 -0
- traceplane/sim/evaluator.py +230 -0
- traceplane/sim/metrics.py +110 -0
- traceplane/sim/robot.py +123 -0
- traceplane/sim/scene.py +91 -0
- traceplane/sim/visualizer.py +188 -0
- traceplane/tf.py +86 -0
- traceplane/torch.py +126 -0
- traceplane/training/__init__.py +23 -0
- traceplane/training/cli.py +92 -0
- traceplane/training/config.py +82 -0
- traceplane/training/diffusion_policy.py +314 -0
- traceplane/training/eval.py +85 -0
- traceplane/training/normalization.py +262 -0
- traceplane/training/trainer.py +318 -0
- traceplane/windowing.py +112 -0
- traceplane-0.1.0.dist-info/METADATA +129 -0
- traceplane-0.1.0.dist-info/RECORD +30 -0
- traceplane-0.1.0.dist-info/WHEEL +4 -0
- traceplane-0.1.0.dist-info/entry_points.txt +5 -0
traceplane/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
"""Traceplane — streaming dataloader for robotics trajectory datasets."""
|
|
2
|
+
|
|
3
|
+
from traceplane.dataset import Episode, TrajectoryDataset
|
|
4
|
+
from traceplane.lerobot_reader import LeRobotReader
|
|
5
|
+
from traceplane.windowing import WindowedDataset
|
|
6
|
+
from traceplane.query import TraceplaneClient
|
|
7
|
+
|
|
8
|
+
__version__ = "0.1.0"
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"Episode",
|
|
12
|
+
"TrajectoryDataset",
|
|
13
|
+
"LeRobotReader",
|
|
14
|
+
"WindowedDataset",
|
|
15
|
+
"TraceplaneClient",
|
|
16
|
+
]
|
traceplane/dataset.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
"""Core dataset abstractions."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from typing import Any, Iterator, Sequence
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class Episode:
|
|
13
|
+
"""A single trajectory episode.
|
|
14
|
+
|
|
15
|
+
Attributes:
|
|
16
|
+
episode_id: Unique identifier (e.g. "episode_000042").
|
|
17
|
+
observations: Dict of observation arrays keyed by modality.
|
|
18
|
+
Common keys: "state" (proprioception), "action", image camera names.
|
|
19
|
+
Each value is shape (T, D) for vectors or (T, H, W, C) for images.
|
|
20
|
+
actions: Action array, shape (T, action_dim).
|
|
21
|
+
timestamps: Monotonic timestamps in seconds, shape (T,).
|
|
22
|
+
metadata: Arbitrary episode metadata (task label, fps, success, etc.).
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
episode_id: str
|
|
26
|
+
observations: dict[str, np.ndarray] = field(default_factory=dict)
|
|
27
|
+
actions: np.ndarray = field(default_factory=lambda: np.empty((0,)))
|
|
28
|
+
timestamps: np.ndarray = field(default_factory=lambda: np.empty((0,)))
|
|
29
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
30
|
+
|
|
31
|
+
@property
|
|
32
|
+
def length(self) -> int:
|
|
33
|
+
"""Number of timesteps."""
|
|
34
|
+
if self.actions.ndim >= 1 and self.actions.shape[0] > 0:
|
|
35
|
+
return self.actions.shape[0]
|
|
36
|
+
if self.timestamps.ndim >= 1 and self.timestamps.shape[0] > 0:
|
|
37
|
+
return self.timestamps.shape[0]
|
|
38
|
+
for v in self.observations.values():
|
|
39
|
+
if hasattr(v, "shape") and v.shape[0] > 0:
|
|
40
|
+
return v.shape[0]
|
|
41
|
+
return 0
|
|
42
|
+
|
|
43
|
+
@property
|
|
44
|
+
def action_dim(self) -> int:
|
|
45
|
+
if self.actions.ndim == 2:
|
|
46
|
+
return self.actions.shape[1]
|
|
47
|
+
return 0
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class TrajectoryDataset:
|
|
51
|
+
"""Abstract base for trajectory datasets.
|
|
52
|
+
|
|
53
|
+
Subclasses must implement ``__len__`` and ``__getitem__``.
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
def __len__(self) -> int:
|
|
57
|
+
raise NotImplementedError
|
|
58
|
+
|
|
59
|
+
def __getitem__(self, idx: int) -> Episode:
|
|
60
|
+
raise NotImplementedError
|
|
61
|
+
|
|
62
|
+
def __iter__(self) -> Iterator[Episode]:
|
|
63
|
+
for i in range(len(self)):
|
|
64
|
+
yield self[i]
|
|
65
|
+
|
|
66
|
+
def episode_ids(self) -> list[str]:
|
|
67
|
+
"""Return all episode IDs in order."""
|
|
68
|
+
raise NotImplementedError
|
|
69
|
+
|
|
70
|
+
def filter(self, episode_ids: Sequence[str]) -> "FilteredDataset":
|
|
71
|
+
"""Return a view containing only the specified episodes."""
|
|
72
|
+
return FilteredDataset(self, list(episode_ids))
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class FilteredDataset(TrajectoryDataset):
|
|
76
|
+
"""A filtered view over another dataset."""
|
|
77
|
+
|
|
78
|
+
def __init__(self, parent: TrajectoryDataset, episode_ids: list[str]):
|
|
79
|
+
self._parent = parent
|
|
80
|
+
self._ids = episode_ids
|
|
81
|
+
# Build index map: episode_id -> parent index
|
|
82
|
+
parent_ids = parent.episode_ids()
|
|
83
|
+
self._id_to_idx = {eid: i for i, eid in enumerate(parent_ids)}
|
|
84
|
+
self._indices = [self._id_to_idx[eid] for eid in episode_ids if eid in self._id_to_idx]
|
|
85
|
+
|
|
86
|
+
def __len__(self) -> int:
|
|
87
|
+
return len(self._indices)
|
|
88
|
+
|
|
89
|
+
def __getitem__(self, idx: int) -> Episode:
|
|
90
|
+
return self._parent[self._indices[idx]]
|
|
91
|
+
|
|
92
|
+
def episode_ids(self) -> list[str]:
|
|
93
|
+
return [self._ids[i] for i in range(len(self._indices))]
|
traceplane/embeddings.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
1
|
+
"""Episode embedding computation for similarity search.
|
|
2
|
+
|
|
3
|
+
Computes per-episode feature vectors and writes them to Parquet for
|
|
4
|
+
use with the DataFusion ``vec_cosine_sim`` UDF.
|
|
5
|
+
|
|
6
|
+
Two strategies:
|
|
7
|
+
- **Trajectory features** (always available): statistical aggregates
|
|
8
|
+
of observation.state and action vectors per episode.
|
|
9
|
+
- **Text embeddings** (optional, requires ``sentence-transformers``):
|
|
10
|
+
encodes task_label strings with a pretrained language model.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
import argparse
|
|
16
|
+
import os
|
|
17
|
+
import sys
|
|
18
|
+
from pathlib import Path
|
|
19
|
+
from typing import Any
|
|
20
|
+
|
|
21
|
+
import numpy as np
|
|
22
|
+
import pyarrow as pa
|
|
23
|
+
import pyarrow.parquet as pq
|
|
24
|
+
|
|
25
|
+
from traceplane.dataset import Episode
|
|
26
|
+
from traceplane.lerobot_reader import LeRobotReader
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def compute_trajectory_embedding(episode: Episode) -> np.ndarray:
|
|
30
|
+
"""Compute a feature vector from an episode's state and action statistics.
|
|
31
|
+
|
|
32
|
+
Concatenates [mean, std, min, max] per dimension for both
|
|
33
|
+
observation.state and action, then L2-normalises.
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
1-D float32 array.
|
|
37
|
+
"""
|
|
38
|
+
parts: list[np.ndarray] = []
|
|
39
|
+
|
|
40
|
+
# State features
|
|
41
|
+
state = episode.observations.get("state")
|
|
42
|
+
if state is not None and state.ndim == 2 and state.shape[0] > 0:
|
|
43
|
+
parts.extend(_stat_features(state))
|
|
44
|
+
|
|
45
|
+
# Action features
|
|
46
|
+
if episode.actions.ndim == 2 and episode.actions.shape[0] > 0:
|
|
47
|
+
parts.extend(_stat_features(episode.actions))
|
|
48
|
+
|
|
49
|
+
if not parts:
|
|
50
|
+
# Fallback: zero vector
|
|
51
|
+
return np.zeros(1, dtype=np.float32)
|
|
52
|
+
|
|
53
|
+
vec = np.concatenate(parts).astype(np.float32)
|
|
54
|
+
|
|
55
|
+
# L2 normalise
|
|
56
|
+
norm = np.linalg.norm(vec)
|
|
57
|
+
if norm > 1e-8:
|
|
58
|
+
vec /= norm
|
|
59
|
+
|
|
60
|
+
return vec
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def _stat_features(arr: np.ndarray) -> list[np.ndarray]:
|
|
64
|
+
"""Compute [mean, std, min, max] per dimension."""
|
|
65
|
+
return [
|
|
66
|
+
arr.mean(axis=0),
|
|
67
|
+
arr.std(axis=0),
|
|
68
|
+
arr.min(axis=0),
|
|
69
|
+
arr.max(axis=0),
|
|
70
|
+
]
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def compute_text_embeddings(
|
|
74
|
+
labels: list[str],
|
|
75
|
+
model_name: str = "sentence-transformers/all-MiniLM-L6-v2",
|
|
76
|
+
) -> np.ndarray:
|
|
77
|
+
"""Encode text labels with a sentence-transformer model.
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
(N, D) float32 array of embeddings.
|
|
81
|
+
"""
|
|
82
|
+
try:
|
|
83
|
+
from sentence_transformers import SentenceTransformer
|
|
84
|
+
except ImportError:
|
|
85
|
+
raise ImportError(
|
|
86
|
+
"sentence-transformers is required for text embeddings. "
|
|
87
|
+
"Install with: pip install traceplane[embeddings]"
|
|
88
|
+
)
|
|
89
|
+
model = SentenceTransformer(model_name)
|
|
90
|
+
embeddings = model.encode(labels, show_progress_bar=False, normalize_embeddings=True)
|
|
91
|
+
return np.asarray(embeddings, dtype=np.float32)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def compute_embeddings(
|
|
95
|
+
dataset_path: str,
|
|
96
|
+
output_dir: str | None = None,
|
|
97
|
+
include_text: bool = False,
|
|
98
|
+
episode_indices: list[int] | None = None,
|
|
99
|
+
storage_options: dict[str, Any] | None = None,
|
|
100
|
+
) -> str:
|
|
101
|
+
"""Compute episode embeddings and write to Parquet.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
dataset_path: Path to a LeRobot dataset.
|
|
105
|
+
output_dir: Output directory. Defaults to ``{dataset_path}/embeddings``.
|
|
106
|
+
include_text: Also compute text embeddings from task labels.
|
|
107
|
+
episode_indices: Subset of episodes. None = all.
|
|
108
|
+
storage_options: fsspec options for remote datasets.
|
|
109
|
+
|
|
110
|
+
Returns:
|
|
111
|
+
Path to the written Parquet file.
|
|
112
|
+
"""
|
|
113
|
+
reader = LeRobotReader(
|
|
114
|
+
dataset_path,
|
|
115
|
+
storage_options=storage_options,
|
|
116
|
+
episode_indices=episode_indices,
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
n = len(reader)
|
|
120
|
+
if n == 0:
|
|
121
|
+
raise ValueError("Dataset has no episodes")
|
|
122
|
+
|
|
123
|
+
print(f"Computing embeddings for {n} episodes...", file=sys.stderr)
|
|
124
|
+
|
|
125
|
+
indices: list[int] = []
|
|
126
|
+
labels: list[str] = []
|
|
127
|
+
traj_embeddings: list[np.ndarray] = []
|
|
128
|
+
|
|
129
|
+
for i in range(n):
|
|
130
|
+
ep = reader[i]
|
|
131
|
+
traj_emb = compute_trajectory_embedding(ep)
|
|
132
|
+
traj_embeddings.append(traj_emb)
|
|
133
|
+
|
|
134
|
+
# Extract episode index from metadata or ID
|
|
135
|
+
ep_idx = ep.metadata.get("episode_index", i)
|
|
136
|
+
if isinstance(ep_idx, str):
|
|
137
|
+
try:
|
|
138
|
+
ep_idx = int(ep_idx.split("_")[-1])
|
|
139
|
+
except (ValueError, IndexError):
|
|
140
|
+
ep_idx = i
|
|
141
|
+
indices.append(int(ep_idx))
|
|
142
|
+
labels.append(ep.metadata.get("task_label", ""))
|
|
143
|
+
|
|
144
|
+
if (i + 1) % 50 == 0 or i == n - 1:
|
|
145
|
+
print(f" {i + 1}/{n}", file=sys.stderr)
|
|
146
|
+
|
|
147
|
+
# Build arrow arrays
|
|
148
|
+
traj_dim = traj_embeddings[0].shape[0]
|
|
149
|
+
arrow_traj = pa.list_(pa.float32())
|
|
150
|
+
|
|
151
|
+
columns: dict[str, Any] = {
|
|
152
|
+
"episode_index": pa.array(indices, type=pa.int64()),
|
|
153
|
+
"task_label": pa.array(labels, type=pa.utf8()),
|
|
154
|
+
"trajectory_embedding": pa.array(
|
|
155
|
+
[emb.tolist() for emb in traj_embeddings],
|
|
156
|
+
type=arrow_traj,
|
|
157
|
+
),
|
|
158
|
+
}
|
|
159
|
+
|
|
160
|
+
# Optional text embeddings
|
|
161
|
+
if include_text:
|
|
162
|
+
unique_labels = list(set(labels))
|
|
163
|
+
print(f"Computing text embeddings for {len(unique_labels)} unique labels...", file=sys.stderr)
|
|
164
|
+
text_embs = compute_text_embeddings(unique_labels)
|
|
165
|
+
label_to_emb = {lbl: text_embs[i] for i, lbl in enumerate(unique_labels)}
|
|
166
|
+
text_vecs = [label_to_emb[lbl].tolist() for lbl in labels]
|
|
167
|
+
columns["text_embedding"] = pa.array(
|
|
168
|
+
text_vecs,
|
|
169
|
+
type=pa.list_(pa.float32()),
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
table = pa.table(columns)
|
|
173
|
+
|
|
174
|
+
# Write
|
|
175
|
+
if output_dir is None:
|
|
176
|
+
output_dir = os.path.join(dataset_path, "embeddings")
|
|
177
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
178
|
+
output_path = os.path.join(output_dir, "episode_embeddings.parquet")
|
|
179
|
+
pq.write_table(table, output_path)
|
|
180
|
+
|
|
181
|
+
print(f"Written {n} embeddings ({traj_dim}-dim) to {output_path}", file=sys.stderr)
|
|
182
|
+
return output_path
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def main(argv: list[str] | None = None) -> None:
|
|
186
|
+
"""CLI entry point for embedding computation."""
|
|
187
|
+
parser = argparse.ArgumentParser(
|
|
188
|
+
description="Compute episode embeddings for similarity search",
|
|
189
|
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
|
190
|
+
)
|
|
191
|
+
parser.add_argument("dataset_path", help="Path to LeRobot dataset")
|
|
192
|
+
parser.add_argument("--output-dir", help="Output directory (default: {dataset_path}/embeddings)")
|
|
193
|
+
parser.add_argument("--include-text", action="store_true", help="Compute text embeddings (requires sentence-transformers)")
|
|
194
|
+
parser.add_argument("--episodes", type=int, nargs="+", help="Specific episode indices")
|
|
195
|
+
|
|
196
|
+
args = parser.parse_args(argv)
|
|
197
|
+
path = compute_embeddings(
|
|
198
|
+
args.dataset_path,
|
|
199
|
+
output_dir=args.output_dir,
|
|
200
|
+
include_text=args.include_text,
|
|
201
|
+
episode_indices=args.episodes,
|
|
202
|
+
)
|
|
203
|
+
print(path)
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
if __name__ == "__main__":
|
|
207
|
+
main()
|
traceplane/jax.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
"""JAX integration — convert episodes and windows to JAX arrays."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Any, Iterator
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
try:
|
|
10
|
+
import jax
|
|
11
|
+
import jax.numpy as jnp
|
|
12
|
+
except ImportError:
|
|
13
|
+
raise ImportError(
|
|
14
|
+
"JAX is required for traceplane.jax. "
|
|
15
|
+
"Install it with: pip install traceplane[jax]"
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
from traceplane.dataset import TrajectoryDataset
|
|
19
|
+
from traceplane.windowing import WindowedDataset
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def episodes_to_jax(
|
|
23
|
+
dataset: TrajectoryDataset,
|
|
24
|
+
) -> list[dict[str, Any]]:
|
|
25
|
+
"""Convert all episodes to dicts of JAX arrays.
|
|
26
|
+
|
|
27
|
+
Returns a list of dicts, each containing:
|
|
28
|
+
- ``actions``: jax array (T, action_dim)
|
|
29
|
+
- ``timestamps``: jax array (T,)
|
|
30
|
+
- observation keys: jax arrays
|
|
31
|
+
- ``episode_id``: str
|
|
32
|
+
- ``metadata``: dict
|
|
33
|
+
"""
|
|
34
|
+
results = []
|
|
35
|
+
for ep in dataset:
|
|
36
|
+
item: dict[str, Any] = {
|
|
37
|
+
"episode_id": ep.episode_id,
|
|
38
|
+
"metadata": ep.metadata,
|
|
39
|
+
"actions": jnp.array(ep.actions, dtype=jnp.float32),
|
|
40
|
+
"timestamps": jnp.array(ep.timestamps, dtype=jnp.float64),
|
|
41
|
+
}
|
|
42
|
+
for key, arr in ep.observations.items():
|
|
43
|
+
item[key] = jnp.array(arr, dtype=jnp.float32)
|
|
44
|
+
results.append(item)
|
|
45
|
+
return results
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def windowed_iterator(
|
|
49
|
+
dataset: TrajectoryDataset,
|
|
50
|
+
obs_horizon: int = 2,
|
|
51
|
+
action_horizon: int = 16,
|
|
52
|
+
stride: int = 1,
|
|
53
|
+
batch_size: int = 64,
|
|
54
|
+
shuffle: bool = True,
|
|
55
|
+
seed: int = 0,
|
|
56
|
+
) -> Iterator[dict[str, Any]]:
|
|
57
|
+
"""Yield batched training windows as dicts of JAX arrays.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
dataset: Source trajectory dataset.
|
|
61
|
+
obs_horizon: Observation history length.
|
|
62
|
+
action_horizon: Action prediction horizon.
|
|
63
|
+
stride: Window stride.
|
|
64
|
+
batch_size: Batch size.
|
|
65
|
+
shuffle: Shuffle sample order.
|
|
66
|
+
seed: Random seed for shuffling.
|
|
67
|
+
|
|
68
|
+
Yields:
|
|
69
|
+
Dicts with ``obs_history`` (dict of jax arrays), ``action_chunk``
|
|
70
|
+
(jax array), ``episode_ids`` (list[str]), ``timesteps`` (list[int]).
|
|
71
|
+
"""
|
|
72
|
+
windowed = WindowedDataset(
|
|
73
|
+
dataset,
|
|
74
|
+
obs_horizon=obs_horizon,
|
|
75
|
+
action_horizon=action_horizon,
|
|
76
|
+
stride=stride,
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
indices = np.arange(len(windowed))
|
|
80
|
+
if shuffle:
|
|
81
|
+
rng = np.random.default_rng(seed)
|
|
82
|
+
rng.shuffle(indices)
|
|
83
|
+
|
|
84
|
+
for start in range(0, len(indices), batch_size):
|
|
85
|
+
batch_indices = indices[start : start + batch_size]
|
|
86
|
+
windows = [windowed[int(i)] for i in batch_indices]
|
|
87
|
+
|
|
88
|
+
# Collate
|
|
89
|
+
obs_keys = set()
|
|
90
|
+
for w in windows:
|
|
91
|
+
obs_keys.update(w.obs_history.keys())
|
|
92
|
+
|
|
93
|
+
batch: dict[str, Any] = {
|
|
94
|
+
"episode_ids": [w.episode_id for w in windows],
|
|
95
|
+
"timesteps": [w.timestep for w in windows],
|
|
96
|
+
"action_chunk": jnp.stack(
|
|
97
|
+
[jnp.array(w.action_chunk, dtype=jnp.float32) for w in windows]
|
|
98
|
+
),
|
|
99
|
+
"obs_history": {},
|
|
100
|
+
}
|
|
101
|
+
for key in obs_keys:
|
|
102
|
+
arrays = [
|
|
103
|
+
jnp.array(w.obs_history.get(key, np.zeros((obs_horizon, 1))), dtype=jnp.float32)
|
|
104
|
+
for w in windows
|
|
105
|
+
]
|
|
106
|
+
batch["obs_history"][key] = jnp.stack(arrays)
|
|
107
|
+
|
|
108
|
+
yield batch
|