traceplane 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- traceplane-0.1.0/.gitignore +29 -0
- traceplane-0.1.0/PKG-INFO +129 -0
- traceplane-0.1.0/README.md +71 -0
- traceplane-0.1.0/pyproject.toml +58 -0
- traceplane-0.1.0/scripts/setup_isaac_sim.sh +85 -0
- traceplane-0.1.0/src/traceplane/__init__.py +16 -0
- traceplane-0.1.0/src/traceplane/dataset.py +93 -0
- traceplane-0.1.0/src/traceplane/embeddings.py +207 -0
- traceplane-0.1.0/src/traceplane/jax.py +108 -0
- traceplane-0.1.0/src/traceplane/lerobot_reader.py +362 -0
- traceplane-0.1.0/src/traceplane/query.py +643 -0
- traceplane-0.1.0/src/traceplane/sim/__init__.py +31 -0
- traceplane-0.1.0/src/traceplane/sim/_compat.py +16 -0
- traceplane-0.1.0/src/traceplane/sim/cli.py +115 -0
- traceplane-0.1.0/src/traceplane/sim/config.py +58 -0
- traceplane-0.1.0/src/traceplane/sim/controller.py +77 -0
- traceplane-0.1.0/src/traceplane/sim/evaluator.py +230 -0
- traceplane-0.1.0/src/traceplane/sim/metrics.py +110 -0
- traceplane-0.1.0/src/traceplane/sim/robot.py +123 -0
- traceplane-0.1.0/src/traceplane/sim/scene.py +91 -0
- traceplane-0.1.0/src/traceplane/sim/visualizer.py +188 -0
- traceplane-0.1.0/src/traceplane/tf.py +86 -0
- traceplane-0.1.0/src/traceplane/torch.py +126 -0
- traceplane-0.1.0/src/traceplane/training/__init__.py +23 -0
- traceplane-0.1.0/src/traceplane/training/cli.py +92 -0
- traceplane-0.1.0/src/traceplane/training/config.py +82 -0
- traceplane-0.1.0/src/traceplane/training/diffusion_policy.py +314 -0
- traceplane-0.1.0/src/traceplane/training/eval.py +85 -0
- traceplane-0.1.0/src/traceplane/training/normalization.py +262 -0
- traceplane-0.1.0/src/traceplane/training/trainer.py +318 -0
- traceplane-0.1.0/src/traceplane/windowing.py +112 -0
- traceplane-0.1.0/tests/test_core.py +217 -0
- traceplane-0.1.0/tests/test_engine.py +91 -0
- traceplane-0.1.0/tests/test_integration.py +426 -0
- traceplane-0.1.0/tests/test_torch.py +107 -0
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
# Rust
|
|
2
|
+
backend/target/
|
|
3
|
+
|
|
4
|
+
# Node / Frontend
|
|
5
|
+
frontend/node_modules/
|
|
6
|
+
frontend/dist/
|
|
7
|
+
|
|
8
|
+
# Python
|
|
9
|
+
dataloader/__pycache__/
|
|
10
|
+
dataloader/src/traceplane/__pycache__/
|
|
11
|
+
dataloader/tests/__pycache__/
|
|
12
|
+
dataloader/.pytest_cache/
|
|
13
|
+
dataloader/src/*.egg-info/
|
|
14
|
+
dataloader/src/traceplane.egg-info/
|
|
15
|
+
*.pyc
|
|
16
|
+
*.pyo
|
|
17
|
+
__pycache__/
|
|
18
|
+
*.egg-info/
|
|
19
|
+
|
|
20
|
+
# IDE
|
|
21
|
+
.idea/
|
|
22
|
+
.vscode/
|
|
23
|
+
*.swp
|
|
24
|
+
|
|
25
|
+
# OS
|
|
26
|
+
.DS_Store
|
|
27
|
+
|
|
28
|
+
# Env
|
|
29
|
+
.env
|
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: traceplane
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Streaming dataloader for robotics trajectory datasets
|
|
5
|
+
Project-URL: Homepage, https://traceplane.ai
|
|
6
|
+
Project-URL: Documentation, https://docs.traceplane.ai
|
|
7
|
+
Project-URL: Repository, https://github.com/traceplane/traceplane
|
|
8
|
+
Project-URL: Issues, https://github.com/traceplane/traceplane/issues
|
|
9
|
+
Author-email: Traceplane <hello@traceplane.ai>
|
|
10
|
+
License-Expression: Apache-2.0
|
|
11
|
+
Keywords: datasets,imitation-learning,lerobot,robotics,trajectories
|
|
12
|
+
Classifier: Development Status :: 4 - Beta
|
|
13
|
+
Classifier: Intended Audience :: Science/Research
|
|
14
|
+
Classifier: License :: OSI Approved :: Apache Software License
|
|
15
|
+
Classifier: Programming Language :: Python :: 3
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
18
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
19
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
20
|
+
Classifier: Topic :: Scientific/Engineering
|
|
21
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
22
|
+
Requires-Python: >=3.10
|
|
23
|
+
Requires-Dist: fsspec>=2023.1
|
|
24
|
+
Requires-Dist: numpy>=1.24
|
|
25
|
+
Requires-Dist: pyarrow>=14.0
|
|
26
|
+
Requires-Dist: requests>=2.28
|
|
27
|
+
Provides-Extra: all
|
|
28
|
+
Requires-Dist: gcsfs>=2023.1; extra == 'all'
|
|
29
|
+
Requires-Dist: jax>=0.4; extra == 'all'
|
|
30
|
+
Requires-Dist: jaxlib>=0.4; extra == 'all'
|
|
31
|
+
Requires-Dist: numpy>=1.24; extra == 'all'
|
|
32
|
+
Requires-Dist: s3fs>=2023.1; extra == 'all'
|
|
33
|
+
Requires-Dist: sentence-transformers>=2.0; extra == 'all'
|
|
34
|
+
Requires-Dist: tensorflow>=2.14; extra == 'all'
|
|
35
|
+
Requires-Dist: torch>=2.0; extra == 'all'
|
|
36
|
+
Provides-Extra: dev
|
|
37
|
+
Requires-Dist: pytest>=7; extra == 'dev'
|
|
38
|
+
Requires-Dist: torch>=2.0; extra == 'dev'
|
|
39
|
+
Provides-Extra: embeddings
|
|
40
|
+
Requires-Dist: sentence-transformers>=2.0; extra == 'embeddings'
|
|
41
|
+
Provides-Extra: gcs
|
|
42
|
+
Requires-Dist: gcsfs>=2023.1; extra == 'gcs'
|
|
43
|
+
Provides-Extra: jax
|
|
44
|
+
Requires-Dist: jax>=0.4; extra == 'jax'
|
|
45
|
+
Requires-Dist: jaxlib>=0.4; extra == 'jax'
|
|
46
|
+
Provides-Extra: s3
|
|
47
|
+
Requires-Dist: s3fs>=2023.1; extra == 's3'
|
|
48
|
+
Provides-Extra: sim
|
|
49
|
+
Requires-Dist: numpy>=1.24; extra == 'sim'
|
|
50
|
+
Requires-Dist: torch>=2.0; extra == 'sim'
|
|
51
|
+
Provides-Extra: tf
|
|
52
|
+
Requires-Dist: tensorflow>=2.14; extra == 'tf'
|
|
53
|
+
Provides-Extra: torch
|
|
54
|
+
Requires-Dist: torch>=2.0; extra == 'torch'
|
|
55
|
+
Provides-Extra: training
|
|
56
|
+
Requires-Dist: torch>=2.0; extra == 'training'
|
|
57
|
+
Description-Content-Type: text/markdown
|
|
58
|
+
|
|
59
|
+
# Traceplane
|
|
60
|
+
|
|
61
|
+
Python SDK for the Traceplane trajectory data platform.
|
|
62
|
+
|
|
63
|
+
## Installation
|
|
64
|
+
|
|
65
|
+
```bash
|
|
66
|
+
pip install traceplane
|
|
67
|
+
```
|
|
68
|
+
|
|
69
|
+
With framework extras:
|
|
70
|
+
|
|
71
|
+
```bash
|
|
72
|
+
pip install traceplane[torch] # PyTorch DataLoader
|
|
73
|
+
pip install traceplane[jax] # JAX support
|
|
74
|
+
pip install traceplane[training] # Diffusion policy training
|
|
75
|
+
pip install traceplane[all] # Everything
|
|
76
|
+
```
|
|
77
|
+
|
|
78
|
+
## Quick Start
|
|
79
|
+
|
|
80
|
+
```python
|
|
81
|
+
from traceplane import TraceplaneClient
|
|
82
|
+
|
|
83
|
+
client = TraceplaneClient("https://api.traceplane.ai", api_key="tp_live_...")
|
|
84
|
+
|
|
85
|
+
# Register a dataset
|
|
86
|
+
client.register("my_data", "/path/to/dataset", include_data=True)
|
|
87
|
+
|
|
88
|
+
# Query with SQL
|
|
89
|
+
rows = client.sql_rows("SELECT * FROM my_data WHERE frame_count > 100")
|
|
90
|
+
|
|
91
|
+
# Upload data
|
|
92
|
+
client.upload_dataset("my_data", "/path/to/parquet/files/")
|
|
93
|
+
|
|
94
|
+
# Vector search
|
|
95
|
+
results = client.search_similar("my_data", episode_index=0, k=5)
|
|
96
|
+
```
|
|
97
|
+
|
|
98
|
+
## Features
|
|
99
|
+
|
|
100
|
+
- **SQL query engine** -- register datasets and query with full SQL, including vector UDFs (`vec_mean`, `vec_norm`, `vec_cosine_sim`, etc.)
|
|
101
|
+
- **Streaming dataloaders** -- PyTorch, JAX, and TensorFlow adapters with windowed sampling
|
|
102
|
+
- **LeRobot format** -- native reader for LeRobot v2/v3 datasets (Parquet + MP4)
|
|
103
|
+
- **Similarity search** -- find related episodes via embedding-based vector search
|
|
104
|
+
- **Dataset upload** -- push local Parquet files to the platform
|
|
105
|
+
- **Retargeting** -- XR hand poses to robot action space via calibration bridge
|
|
106
|
+
- **Training** -- built-in diffusion policy training with `traceplane-train` CLI
|
|
107
|
+
|
|
108
|
+
## Training Integration
|
|
109
|
+
|
|
110
|
+
```python
|
|
111
|
+
from traceplane import LeRobotReader
|
|
112
|
+
from traceplane.torch import TorchEpisodeLoader
|
|
113
|
+
|
|
114
|
+
reader = LeRobotReader("/path/to/lerobot/dataset")
|
|
115
|
+
loader = TorchEpisodeLoader(reader, batch_size=32, window_size=16)
|
|
116
|
+
|
|
117
|
+
for batch in loader:
|
|
118
|
+
observations = batch["observation"]
|
|
119
|
+
actions = batch["action"]
|
|
120
|
+
# ... your training loop
|
|
121
|
+
```
|
|
122
|
+
|
|
123
|
+
## API Reference
|
|
124
|
+
|
|
125
|
+
Full documentation: [docs.traceplane.ai](https://docs.traceplane.ai)
|
|
126
|
+
|
|
127
|
+
## License
|
|
128
|
+
|
|
129
|
+
Apache-2.0
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
# Traceplane
|
|
2
|
+
|
|
3
|
+
Python SDK for the Traceplane trajectory data platform.
|
|
4
|
+
|
|
5
|
+
## Installation
|
|
6
|
+
|
|
7
|
+
```bash
|
|
8
|
+
pip install traceplane
|
|
9
|
+
```
|
|
10
|
+
|
|
11
|
+
With framework extras:
|
|
12
|
+
|
|
13
|
+
```bash
|
|
14
|
+
pip install traceplane[torch] # PyTorch DataLoader
|
|
15
|
+
pip install traceplane[jax] # JAX support
|
|
16
|
+
pip install traceplane[training] # Diffusion policy training
|
|
17
|
+
pip install traceplane[all] # Everything
|
|
18
|
+
```
|
|
19
|
+
|
|
20
|
+
## Quick Start
|
|
21
|
+
|
|
22
|
+
```python
|
|
23
|
+
from traceplane import TraceplaneClient
|
|
24
|
+
|
|
25
|
+
client = TraceplaneClient("https://api.traceplane.ai", api_key="tp_live_...")
|
|
26
|
+
|
|
27
|
+
# Register a dataset
|
|
28
|
+
client.register("my_data", "/path/to/dataset", include_data=True)
|
|
29
|
+
|
|
30
|
+
# Query with SQL
|
|
31
|
+
rows = client.sql_rows("SELECT * FROM my_data WHERE frame_count > 100")
|
|
32
|
+
|
|
33
|
+
# Upload data
|
|
34
|
+
client.upload_dataset("my_data", "/path/to/parquet/files/")
|
|
35
|
+
|
|
36
|
+
# Vector search
|
|
37
|
+
results = client.search_similar("my_data", episode_index=0, k=5)
|
|
38
|
+
```
|
|
39
|
+
|
|
40
|
+
## Features
|
|
41
|
+
|
|
42
|
+
- **SQL query engine** -- register datasets and query with full SQL, including vector UDFs (`vec_mean`, `vec_norm`, `vec_cosine_sim`, etc.)
|
|
43
|
+
- **Streaming dataloaders** -- PyTorch, JAX, and TensorFlow adapters with windowed sampling
|
|
44
|
+
- **LeRobot format** -- native reader for LeRobot v2/v3 datasets (Parquet + MP4)
|
|
45
|
+
- **Similarity search** -- find related episodes via embedding-based vector search
|
|
46
|
+
- **Dataset upload** -- push local Parquet files to the platform
|
|
47
|
+
- **Retargeting** -- XR hand poses to robot action space via calibration bridge
|
|
48
|
+
- **Training** -- built-in diffusion policy training with `traceplane-train` CLI
|
|
49
|
+
|
|
50
|
+
## Training Integration
|
|
51
|
+
|
|
52
|
+
```python
|
|
53
|
+
from traceplane import LeRobotReader
|
|
54
|
+
from traceplane.torch import TorchEpisodeLoader
|
|
55
|
+
|
|
56
|
+
reader = LeRobotReader("/path/to/lerobot/dataset")
|
|
57
|
+
loader = TorchEpisodeLoader(reader, batch_size=32, window_size=16)
|
|
58
|
+
|
|
59
|
+
for batch in loader:
|
|
60
|
+
observations = batch["observation"]
|
|
61
|
+
actions = batch["action"]
|
|
62
|
+
# ... your training loop
|
|
63
|
+
```
|
|
64
|
+
|
|
65
|
+
## API Reference
|
|
66
|
+
|
|
67
|
+
Full documentation: [docs.traceplane.ai](https://docs.traceplane.ai)
|
|
68
|
+
|
|
69
|
+
## License
|
|
70
|
+
|
|
71
|
+
Apache-2.0
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["hatchling"]
|
|
3
|
+
build-backend = "hatchling.build"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "traceplane"
|
|
7
|
+
version = "0.1.0"
|
|
8
|
+
description = "Streaming dataloader for robotics trajectory datasets"
|
|
9
|
+
requires-python = ">=3.10"
|
|
10
|
+
license = "Apache-2.0"
|
|
11
|
+
authors = [{name = "Traceplane", email = "hello@traceplane.ai"}]
|
|
12
|
+
readme = "README.md"
|
|
13
|
+
keywords = ["robotics", "trajectories", "datasets", "imitation-learning", "lerobot"]
|
|
14
|
+
classifiers = [
|
|
15
|
+
"Development Status :: 4 - Beta",
|
|
16
|
+
"Intended Audience :: Science/Research",
|
|
17
|
+
"License :: OSI Approved :: Apache Software License",
|
|
18
|
+
"Programming Language :: Python :: 3",
|
|
19
|
+
"Programming Language :: Python :: 3.10",
|
|
20
|
+
"Programming Language :: Python :: 3.11",
|
|
21
|
+
"Programming Language :: Python :: 3.12",
|
|
22
|
+
"Programming Language :: Python :: 3.13",
|
|
23
|
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
|
24
|
+
"Topic :: Scientific/Engineering",
|
|
25
|
+
]
|
|
26
|
+
dependencies = [
|
|
27
|
+
"numpy>=1.24",
|
|
28
|
+
"pyarrow>=14.0",
|
|
29
|
+
"fsspec>=2023.1",
|
|
30
|
+
"requests>=2.28",
|
|
31
|
+
]
|
|
32
|
+
|
|
33
|
+
[project.optional-dependencies]
|
|
34
|
+
torch = ["torch>=2.0"]
|
|
35
|
+
jax = ["jax>=0.4", "jaxlib>=0.4"]
|
|
36
|
+
tf = ["tensorflow>=2.14"]
|
|
37
|
+
s3 = ["s3fs>=2023.1"]
|
|
38
|
+
gcs = ["gcsfs>=2023.1"]
|
|
39
|
+
training = ["torch>=2.0"]
|
|
40
|
+
embeddings = ["sentence-transformers>=2.0"]
|
|
41
|
+
sim = ["torch>=2.0", "numpy>=1.24"]
|
|
42
|
+
all = ["traceplane[torch,jax,tf,s3,gcs,training,embeddings,sim]"]
|
|
43
|
+
dev = ["pytest>=7", "traceplane[torch]"]
|
|
44
|
+
|
|
45
|
+
[project.scripts]
|
|
46
|
+
traceplane-train = "traceplane.training.cli:main"
|
|
47
|
+
traceplane-embed = "traceplane.embeddings:main"
|
|
48
|
+
traceplane-sim-viz = "traceplane.sim.cli:main_viz"
|
|
49
|
+
traceplane-sim-eval = "traceplane.sim.cli:main_eval"
|
|
50
|
+
|
|
51
|
+
[project.urls]
|
|
52
|
+
Homepage = "https://traceplane.ai"
|
|
53
|
+
Documentation = "https://docs.traceplane.ai"
|
|
54
|
+
Repository = "https://github.com/traceplane/traceplane"
|
|
55
|
+
Issues = "https://github.com/traceplane/traceplane/issues"
|
|
56
|
+
|
|
57
|
+
[tool.hatch.build.targets.wheel]
|
|
58
|
+
packages = ["src/traceplane"]
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
#!/usr/bin/env bash
|
|
2
|
+
# Traceplane Isaac Sim Setup — Ubuntu 22.04/24.04 + NVIDIA RTX 5080
|
|
3
|
+
#
|
|
4
|
+
# Prerequisites:
|
|
5
|
+
# - Ubuntu 22.04 or 24.04
|
|
6
|
+
# - NVIDIA driver >= 565 (required for RTX 5080 / Blackwell)
|
|
7
|
+
# - Python 3.10+
|
|
8
|
+
#
|
|
9
|
+
# Usage:
|
|
10
|
+
# chmod +x scripts/setup_isaac_sim.sh
|
|
11
|
+
# ./scripts/setup_isaac_sim.sh
|
|
12
|
+
|
|
13
|
+
set -euo pipefail
|
|
14
|
+
|
|
15
|
+
echo "=== Traceplane Isaac Sim Setup ==="
|
|
16
|
+
echo ""
|
|
17
|
+
|
|
18
|
+
# Check OS
|
|
19
|
+
if ! grep -qE "22\.04|24\.04" /etc/lsb-release 2>/dev/null; then
|
|
20
|
+
echo "WARNING: This script targets Ubuntu 22.04/24.04."
|
|
21
|
+
echo "Current OS: $(lsb_release -ds 2>/dev/null || echo 'unknown')"
|
|
22
|
+
echo ""
|
|
23
|
+
fi
|
|
24
|
+
|
|
25
|
+
# Check NVIDIA driver
|
|
26
|
+
if ! command -v nvidia-smi &>/dev/null; then
|
|
27
|
+
echo "ERROR: nvidia-smi not found. Install NVIDIA driver first:"
|
|
28
|
+
echo " sudo apt update && sudo apt install nvidia-driver-565"
|
|
29
|
+
exit 1
|
|
30
|
+
fi
|
|
31
|
+
|
|
32
|
+
DRIVER_VERSION=$(nvidia-smi --query-gpu=driver_version --format=csv,noheader | head -1)
|
|
33
|
+
echo "NVIDIA driver: $DRIVER_VERSION"
|
|
34
|
+
|
|
35
|
+
DRIVER_MAJOR=$(echo "$DRIVER_VERSION" | cut -d. -f1)
|
|
36
|
+
if [ "$DRIVER_MAJOR" -lt 565 ]; then
|
|
37
|
+
echo "WARNING: Driver $DRIVER_VERSION may be too old for RTX 5080."
|
|
38
|
+
echo "Recommended: >= 565. Install with:"
|
|
39
|
+
echo " sudo apt install nvidia-driver-565"
|
|
40
|
+
echo ""
|
|
41
|
+
fi
|
|
42
|
+
|
|
43
|
+
# Check GPU
|
|
44
|
+
GPU_NAME=$(nvidia-smi --query-gpu=name --format=csv,noheader | head -1)
|
|
45
|
+
echo "GPU: $GPU_NAME"
|
|
46
|
+
echo ""
|
|
47
|
+
|
|
48
|
+
# Check Python
|
|
49
|
+
PYTHON=${PYTHON:-python3}
|
|
50
|
+
PY_VERSION=$($PYTHON --version 2>&1)
|
|
51
|
+
echo "Python: $PY_VERSION"
|
|
52
|
+
echo ""
|
|
53
|
+
|
|
54
|
+
# Install Isaac Sim
|
|
55
|
+
echo "=== Installing Isaac Sim 5.x ==="
|
|
56
|
+
echo "This may take several minutes..."
|
|
57
|
+
$PYTHON -m pip install isaacsim==5.* --extra-index-url https://pypi.nvidia.com
|
|
58
|
+
|
|
59
|
+
# Install Traceplane sim module
|
|
60
|
+
echo ""
|
|
61
|
+
echo "=== Installing Traceplane sim module ==="
|
|
62
|
+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
|
63
|
+
DATALOADER_DIR="$(dirname "$SCRIPT_DIR")"
|
|
64
|
+
$PYTHON -m pip install -e "$DATALOADER_DIR[sim]"
|
|
65
|
+
|
|
66
|
+
# Smoke test
|
|
67
|
+
echo ""
|
|
68
|
+
echo "=== Smoke test ==="
|
|
69
|
+
$PYTHON -c "
|
|
70
|
+
from isaacsim import SimulationApp
|
|
71
|
+
app = SimulationApp({'headless': True})
|
|
72
|
+
print('Isaac Sim loaded successfully')
|
|
73
|
+
app.close()
|
|
74
|
+
print('Smoke test passed!')
|
|
75
|
+
"
|
|
76
|
+
|
|
77
|
+
echo ""
|
|
78
|
+
echo "=== Setup complete ==="
|
|
79
|
+
echo ""
|
|
80
|
+
echo "Quick start:"
|
|
81
|
+
echo " # Replay a dataset episode in sim"
|
|
82
|
+
echo " traceplane-sim-viz --dataset-path /path/to/dataset --episode-index 0"
|
|
83
|
+
echo ""
|
|
84
|
+
echo " # Evaluate a policy"
|
|
85
|
+
echo " traceplane-sim-eval /path/to/checkpoint.pt --num-episodes 10 --headless"
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
"""Traceplane — streaming dataloader for robotics trajectory datasets."""
|
|
2
|
+
|
|
3
|
+
from traceplane.dataset import Episode, TrajectoryDataset
|
|
4
|
+
from traceplane.lerobot_reader import LeRobotReader
|
|
5
|
+
from traceplane.windowing import WindowedDataset
|
|
6
|
+
from traceplane.query import TraceplaneClient
|
|
7
|
+
|
|
8
|
+
__version__ = "0.1.0"
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"Episode",
|
|
12
|
+
"TrajectoryDataset",
|
|
13
|
+
"LeRobotReader",
|
|
14
|
+
"WindowedDataset",
|
|
15
|
+
"TraceplaneClient",
|
|
16
|
+
]
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
"""Core dataset abstractions."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from typing import Any, Iterator, Sequence
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class Episode:
|
|
13
|
+
"""A single trajectory episode.
|
|
14
|
+
|
|
15
|
+
Attributes:
|
|
16
|
+
episode_id: Unique identifier (e.g. "episode_000042").
|
|
17
|
+
observations: Dict of observation arrays keyed by modality.
|
|
18
|
+
Common keys: "state" (proprioception), "action", image camera names.
|
|
19
|
+
Each value is shape (T, D) for vectors or (T, H, W, C) for images.
|
|
20
|
+
actions: Action array, shape (T, action_dim).
|
|
21
|
+
timestamps: Monotonic timestamps in seconds, shape (T,).
|
|
22
|
+
metadata: Arbitrary episode metadata (task label, fps, success, etc.).
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
episode_id: str
|
|
26
|
+
observations: dict[str, np.ndarray] = field(default_factory=dict)
|
|
27
|
+
actions: np.ndarray = field(default_factory=lambda: np.empty((0,)))
|
|
28
|
+
timestamps: np.ndarray = field(default_factory=lambda: np.empty((0,)))
|
|
29
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
30
|
+
|
|
31
|
+
@property
|
|
32
|
+
def length(self) -> int:
|
|
33
|
+
"""Number of timesteps."""
|
|
34
|
+
if self.actions.ndim >= 1 and self.actions.shape[0] > 0:
|
|
35
|
+
return self.actions.shape[0]
|
|
36
|
+
if self.timestamps.ndim >= 1 and self.timestamps.shape[0] > 0:
|
|
37
|
+
return self.timestamps.shape[0]
|
|
38
|
+
for v in self.observations.values():
|
|
39
|
+
if hasattr(v, "shape") and v.shape[0] > 0:
|
|
40
|
+
return v.shape[0]
|
|
41
|
+
return 0
|
|
42
|
+
|
|
43
|
+
@property
|
|
44
|
+
def action_dim(self) -> int:
|
|
45
|
+
if self.actions.ndim == 2:
|
|
46
|
+
return self.actions.shape[1]
|
|
47
|
+
return 0
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class TrajectoryDataset:
|
|
51
|
+
"""Abstract base for trajectory datasets.
|
|
52
|
+
|
|
53
|
+
Subclasses must implement ``__len__`` and ``__getitem__``.
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
def __len__(self) -> int:
|
|
57
|
+
raise NotImplementedError
|
|
58
|
+
|
|
59
|
+
def __getitem__(self, idx: int) -> Episode:
|
|
60
|
+
raise NotImplementedError
|
|
61
|
+
|
|
62
|
+
def __iter__(self) -> Iterator[Episode]:
|
|
63
|
+
for i in range(len(self)):
|
|
64
|
+
yield self[i]
|
|
65
|
+
|
|
66
|
+
def episode_ids(self) -> list[str]:
|
|
67
|
+
"""Return all episode IDs in order."""
|
|
68
|
+
raise NotImplementedError
|
|
69
|
+
|
|
70
|
+
def filter(self, episode_ids: Sequence[str]) -> "FilteredDataset":
|
|
71
|
+
"""Return a view containing only the specified episodes."""
|
|
72
|
+
return FilteredDataset(self, list(episode_ids))
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class FilteredDataset(TrajectoryDataset):
|
|
76
|
+
"""A filtered view over another dataset."""
|
|
77
|
+
|
|
78
|
+
def __init__(self, parent: TrajectoryDataset, episode_ids: list[str]):
|
|
79
|
+
self._parent = parent
|
|
80
|
+
self._ids = episode_ids
|
|
81
|
+
# Build index map: episode_id -> parent index
|
|
82
|
+
parent_ids = parent.episode_ids()
|
|
83
|
+
self._id_to_idx = {eid: i for i, eid in enumerate(parent_ids)}
|
|
84
|
+
self._indices = [self._id_to_idx[eid] for eid in episode_ids if eid in self._id_to_idx]
|
|
85
|
+
|
|
86
|
+
def __len__(self) -> int:
|
|
87
|
+
return len(self._indices)
|
|
88
|
+
|
|
89
|
+
def __getitem__(self, idx: int) -> Episode:
|
|
90
|
+
return self._parent[self._indices[idx]]
|
|
91
|
+
|
|
92
|
+
def episode_ids(self) -> list[str]:
|
|
93
|
+
return [self._ids[i] for i in range(len(self._indices))]
|
|
@@ -0,0 +1,207 @@
|
|
|
1
|
+
"""Episode embedding computation for similarity search.
|
|
2
|
+
|
|
3
|
+
Computes per-episode feature vectors and writes them to Parquet for
|
|
4
|
+
use with the DataFusion ``vec_cosine_sim`` UDF.
|
|
5
|
+
|
|
6
|
+
Two strategies:
|
|
7
|
+
- **Trajectory features** (always available): statistical aggregates
|
|
8
|
+
of observation.state and action vectors per episode.
|
|
9
|
+
- **Text embeddings** (optional, requires ``sentence-transformers``):
|
|
10
|
+
encodes task_label strings with a pretrained language model.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
import argparse
|
|
16
|
+
import os
|
|
17
|
+
import sys
|
|
18
|
+
from pathlib import Path
|
|
19
|
+
from typing import Any
|
|
20
|
+
|
|
21
|
+
import numpy as np
|
|
22
|
+
import pyarrow as pa
|
|
23
|
+
import pyarrow.parquet as pq
|
|
24
|
+
|
|
25
|
+
from traceplane.dataset import Episode
|
|
26
|
+
from traceplane.lerobot_reader import LeRobotReader
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def compute_trajectory_embedding(episode: Episode) -> np.ndarray:
|
|
30
|
+
"""Compute a feature vector from an episode's state and action statistics.
|
|
31
|
+
|
|
32
|
+
Concatenates [mean, std, min, max] per dimension for both
|
|
33
|
+
observation.state and action, then L2-normalises.
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
1-D float32 array.
|
|
37
|
+
"""
|
|
38
|
+
parts: list[np.ndarray] = []
|
|
39
|
+
|
|
40
|
+
# State features
|
|
41
|
+
state = episode.observations.get("state")
|
|
42
|
+
if state is not None and state.ndim == 2 and state.shape[0] > 0:
|
|
43
|
+
parts.extend(_stat_features(state))
|
|
44
|
+
|
|
45
|
+
# Action features
|
|
46
|
+
if episode.actions.ndim == 2 and episode.actions.shape[0] > 0:
|
|
47
|
+
parts.extend(_stat_features(episode.actions))
|
|
48
|
+
|
|
49
|
+
if not parts:
|
|
50
|
+
# Fallback: zero vector
|
|
51
|
+
return np.zeros(1, dtype=np.float32)
|
|
52
|
+
|
|
53
|
+
vec = np.concatenate(parts).astype(np.float32)
|
|
54
|
+
|
|
55
|
+
# L2 normalise
|
|
56
|
+
norm = np.linalg.norm(vec)
|
|
57
|
+
if norm > 1e-8:
|
|
58
|
+
vec /= norm
|
|
59
|
+
|
|
60
|
+
return vec
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def _stat_features(arr: np.ndarray) -> list[np.ndarray]:
|
|
64
|
+
"""Compute [mean, std, min, max] per dimension."""
|
|
65
|
+
return [
|
|
66
|
+
arr.mean(axis=0),
|
|
67
|
+
arr.std(axis=0),
|
|
68
|
+
arr.min(axis=0),
|
|
69
|
+
arr.max(axis=0),
|
|
70
|
+
]
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def compute_text_embeddings(
|
|
74
|
+
labels: list[str],
|
|
75
|
+
model_name: str = "sentence-transformers/all-MiniLM-L6-v2",
|
|
76
|
+
) -> np.ndarray:
|
|
77
|
+
"""Encode text labels with a sentence-transformer model.
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
(N, D) float32 array of embeddings.
|
|
81
|
+
"""
|
|
82
|
+
try:
|
|
83
|
+
from sentence_transformers import SentenceTransformer
|
|
84
|
+
except ImportError:
|
|
85
|
+
raise ImportError(
|
|
86
|
+
"sentence-transformers is required for text embeddings. "
|
|
87
|
+
"Install with: pip install traceplane[embeddings]"
|
|
88
|
+
)
|
|
89
|
+
model = SentenceTransformer(model_name)
|
|
90
|
+
embeddings = model.encode(labels, show_progress_bar=False, normalize_embeddings=True)
|
|
91
|
+
return np.asarray(embeddings, dtype=np.float32)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def compute_embeddings(
|
|
95
|
+
dataset_path: str,
|
|
96
|
+
output_dir: str | None = None,
|
|
97
|
+
include_text: bool = False,
|
|
98
|
+
episode_indices: list[int] | None = None,
|
|
99
|
+
storage_options: dict[str, Any] | None = None,
|
|
100
|
+
) -> str:
|
|
101
|
+
"""Compute episode embeddings and write to Parquet.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
dataset_path: Path to a LeRobot dataset.
|
|
105
|
+
output_dir: Output directory. Defaults to ``{dataset_path}/embeddings``.
|
|
106
|
+
include_text: Also compute text embeddings from task labels.
|
|
107
|
+
episode_indices: Subset of episodes. None = all.
|
|
108
|
+
storage_options: fsspec options for remote datasets.
|
|
109
|
+
|
|
110
|
+
Returns:
|
|
111
|
+
Path to the written Parquet file.
|
|
112
|
+
"""
|
|
113
|
+
reader = LeRobotReader(
|
|
114
|
+
dataset_path,
|
|
115
|
+
storage_options=storage_options,
|
|
116
|
+
episode_indices=episode_indices,
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
n = len(reader)
|
|
120
|
+
if n == 0:
|
|
121
|
+
raise ValueError("Dataset has no episodes")
|
|
122
|
+
|
|
123
|
+
print(f"Computing embeddings for {n} episodes...", file=sys.stderr)
|
|
124
|
+
|
|
125
|
+
indices: list[int] = []
|
|
126
|
+
labels: list[str] = []
|
|
127
|
+
traj_embeddings: list[np.ndarray] = []
|
|
128
|
+
|
|
129
|
+
for i in range(n):
|
|
130
|
+
ep = reader[i]
|
|
131
|
+
traj_emb = compute_trajectory_embedding(ep)
|
|
132
|
+
traj_embeddings.append(traj_emb)
|
|
133
|
+
|
|
134
|
+
# Extract episode index from metadata or ID
|
|
135
|
+
ep_idx = ep.metadata.get("episode_index", i)
|
|
136
|
+
if isinstance(ep_idx, str):
|
|
137
|
+
try:
|
|
138
|
+
ep_idx = int(ep_idx.split("_")[-1])
|
|
139
|
+
except (ValueError, IndexError):
|
|
140
|
+
ep_idx = i
|
|
141
|
+
indices.append(int(ep_idx))
|
|
142
|
+
labels.append(ep.metadata.get("task_label", ""))
|
|
143
|
+
|
|
144
|
+
if (i + 1) % 50 == 0 or i == n - 1:
|
|
145
|
+
print(f" {i + 1}/{n}", file=sys.stderr)
|
|
146
|
+
|
|
147
|
+
# Build arrow arrays
|
|
148
|
+
traj_dim = traj_embeddings[0].shape[0]
|
|
149
|
+
arrow_traj = pa.list_(pa.float32())
|
|
150
|
+
|
|
151
|
+
columns: dict[str, Any] = {
|
|
152
|
+
"episode_index": pa.array(indices, type=pa.int64()),
|
|
153
|
+
"task_label": pa.array(labels, type=pa.utf8()),
|
|
154
|
+
"trajectory_embedding": pa.array(
|
|
155
|
+
[emb.tolist() for emb in traj_embeddings],
|
|
156
|
+
type=arrow_traj,
|
|
157
|
+
),
|
|
158
|
+
}
|
|
159
|
+
|
|
160
|
+
# Optional text embeddings
|
|
161
|
+
if include_text:
|
|
162
|
+
unique_labels = list(set(labels))
|
|
163
|
+
print(f"Computing text embeddings for {len(unique_labels)} unique labels...", file=sys.stderr)
|
|
164
|
+
text_embs = compute_text_embeddings(unique_labels)
|
|
165
|
+
label_to_emb = {lbl: text_embs[i] for i, lbl in enumerate(unique_labels)}
|
|
166
|
+
text_vecs = [label_to_emb[lbl].tolist() for lbl in labels]
|
|
167
|
+
columns["text_embedding"] = pa.array(
|
|
168
|
+
text_vecs,
|
|
169
|
+
type=pa.list_(pa.float32()),
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
table = pa.table(columns)
|
|
173
|
+
|
|
174
|
+
# Write
|
|
175
|
+
if output_dir is None:
|
|
176
|
+
output_dir = os.path.join(dataset_path, "embeddings")
|
|
177
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
178
|
+
output_path = os.path.join(output_dir, "episode_embeddings.parquet")
|
|
179
|
+
pq.write_table(table, output_path)
|
|
180
|
+
|
|
181
|
+
print(f"Written {n} embeddings ({traj_dim}-dim) to {output_path}", file=sys.stderr)
|
|
182
|
+
return output_path
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def main(argv: list[str] | None = None) -> None:
|
|
186
|
+
"""CLI entry point for embedding computation."""
|
|
187
|
+
parser = argparse.ArgumentParser(
|
|
188
|
+
description="Compute episode embeddings for similarity search",
|
|
189
|
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
|
190
|
+
)
|
|
191
|
+
parser.add_argument("dataset_path", help="Path to LeRobot dataset")
|
|
192
|
+
parser.add_argument("--output-dir", help="Output directory (default: {dataset_path}/embeddings)")
|
|
193
|
+
parser.add_argument("--include-text", action="store_true", help="Compute text embeddings (requires sentence-transformers)")
|
|
194
|
+
parser.add_argument("--episodes", type=int, nargs="+", help="Specific episode indices")
|
|
195
|
+
|
|
196
|
+
args = parser.parse_args(argv)
|
|
197
|
+
path = compute_embeddings(
|
|
198
|
+
args.dataset_path,
|
|
199
|
+
output_dir=args.output_dir,
|
|
200
|
+
include_text=args.include_text,
|
|
201
|
+
episode_indices=args.episodes,
|
|
202
|
+
)
|
|
203
|
+
print(path)
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
if __name__ == "__main__":
|
|
207
|
+
main()
|