traceplane 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. traceplane-0.1.0/.gitignore +29 -0
  2. traceplane-0.1.0/PKG-INFO +129 -0
  3. traceplane-0.1.0/README.md +71 -0
  4. traceplane-0.1.0/pyproject.toml +58 -0
  5. traceplane-0.1.0/scripts/setup_isaac_sim.sh +85 -0
  6. traceplane-0.1.0/src/traceplane/__init__.py +16 -0
  7. traceplane-0.1.0/src/traceplane/dataset.py +93 -0
  8. traceplane-0.1.0/src/traceplane/embeddings.py +207 -0
  9. traceplane-0.1.0/src/traceplane/jax.py +108 -0
  10. traceplane-0.1.0/src/traceplane/lerobot_reader.py +362 -0
  11. traceplane-0.1.0/src/traceplane/query.py +643 -0
  12. traceplane-0.1.0/src/traceplane/sim/__init__.py +31 -0
  13. traceplane-0.1.0/src/traceplane/sim/_compat.py +16 -0
  14. traceplane-0.1.0/src/traceplane/sim/cli.py +115 -0
  15. traceplane-0.1.0/src/traceplane/sim/config.py +58 -0
  16. traceplane-0.1.0/src/traceplane/sim/controller.py +77 -0
  17. traceplane-0.1.0/src/traceplane/sim/evaluator.py +230 -0
  18. traceplane-0.1.0/src/traceplane/sim/metrics.py +110 -0
  19. traceplane-0.1.0/src/traceplane/sim/robot.py +123 -0
  20. traceplane-0.1.0/src/traceplane/sim/scene.py +91 -0
  21. traceplane-0.1.0/src/traceplane/sim/visualizer.py +188 -0
  22. traceplane-0.1.0/src/traceplane/tf.py +86 -0
  23. traceplane-0.1.0/src/traceplane/torch.py +126 -0
  24. traceplane-0.1.0/src/traceplane/training/__init__.py +23 -0
  25. traceplane-0.1.0/src/traceplane/training/cli.py +92 -0
  26. traceplane-0.1.0/src/traceplane/training/config.py +82 -0
  27. traceplane-0.1.0/src/traceplane/training/diffusion_policy.py +314 -0
  28. traceplane-0.1.0/src/traceplane/training/eval.py +85 -0
  29. traceplane-0.1.0/src/traceplane/training/normalization.py +262 -0
  30. traceplane-0.1.0/src/traceplane/training/trainer.py +318 -0
  31. traceplane-0.1.0/src/traceplane/windowing.py +112 -0
  32. traceplane-0.1.0/tests/test_core.py +217 -0
  33. traceplane-0.1.0/tests/test_engine.py +91 -0
  34. traceplane-0.1.0/tests/test_integration.py +426 -0
  35. traceplane-0.1.0/tests/test_torch.py +107 -0
@@ -0,0 +1,29 @@
1
+ # Rust
2
+ backend/target/
3
+
4
+ # Node / Frontend
5
+ frontend/node_modules/
6
+ frontend/dist/
7
+
8
+ # Python
9
+ dataloader/__pycache__/
10
+ dataloader/src/traceplane/__pycache__/
11
+ dataloader/tests/__pycache__/
12
+ dataloader/.pytest_cache/
13
+ dataloader/src/*.egg-info/
14
+ dataloader/src/traceplane.egg-info/
15
+ *.pyc
16
+ *.pyo
17
+ __pycache__/
18
+ *.egg-info/
19
+
20
+ # IDE
21
+ .idea/
22
+ .vscode/
23
+ *.swp
24
+
25
+ # OS
26
+ .DS_Store
27
+
28
+ # Env
29
+ .env
@@ -0,0 +1,129 @@
1
+ Metadata-Version: 2.4
2
+ Name: traceplane
3
+ Version: 0.1.0
4
+ Summary: Streaming dataloader for robotics trajectory datasets
5
+ Project-URL: Homepage, https://traceplane.ai
6
+ Project-URL: Documentation, https://docs.traceplane.ai
7
+ Project-URL: Repository, https://github.com/traceplane/traceplane
8
+ Project-URL: Issues, https://github.com/traceplane/traceplane/issues
9
+ Author-email: Traceplane <hello@traceplane.ai>
10
+ License-Expression: Apache-2.0
11
+ Keywords: datasets,imitation-learning,lerobot,robotics,trajectories
12
+ Classifier: Development Status :: 4 - Beta
13
+ Classifier: Intended Audience :: Science/Research
14
+ Classifier: License :: OSI Approved :: Apache Software License
15
+ Classifier: Programming Language :: Python :: 3
16
+ Classifier: Programming Language :: Python :: 3.10
17
+ Classifier: Programming Language :: Python :: 3.11
18
+ Classifier: Programming Language :: Python :: 3.12
19
+ Classifier: Programming Language :: Python :: 3.13
20
+ Classifier: Topic :: Scientific/Engineering
21
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
22
+ Requires-Python: >=3.10
23
+ Requires-Dist: fsspec>=2023.1
24
+ Requires-Dist: numpy>=1.24
25
+ Requires-Dist: pyarrow>=14.0
26
+ Requires-Dist: requests>=2.28
27
+ Provides-Extra: all
28
+ Requires-Dist: gcsfs>=2023.1; extra == 'all'
29
+ Requires-Dist: jax>=0.4; extra == 'all'
30
+ Requires-Dist: jaxlib>=0.4; extra == 'all'
31
+ Requires-Dist: numpy>=1.24; extra == 'all'
32
+ Requires-Dist: s3fs>=2023.1; extra == 'all'
33
+ Requires-Dist: sentence-transformers>=2.0; extra == 'all'
34
+ Requires-Dist: tensorflow>=2.14; extra == 'all'
35
+ Requires-Dist: torch>=2.0; extra == 'all'
36
+ Provides-Extra: dev
37
+ Requires-Dist: pytest>=7; extra == 'dev'
38
+ Requires-Dist: torch>=2.0; extra == 'dev'
39
+ Provides-Extra: embeddings
40
+ Requires-Dist: sentence-transformers>=2.0; extra == 'embeddings'
41
+ Provides-Extra: gcs
42
+ Requires-Dist: gcsfs>=2023.1; extra == 'gcs'
43
+ Provides-Extra: jax
44
+ Requires-Dist: jax>=0.4; extra == 'jax'
45
+ Requires-Dist: jaxlib>=0.4; extra == 'jax'
46
+ Provides-Extra: s3
47
+ Requires-Dist: s3fs>=2023.1; extra == 's3'
48
+ Provides-Extra: sim
49
+ Requires-Dist: numpy>=1.24; extra == 'sim'
50
+ Requires-Dist: torch>=2.0; extra == 'sim'
51
+ Provides-Extra: tf
52
+ Requires-Dist: tensorflow>=2.14; extra == 'tf'
53
+ Provides-Extra: torch
54
+ Requires-Dist: torch>=2.0; extra == 'torch'
55
+ Provides-Extra: training
56
+ Requires-Dist: torch>=2.0; extra == 'training'
57
+ Description-Content-Type: text/markdown
58
+
59
+ # Traceplane
60
+
61
+ Python SDK for the Traceplane trajectory data platform.
62
+
63
+ ## Installation
64
+
65
+ ```bash
66
+ pip install traceplane
67
+ ```
68
+
69
+ With framework extras:
70
+
71
+ ```bash
72
+ pip install traceplane[torch] # PyTorch DataLoader
73
+ pip install traceplane[jax] # JAX support
74
+ pip install traceplane[training] # Diffusion policy training
75
+ pip install traceplane[all] # Everything
76
+ ```
77
+
78
+ ## Quick Start
79
+
80
+ ```python
81
+ from traceplane import TraceplaneClient
82
+
83
+ client = TraceplaneClient("https://api.traceplane.ai", api_key="tp_live_...")
84
+
85
+ # Register a dataset
86
+ client.register("my_data", "/path/to/dataset", include_data=True)
87
+
88
+ # Query with SQL
89
+ rows = client.sql_rows("SELECT * FROM my_data WHERE frame_count > 100")
90
+
91
+ # Upload data
92
+ client.upload_dataset("my_data", "/path/to/parquet/files/")
93
+
94
+ # Vector search
95
+ results = client.search_similar("my_data", episode_index=0, k=5)
96
+ ```
97
+
98
+ ## Features
99
+
100
+ - **SQL query engine** -- register datasets and query with full SQL, including vector UDFs (`vec_mean`, `vec_norm`, `vec_cosine_sim`, etc.)
101
+ - **Streaming dataloaders** -- PyTorch, JAX, and TensorFlow adapters with windowed sampling
102
+ - **LeRobot format** -- native reader for LeRobot v2/v3 datasets (Parquet + MP4)
103
+ - **Similarity search** -- find related episodes via embedding-based vector search
104
+ - **Dataset upload** -- push local Parquet files to the platform
105
+ - **Retargeting** -- XR hand poses to robot action space via calibration bridge
106
+ - **Training** -- built-in diffusion policy training with `traceplane-train` CLI
107
+
108
+ ## Training Integration
109
+
110
+ ```python
111
+ from traceplane import LeRobotReader
112
+ from traceplane.torch import TorchEpisodeLoader
113
+
114
+ reader = LeRobotReader("/path/to/lerobot/dataset")
115
+ loader = TorchEpisodeLoader(reader, batch_size=32, window_size=16)
116
+
117
+ for batch in loader:
118
+ observations = batch["observation"]
119
+ actions = batch["action"]
120
+ # ... your training loop
121
+ ```
122
+
123
+ ## API Reference
124
+
125
+ Full documentation: [docs.traceplane.ai](https://docs.traceplane.ai)
126
+
127
+ ## License
128
+
129
+ Apache-2.0
@@ -0,0 +1,71 @@
1
+ # Traceplane
2
+
3
+ Python SDK for the Traceplane trajectory data platform.
4
+
5
+ ## Installation
6
+
7
+ ```bash
8
+ pip install traceplane
9
+ ```
10
+
11
+ With framework extras:
12
+
13
+ ```bash
14
+ pip install traceplane[torch] # PyTorch DataLoader
15
+ pip install traceplane[jax] # JAX support
16
+ pip install traceplane[training] # Diffusion policy training
17
+ pip install traceplane[all] # Everything
18
+ ```
19
+
20
+ ## Quick Start
21
+
22
+ ```python
23
+ from traceplane import TraceplaneClient
24
+
25
+ client = TraceplaneClient("https://api.traceplane.ai", api_key="tp_live_...")
26
+
27
+ # Register a dataset
28
+ client.register("my_data", "/path/to/dataset", include_data=True)
29
+
30
+ # Query with SQL
31
+ rows = client.sql_rows("SELECT * FROM my_data WHERE frame_count > 100")
32
+
33
+ # Upload data
34
+ client.upload_dataset("my_data", "/path/to/parquet/files/")
35
+
36
+ # Vector search
37
+ results = client.search_similar("my_data", episode_index=0, k=5)
38
+ ```
39
+
40
+ ## Features
41
+
42
+ - **SQL query engine** -- register datasets and query with full SQL, including vector UDFs (`vec_mean`, `vec_norm`, `vec_cosine_sim`, etc.)
43
+ - **Streaming dataloaders** -- PyTorch, JAX, and TensorFlow adapters with windowed sampling
44
+ - **LeRobot format** -- native reader for LeRobot v2/v3 datasets (Parquet + MP4)
45
+ - **Similarity search** -- find related episodes via embedding-based vector search
46
+ - **Dataset upload** -- push local Parquet files to the platform
47
+ - **Retargeting** -- XR hand poses to robot action space via calibration bridge
48
+ - **Training** -- built-in diffusion policy training with `traceplane-train` CLI
49
+
50
+ ## Training Integration
51
+
52
+ ```python
53
+ from traceplane import LeRobotReader
54
+ from traceplane.torch import TorchEpisodeLoader
55
+
56
+ reader = LeRobotReader("/path/to/lerobot/dataset")
57
+ loader = TorchEpisodeLoader(reader, batch_size=32, window_size=16)
58
+
59
+ for batch in loader:
60
+ observations = batch["observation"]
61
+ actions = batch["action"]
62
+ # ... your training loop
63
+ ```
64
+
65
+ ## API Reference
66
+
67
+ Full documentation: [docs.traceplane.ai](https://docs.traceplane.ai)
68
+
69
+ ## License
70
+
71
+ Apache-2.0
@@ -0,0 +1,58 @@
1
+ [build-system]
2
+ requires = ["hatchling"]
3
+ build-backend = "hatchling.build"
4
+
5
+ [project]
6
+ name = "traceplane"
7
+ version = "0.1.0"
8
+ description = "Streaming dataloader for robotics trajectory datasets"
9
+ requires-python = ">=3.10"
10
+ license = "Apache-2.0"
11
+ authors = [{name = "Traceplane", email = "hello@traceplane.ai"}]
12
+ readme = "README.md"
13
+ keywords = ["robotics", "trajectories", "datasets", "imitation-learning", "lerobot"]
14
+ classifiers = [
15
+ "Development Status :: 4 - Beta",
16
+ "Intended Audience :: Science/Research",
17
+ "License :: OSI Approved :: Apache Software License",
18
+ "Programming Language :: Python :: 3",
19
+ "Programming Language :: Python :: 3.10",
20
+ "Programming Language :: Python :: 3.11",
21
+ "Programming Language :: Python :: 3.12",
22
+ "Programming Language :: Python :: 3.13",
23
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
24
+ "Topic :: Scientific/Engineering",
25
+ ]
26
+ dependencies = [
27
+ "numpy>=1.24",
28
+ "pyarrow>=14.0",
29
+ "fsspec>=2023.1",
30
+ "requests>=2.28",
31
+ ]
32
+
33
+ [project.optional-dependencies]
34
+ torch = ["torch>=2.0"]
35
+ jax = ["jax>=0.4", "jaxlib>=0.4"]
36
+ tf = ["tensorflow>=2.14"]
37
+ s3 = ["s3fs>=2023.1"]
38
+ gcs = ["gcsfs>=2023.1"]
39
+ training = ["torch>=2.0"]
40
+ embeddings = ["sentence-transformers>=2.0"]
41
+ sim = ["torch>=2.0", "numpy>=1.24"]
42
+ all = ["traceplane[torch,jax,tf,s3,gcs,training,embeddings,sim]"]
43
+ dev = ["pytest>=7", "traceplane[torch]"]
44
+
45
+ [project.scripts]
46
+ traceplane-train = "traceplane.training.cli:main"
47
+ traceplane-embed = "traceplane.embeddings:main"
48
+ traceplane-sim-viz = "traceplane.sim.cli:main_viz"
49
+ traceplane-sim-eval = "traceplane.sim.cli:main_eval"
50
+
51
+ [project.urls]
52
+ Homepage = "https://traceplane.ai"
53
+ Documentation = "https://docs.traceplane.ai"
54
+ Repository = "https://github.com/traceplane/traceplane"
55
+ Issues = "https://github.com/traceplane/traceplane/issues"
56
+
57
+ [tool.hatch.build.targets.wheel]
58
+ packages = ["src/traceplane"]
@@ -0,0 +1,85 @@
1
+ #!/usr/bin/env bash
2
+ # Traceplane Isaac Sim Setup — Ubuntu 22.04/24.04 + NVIDIA RTX 5080
3
+ #
4
+ # Prerequisites:
5
+ # - Ubuntu 22.04 or 24.04
6
+ # - NVIDIA driver >= 565 (required for RTX 5080 / Blackwell)
7
+ # - Python 3.10+
8
+ #
9
+ # Usage:
10
+ # chmod +x scripts/setup_isaac_sim.sh
11
+ # ./scripts/setup_isaac_sim.sh
12
+
13
+ set -euo pipefail
14
+
15
+ echo "=== Traceplane Isaac Sim Setup ==="
16
+ echo ""
17
+
18
+ # Check OS
19
+ if ! grep -qE "22\.04|24\.04" /etc/lsb-release 2>/dev/null; then
20
+ echo "WARNING: This script targets Ubuntu 22.04/24.04."
21
+ echo "Current OS: $(lsb_release -ds 2>/dev/null || echo 'unknown')"
22
+ echo ""
23
+ fi
24
+
25
+ # Check NVIDIA driver
26
+ if ! command -v nvidia-smi &>/dev/null; then
27
+ echo "ERROR: nvidia-smi not found. Install NVIDIA driver first:"
28
+ echo " sudo apt update && sudo apt install nvidia-driver-565"
29
+ exit 1
30
+ fi
31
+
32
+ DRIVER_VERSION=$(nvidia-smi --query-gpu=driver_version --format=csv,noheader | head -1)
33
+ echo "NVIDIA driver: $DRIVER_VERSION"
34
+
35
+ DRIVER_MAJOR=$(echo "$DRIVER_VERSION" | cut -d. -f1)
36
+ if [ "$DRIVER_MAJOR" -lt 565 ]; then
37
+ echo "WARNING: Driver $DRIVER_VERSION may be too old for RTX 5080."
38
+ echo "Recommended: >= 565. Install with:"
39
+ echo " sudo apt install nvidia-driver-565"
40
+ echo ""
41
+ fi
42
+
43
+ # Check GPU
44
+ GPU_NAME=$(nvidia-smi --query-gpu=name --format=csv,noheader | head -1)
45
+ echo "GPU: $GPU_NAME"
46
+ echo ""
47
+
48
+ # Check Python
49
+ PYTHON=${PYTHON:-python3}
50
+ PY_VERSION=$($PYTHON --version 2>&1)
51
+ echo "Python: $PY_VERSION"
52
+ echo ""
53
+
54
+ # Install Isaac Sim
55
+ echo "=== Installing Isaac Sim 5.x ==="
56
+ echo "This may take several minutes..."
57
+ $PYTHON -m pip install isaacsim==5.* --extra-index-url https://pypi.nvidia.com
58
+
59
+ # Install Traceplane sim module
60
+ echo ""
61
+ echo "=== Installing Traceplane sim module ==="
62
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
63
+ DATALOADER_DIR="$(dirname "$SCRIPT_DIR")"
64
+ $PYTHON -m pip install -e "$DATALOADER_DIR[sim]"
65
+
66
+ # Smoke test
67
+ echo ""
68
+ echo "=== Smoke test ==="
69
+ $PYTHON -c "
70
+ from isaacsim import SimulationApp
71
+ app = SimulationApp({'headless': True})
72
+ print('Isaac Sim loaded successfully')
73
+ app.close()
74
+ print('Smoke test passed!')
75
+ "
76
+
77
+ echo ""
78
+ echo "=== Setup complete ==="
79
+ echo ""
80
+ echo "Quick start:"
81
+ echo " # Replay a dataset episode in sim"
82
+ echo " traceplane-sim-viz --dataset-path /path/to/dataset --episode-index 0"
83
+ echo ""
84
+ echo " # Evaluate a policy"
85
+ echo " traceplane-sim-eval /path/to/checkpoint.pt --num-episodes 10 --headless"
@@ -0,0 +1,16 @@
1
+ """Traceplane — streaming dataloader for robotics trajectory datasets."""
2
+
3
+ from traceplane.dataset import Episode, TrajectoryDataset
4
+ from traceplane.lerobot_reader import LeRobotReader
5
+ from traceplane.windowing import WindowedDataset
6
+ from traceplane.query import TraceplaneClient
7
+
8
+ __version__ = "0.1.0"
9
+
10
+ __all__ = [
11
+ "Episode",
12
+ "TrajectoryDataset",
13
+ "LeRobotReader",
14
+ "WindowedDataset",
15
+ "TraceplaneClient",
16
+ ]
@@ -0,0 +1,93 @@
1
+ """Core dataset abstractions."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass, field
6
+ from typing import Any, Iterator, Sequence
7
+
8
+ import numpy as np
9
+
10
+
11
+ @dataclass
12
+ class Episode:
13
+ """A single trajectory episode.
14
+
15
+ Attributes:
16
+ episode_id: Unique identifier (e.g. "episode_000042").
17
+ observations: Dict of observation arrays keyed by modality.
18
+ Common keys: "state" (proprioception), "action", image camera names.
19
+ Each value is shape (T, D) for vectors or (T, H, W, C) for images.
20
+ actions: Action array, shape (T, action_dim).
21
+ timestamps: Monotonic timestamps in seconds, shape (T,).
22
+ metadata: Arbitrary episode metadata (task label, fps, success, etc.).
23
+ """
24
+
25
+ episode_id: str
26
+ observations: dict[str, np.ndarray] = field(default_factory=dict)
27
+ actions: np.ndarray = field(default_factory=lambda: np.empty((0,)))
28
+ timestamps: np.ndarray = field(default_factory=lambda: np.empty((0,)))
29
+ metadata: dict[str, Any] = field(default_factory=dict)
30
+
31
+ @property
32
+ def length(self) -> int:
33
+ """Number of timesteps."""
34
+ if self.actions.ndim >= 1 and self.actions.shape[0] > 0:
35
+ return self.actions.shape[0]
36
+ if self.timestamps.ndim >= 1 and self.timestamps.shape[0] > 0:
37
+ return self.timestamps.shape[0]
38
+ for v in self.observations.values():
39
+ if hasattr(v, "shape") and v.shape[0] > 0:
40
+ return v.shape[0]
41
+ return 0
42
+
43
+ @property
44
+ def action_dim(self) -> int:
45
+ if self.actions.ndim == 2:
46
+ return self.actions.shape[1]
47
+ return 0
48
+
49
+
50
+ class TrajectoryDataset:
51
+ """Abstract base for trajectory datasets.
52
+
53
+ Subclasses must implement ``__len__`` and ``__getitem__``.
54
+ """
55
+
56
+ def __len__(self) -> int:
57
+ raise NotImplementedError
58
+
59
+ def __getitem__(self, idx: int) -> Episode:
60
+ raise NotImplementedError
61
+
62
+ def __iter__(self) -> Iterator[Episode]:
63
+ for i in range(len(self)):
64
+ yield self[i]
65
+
66
+ def episode_ids(self) -> list[str]:
67
+ """Return all episode IDs in order."""
68
+ raise NotImplementedError
69
+
70
+ def filter(self, episode_ids: Sequence[str]) -> "FilteredDataset":
71
+ """Return a view containing only the specified episodes."""
72
+ return FilteredDataset(self, list(episode_ids))
73
+
74
+
75
+ class FilteredDataset(TrajectoryDataset):
76
+ """A filtered view over another dataset."""
77
+
78
+ def __init__(self, parent: TrajectoryDataset, episode_ids: list[str]):
79
+ self._parent = parent
80
+ self._ids = episode_ids
81
+ # Build index map: episode_id -> parent index
82
+ parent_ids = parent.episode_ids()
83
+ self._id_to_idx = {eid: i for i, eid in enumerate(parent_ids)}
84
+ self._indices = [self._id_to_idx[eid] for eid in episode_ids if eid in self._id_to_idx]
85
+
86
+ def __len__(self) -> int:
87
+ return len(self._indices)
88
+
89
+ def __getitem__(self, idx: int) -> Episode:
90
+ return self._parent[self._indices[idx]]
91
+
92
+ def episode_ids(self) -> list[str]:
93
+ return [self._ids[i] for i in range(len(self._indices))]
@@ -0,0 +1,207 @@
1
+ """Episode embedding computation for similarity search.
2
+
3
+ Computes per-episode feature vectors and writes them to Parquet for
4
+ use with the DataFusion ``vec_cosine_sim`` UDF.
5
+
6
+ Two strategies:
7
+ - **Trajectory features** (always available): statistical aggregates
8
+ of observation.state and action vectors per episode.
9
+ - **Text embeddings** (optional, requires ``sentence-transformers``):
10
+ encodes task_label strings with a pretrained language model.
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import argparse
16
+ import os
17
+ import sys
18
+ from pathlib import Path
19
+ from typing import Any
20
+
21
+ import numpy as np
22
+ import pyarrow as pa
23
+ import pyarrow.parquet as pq
24
+
25
+ from traceplane.dataset import Episode
26
+ from traceplane.lerobot_reader import LeRobotReader
27
+
28
+
29
+ def compute_trajectory_embedding(episode: Episode) -> np.ndarray:
30
+ """Compute a feature vector from an episode's state and action statistics.
31
+
32
+ Concatenates [mean, std, min, max] per dimension for both
33
+ observation.state and action, then L2-normalises.
34
+
35
+ Returns:
36
+ 1-D float32 array.
37
+ """
38
+ parts: list[np.ndarray] = []
39
+
40
+ # State features
41
+ state = episode.observations.get("state")
42
+ if state is not None and state.ndim == 2 and state.shape[0] > 0:
43
+ parts.extend(_stat_features(state))
44
+
45
+ # Action features
46
+ if episode.actions.ndim == 2 and episode.actions.shape[0] > 0:
47
+ parts.extend(_stat_features(episode.actions))
48
+
49
+ if not parts:
50
+ # Fallback: zero vector
51
+ return np.zeros(1, dtype=np.float32)
52
+
53
+ vec = np.concatenate(parts).astype(np.float32)
54
+
55
+ # L2 normalise
56
+ norm = np.linalg.norm(vec)
57
+ if norm > 1e-8:
58
+ vec /= norm
59
+
60
+ return vec
61
+
62
+
63
+ def _stat_features(arr: np.ndarray) -> list[np.ndarray]:
64
+ """Compute [mean, std, min, max] per dimension."""
65
+ return [
66
+ arr.mean(axis=0),
67
+ arr.std(axis=0),
68
+ arr.min(axis=0),
69
+ arr.max(axis=0),
70
+ ]
71
+
72
+
73
+ def compute_text_embeddings(
74
+ labels: list[str],
75
+ model_name: str = "sentence-transformers/all-MiniLM-L6-v2",
76
+ ) -> np.ndarray:
77
+ """Encode text labels with a sentence-transformer model.
78
+
79
+ Returns:
80
+ (N, D) float32 array of embeddings.
81
+ """
82
+ try:
83
+ from sentence_transformers import SentenceTransformer
84
+ except ImportError:
85
+ raise ImportError(
86
+ "sentence-transformers is required for text embeddings. "
87
+ "Install with: pip install traceplane[embeddings]"
88
+ )
89
+ model = SentenceTransformer(model_name)
90
+ embeddings = model.encode(labels, show_progress_bar=False, normalize_embeddings=True)
91
+ return np.asarray(embeddings, dtype=np.float32)
92
+
93
+
94
+ def compute_embeddings(
95
+ dataset_path: str,
96
+ output_dir: str | None = None,
97
+ include_text: bool = False,
98
+ episode_indices: list[int] | None = None,
99
+ storage_options: dict[str, Any] | None = None,
100
+ ) -> str:
101
+ """Compute episode embeddings and write to Parquet.
102
+
103
+ Args:
104
+ dataset_path: Path to a LeRobot dataset.
105
+ output_dir: Output directory. Defaults to ``{dataset_path}/embeddings``.
106
+ include_text: Also compute text embeddings from task labels.
107
+ episode_indices: Subset of episodes. None = all.
108
+ storage_options: fsspec options for remote datasets.
109
+
110
+ Returns:
111
+ Path to the written Parquet file.
112
+ """
113
+ reader = LeRobotReader(
114
+ dataset_path,
115
+ storage_options=storage_options,
116
+ episode_indices=episode_indices,
117
+ )
118
+
119
+ n = len(reader)
120
+ if n == 0:
121
+ raise ValueError("Dataset has no episodes")
122
+
123
+ print(f"Computing embeddings for {n} episodes...", file=sys.stderr)
124
+
125
+ indices: list[int] = []
126
+ labels: list[str] = []
127
+ traj_embeddings: list[np.ndarray] = []
128
+
129
+ for i in range(n):
130
+ ep = reader[i]
131
+ traj_emb = compute_trajectory_embedding(ep)
132
+ traj_embeddings.append(traj_emb)
133
+
134
+ # Extract episode index from metadata or ID
135
+ ep_idx = ep.metadata.get("episode_index", i)
136
+ if isinstance(ep_idx, str):
137
+ try:
138
+ ep_idx = int(ep_idx.split("_")[-1])
139
+ except (ValueError, IndexError):
140
+ ep_idx = i
141
+ indices.append(int(ep_idx))
142
+ labels.append(ep.metadata.get("task_label", ""))
143
+
144
+ if (i + 1) % 50 == 0 or i == n - 1:
145
+ print(f" {i + 1}/{n}", file=sys.stderr)
146
+
147
+ # Build arrow arrays
148
+ traj_dim = traj_embeddings[0].shape[0]
149
+ arrow_traj = pa.list_(pa.float32())
150
+
151
+ columns: dict[str, Any] = {
152
+ "episode_index": pa.array(indices, type=pa.int64()),
153
+ "task_label": pa.array(labels, type=pa.utf8()),
154
+ "trajectory_embedding": pa.array(
155
+ [emb.tolist() for emb in traj_embeddings],
156
+ type=arrow_traj,
157
+ ),
158
+ }
159
+
160
+ # Optional text embeddings
161
+ if include_text:
162
+ unique_labels = list(set(labels))
163
+ print(f"Computing text embeddings for {len(unique_labels)} unique labels...", file=sys.stderr)
164
+ text_embs = compute_text_embeddings(unique_labels)
165
+ label_to_emb = {lbl: text_embs[i] for i, lbl in enumerate(unique_labels)}
166
+ text_vecs = [label_to_emb[lbl].tolist() for lbl in labels]
167
+ columns["text_embedding"] = pa.array(
168
+ text_vecs,
169
+ type=pa.list_(pa.float32()),
170
+ )
171
+
172
+ table = pa.table(columns)
173
+
174
+ # Write
175
+ if output_dir is None:
176
+ output_dir = os.path.join(dataset_path, "embeddings")
177
+ os.makedirs(output_dir, exist_ok=True)
178
+ output_path = os.path.join(output_dir, "episode_embeddings.parquet")
179
+ pq.write_table(table, output_path)
180
+
181
+ print(f"Written {n} embeddings ({traj_dim}-dim) to {output_path}", file=sys.stderr)
182
+ return output_path
183
+
184
+
185
+ def main(argv: list[str] | None = None) -> None:
186
+ """CLI entry point for embedding computation."""
187
+ parser = argparse.ArgumentParser(
188
+ description="Compute episode embeddings for similarity search",
189
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
190
+ )
191
+ parser.add_argument("dataset_path", help="Path to LeRobot dataset")
192
+ parser.add_argument("--output-dir", help="Output directory (default: {dataset_path}/embeddings)")
193
+ parser.add_argument("--include-text", action="store_true", help="Compute text embeddings (requires sentence-transformers)")
194
+ parser.add_argument("--episodes", type=int, nargs="+", help="Specific episode indices")
195
+
196
+ args = parser.parse_args(argv)
197
+ path = compute_embeddings(
198
+ args.dataset_path,
199
+ output_dir=args.output_dir,
200
+ include_text=args.include_text,
201
+ episode_indices=args.episodes,
202
+ )
203
+ print(path)
204
+
205
+
206
+ if __name__ == "__main__":
207
+ main()