tairos-data-convert 1.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data_pipeline/data_process/data_convert/data_convert +609 -0
- data_pipeline/data_process/data_convert/data_convert.py +142 -0
- data_pipeline/data_process/utils/data_check.py +40 -0
- data_pipeline/data_process/utils/data_load.py +144 -0
- data_pipeline/data_process/utils/data_read.py +90 -0
- data_pipeline/data_process/utils/message_convert.py +41 -0
- data_pipeline/data_process/utils/output_dataset.py +56 -0
- data_pipeline/data_process/utils/topic_mapping.py +137 -0
- data_pipeline/py_api/api_utils.py +406 -0
- data_pipeline/py_api/raw_api.py +519 -0
- data_pipeline/py_api/utils.py +11 -0
- tairos_data_convert-1.0.3.dist-info/METADATA +7 -0
- tairos_data_convert-1.0.3.dist-info/RECORD +14 -0
- tairos_data_convert-1.0.3.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
import concurrent
|
|
2
|
+
import os
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
import yaml
|
|
6
|
+
from loguru import logger
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
from data_pipeline.data_process.utils.data_load import (
|
|
10
|
+
check_config_for_loading_dataset,
|
|
11
|
+
check_config_for_loading_data,
|
|
12
|
+
load_dataset,
|
|
13
|
+
load_data,
|
|
14
|
+
)
|
|
15
|
+
from data_pipeline.data_process.utils.data_read import BagReader
|
|
16
|
+
from data_pipeline.data_process.utils.data_check import populate_dof_from_data
|
|
17
|
+
from data_pipeline.data_process.utils.output_dataset import Hdf5Dataset
|
|
18
|
+
from data_pipeline.data_process.utils.topic_mapping import get_data_for_value
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def check_config(config: dict):
|
|
22
|
+
if not config.get("dataset_id"):
|
|
23
|
+
raise ValueError("配置中未设置dataset_id")
|
|
24
|
+
|
|
25
|
+
if not config.get("fps"):
|
|
26
|
+
raise ValueError("配置中未设置fps")
|
|
27
|
+
|
|
28
|
+
output_dataset_dir = config.get("output_dataset_dir")
|
|
29
|
+
if not output_dataset_dir:
|
|
30
|
+
raise ValueError("配置中未设置output_dataset_dir")
|
|
31
|
+
os.makedirs(output_dataset_dir, exist_ok=True)
|
|
32
|
+
|
|
33
|
+
frame_structure = config.get("frame_structure")
|
|
34
|
+
if not frame_structure:
|
|
35
|
+
raise ValueError("配置中未设置frame_structure")
|
|
36
|
+
|
|
37
|
+
if not config.get("convert_num_workers"):
|
|
38
|
+
raise ValueError("配置中未设置convert_num_workers")
|
|
39
|
+
|
|
40
|
+
if not config.get("data_load"):
|
|
41
|
+
raise ValueError("配置中未设置data_load")
|
|
42
|
+
check_config_for_loading_data(config.get("data_load"))
|
|
43
|
+
check_config_for_loading_dataset(config.get("data_load"))
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def get_frame_data(frame_structure_values_set: set, frame_structure: dict, topic_to_msg: dict, data_source: int):
|
|
47
|
+
values_to_data = {}
|
|
48
|
+
for v in frame_structure_values_set:
|
|
49
|
+
values_to_data[v] = get_data_for_value(topic_to_msg, v, data_source)
|
|
50
|
+
|
|
51
|
+
frame_data = {}
|
|
52
|
+
for key, value in frame_structure.items():
|
|
53
|
+
frame_data[key] = np.concatenate([values_to_data[v] for v in value["values"]], axis=0)
|
|
54
|
+
|
|
55
|
+
return frame_data
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def convert(metadata, config: dict):
|
|
59
|
+
data_path = load_data(metadata, config.get("data_load"))
|
|
60
|
+
|
|
61
|
+
data_source = metadata.get("metadata", {}).get("data_source", 0)
|
|
62
|
+
frame_structure = config.get("frame_structure").copy()
|
|
63
|
+
try:
|
|
64
|
+
populate_dof_from_data(frame_structure, data_path, data_source)
|
|
65
|
+
except (KeyError, ValueError) as e:
|
|
66
|
+
return {
|
|
67
|
+
"success": False,
|
|
68
|
+
"error_message": str(e),
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
error_keys = []
|
|
72
|
+
for key, value in frame_structure.items():
|
|
73
|
+
if value["type"] == "float32" and value["dof"] == 0:
|
|
74
|
+
error_keys.append(key)
|
|
75
|
+
if error_keys:
|
|
76
|
+
return {
|
|
77
|
+
"success": False,
|
|
78
|
+
"error_message": f"frame_structure中存在dof为0的key: {error_keys}",
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
frame_structure_values_set = set()
|
|
82
|
+
for value in frame_structure.values():
|
|
83
|
+
for v in value["values"]:
|
|
84
|
+
frame_structure_values_set.add(v)
|
|
85
|
+
|
|
86
|
+
output_dataset_path = os.path.join(
|
|
87
|
+
config.get("output_dataset_dir"),
|
|
88
|
+
os.path.basename(data_path).rsplit('.', 1)[0] + ".hdf5"
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
with Hdf5Dataset(output_dataset_path, frame_structure) as output_dataset:
|
|
92
|
+
with BagReader(data_path) as bag_reader:
|
|
93
|
+
start_time = bag_reader.get_start_time()
|
|
94
|
+
end_time = bag_reader.get_end_time()
|
|
95
|
+
sample_timestamps = np.arange(start_time, end_time, 1 / config.get("fps"))
|
|
96
|
+
|
|
97
|
+
sample_index = 1
|
|
98
|
+
topic_to_msg = {}
|
|
99
|
+
for topic, msg, t in bag_reader.read_messages():
|
|
100
|
+
if t > sample_timestamps[sample_index]:
|
|
101
|
+
try:
|
|
102
|
+
frame_data = get_frame_data(
|
|
103
|
+
frame_structure_values_set,
|
|
104
|
+
frame_structure,
|
|
105
|
+
topic_to_msg,
|
|
106
|
+
data_source
|
|
107
|
+
)
|
|
108
|
+
output_dataset.add_frame(frame_data)
|
|
109
|
+
except (KeyError, ValueError) as e:
|
|
110
|
+
if t - start_time > 1:
|
|
111
|
+
return {
|
|
112
|
+
"success": False,
|
|
113
|
+
"error_message": f"failed at relative time {t - start_time:2f}s: {str(e)}",
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
sample_index += 1
|
|
117
|
+
if sample_index >= len(sample_timestamps):
|
|
118
|
+
break
|
|
119
|
+
|
|
120
|
+
topic_to_msg[topic] = msg
|
|
121
|
+
|
|
122
|
+
return {
|
|
123
|
+
"success": True,
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def convert_data(config_path: str):
|
|
128
|
+
with open(config_path, 'r', encoding='utf-8') as f:
|
|
129
|
+
config = yaml.safe_load(f)
|
|
130
|
+
check_config(config)
|
|
131
|
+
|
|
132
|
+
metadatas = load_dataset(config.get("dataset_id"), config.get("data_load"))
|
|
133
|
+
# convert data in parallel
|
|
134
|
+
with concurrent.futures.ProcessPoolExecutor(max_workers=config.get("convert_num_workers")) as executor:
|
|
135
|
+
futures = {executor.submit(convert, metadata, config): metadata for metadata in metadatas}
|
|
136
|
+
for future in concurrent.futures.as_completed(futures):
|
|
137
|
+
result = future.result()
|
|
138
|
+
logger.info(f"converted {futures[future]['metadata']['name']}, result: {result}")
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
if __name__ == "__main__":
|
|
142
|
+
convert_data(Path(__file__).parent / "config.yaml")
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
from data_pipeline.data_process.utils.data_read import BagReader
|
|
2
|
+
from data_pipeline.data_process.utils.topic_mapping import get_data_for_value
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def populate_dof(topic_to_msg, frame_structure, data_source):
|
|
6
|
+
for key, value in frame_structure.items():
|
|
7
|
+
dof = 0
|
|
8
|
+
for v in value["values"]:
|
|
9
|
+
data = get_data_for_value(topic_to_msg, v, data_source)
|
|
10
|
+
if value["type"] == "float32":
|
|
11
|
+
dof += data.shape[0]
|
|
12
|
+
|
|
13
|
+
if value["type"] == "float32":
|
|
14
|
+
frame_structure[key]["dof"] = dof
|
|
15
|
+
return
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def populate_dof_from_data(frame_structure, data_path, data_source):
|
|
19
|
+
with BagReader(data_path) as bag_reader:
|
|
20
|
+
start_time = bag_reader.get_start_time()
|
|
21
|
+
last_t = start_time
|
|
22
|
+
topic_to_msg = {}
|
|
23
|
+
last_error = ""
|
|
24
|
+
for topic, msg, t in bag_reader.read_messages():
|
|
25
|
+
topic_to_msg[topic] = msg
|
|
26
|
+
|
|
27
|
+
if t > start_time + 1:
|
|
28
|
+
raise ValueError(f"Cannot get all required values in the first second, last error: {last_error}")
|
|
29
|
+
|
|
30
|
+
if t > last_t + 0.1:
|
|
31
|
+
last_t = t
|
|
32
|
+
try:
|
|
33
|
+
populate_dof(topic_to_msg, frame_structure, data_source)
|
|
34
|
+
except (KeyError, ValueError) as e:
|
|
35
|
+
last_error = str(e)
|
|
36
|
+
continue
|
|
37
|
+
|
|
38
|
+
return
|
|
39
|
+
|
|
40
|
+
raise ValueError(f"Cannot get all required values in the first second, last error: {last_error}")
|
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
"""数据加载模块,支持通过goosefs或API获取数据集信息和bag路径"""
|
|
2
|
+
import os
|
|
3
|
+
import json
|
|
4
|
+
import shutil
|
|
5
|
+
from typing import List, Dict, Any, Tuple
|
|
6
|
+
|
|
7
|
+
import requests
|
|
8
|
+
|
|
9
|
+
from data_pipeline.py_api.api_utils import (
|
|
10
|
+
get_download_url_by_data_id,
|
|
11
|
+
get_all_data_from_dataset,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def check_config_for_loading_dataset(config: dict):
|
|
16
|
+
read_dataset_file = config.get("read_dataset_file")
|
|
17
|
+
if read_dataset_file is None:
|
|
18
|
+
raise ValueError("配置中未设置read_dataset_file")
|
|
19
|
+
|
|
20
|
+
if read_dataset_file and not config.get("goosefs_path"):
|
|
21
|
+
raise ValueError("read_dataset_file为true,但配置中未设置goosefs_path")
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def load_dataset(dataset_id: str, config: dict) -> List[Dict[str, Any]]:
|
|
25
|
+
"""
|
|
26
|
+
根据配置选择从文件或API加载dataset
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
dataset_id: dataset ID
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
metadata列表
|
|
33
|
+
"""
|
|
34
|
+
if config.get("read_dataset_file"):
|
|
35
|
+
file_path = os.path.join(config.get("goosefs_path"), f"dataset_conf/{dataset_id}.json")
|
|
36
|
+
if not os.path.exists(file_path):
|
|
37
|
+
raise FileNotFoundError(f"Dataset文件不存在: {file_path}")
|
|
38
|
+
|
|
39
|
+
with open(file_path, "r", encoding='utf-8') as f:
|
|
40
|
+
return json.load(f)
|
|
41
|
+
else:
|
|
42
|
+
return get_all_data_from_dataset(dataset_id)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _get_file_key(metadata: Dict[str, Any]) -> str:
|
|
46
|
+
"""从metadata中提取文件key"""
|
|
47
|
+
files = metadata.get("cos_storage", {}).get("files", [])
|
|
48
|
+
if len(files) == 0:
|
|
49
|
+
raise ValueError("failed to get cos_storage.files")
|
|
50
|
+
|
|
51
|
+
key = files[0].get("key", "")
|
|
52
|
+
if not key:
|
|
53
|
+
raise ValueError("failed to get cos_storage.files[0].key")
|
|
54
|
+
|
|
55
|
+
return key
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _get_name_data_id_type(metadata: Dict[str, Any]) -> Tuple[str, str, str]:
|
|
59
|
+
key = _get_file_key(metadata)
|
|
60
|
+
|
|
61
|
+
metadata_name = metadata.get("metadata", {}).get("name")
|
|
62
|
+
if not metadata_name:
|
|
63
|
+
raise ValueError("failed to get metadata.name")
|
|
64
|
+
|
|
65
|
+
data_id = metadata.get("metadata", {}).get("data_id")
|
|
66
|
+
if not data_id:
|
|
67
|
+
raise ValueError("failed to get metadata.data_id")
|
|
68
|
+
|
|
69
|
+
return metadata_name, data_id, key.split(".")[-1]
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def get_path_on_goosefs_or_goosefsx(metadata: Dict[str, Any], config: dict) -> str:
|
|
73
|
+
key = _get_file_key(metadata)[1:]
|
|
74
|
+
|
|
75
|
+
goosefs_filepath = os.path.join(config.get("goosefs_path"), key)
|
|
76
|
+
if not os.path.exists(goosefs_filepath):
|
|
77
|
+
raise FileNotFoundError(f"goosefs文件不存在: {goosefs_filepath}")
|
|
78
|
+
|
|
79
|
+
if not config.get("use_goosefsx_as_cache"):
|
|
80
|
+
return goosefs_filepath
|
|
81
|
+
|
|
82
|
+
metadata_name, data_id, file_type = _get_name_data_id_type(metadata)
|
|
83
|
+
goosefsx_filepath = os.path.join(config.get("bag_cache_dir"), f"{metadata_name}_{data_id}.{file_type}")
|
|
84
|
+
shutil.copy(goosefs_filepath, goosefsx_filepath)
|
|
85
|
+
return goosefsx_filepath
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def download_bag_from_url(metadata: Dict[str, Any], config: dict) -> str:
|
|
89
|
+
metadata_name, data_id, file_type = _get_name_data_id_type(metadata)
|
|
90
|
+
|
|
91
|
+
download_url = get_download_url_by_data_id(data_id)
|
|
92
|
+
if not download_url:
|
|
93
|
+
raise ValueError(f"无法获取下载链接: {metadata_name}_{data_id}")
|
|
94
|
+
|
|
95
|
+
save_path = os.path.join(config.get("bag_download_dir"), f'{metadata_name}_{data_id}.{file_type}')
|
|
96
|
+
with requests.get(download_url, stream=True) as response:
|
|
97
|
+
response.raise_for_status()
|
|
98
|
+
|
|
99
|
+
with open(save_path, 'wb') as file:
|
|
100
|
+
for chunk in response.iter_content(chunk_size=8192):
|
|
101
|
+
file.write(chunk)
|
|
102
|
+
|
|
103
|
+
return save_path
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def check_config_for_loading_data(config: dict):
|
|
107
|
+
download_bag = config.get("download_bag")
|
|
108
|
+
if download_bag is None:
|
|
109
|
+
raise ValueError("配置中未设置download_bag")
|
|
110
|
+
|
|
111
|
+
bag_download_dir = config.get("bag_download_dir")
|
|
112
|
+
if download_bag and not bag_download_dir:
|
|
113
|
+
raise ValueError("download_bag为true,但配置中未设置bag_download_dir")
|
|
114
|
+
os.makedirs(bag_download_dir, exist_ok=True)
|
|
115
|
+
|
|
116
|
+
if not download_bag:
|
|
117
|
+
if not config.get("goosefs_path"):
|
|
118
|
+
raise ValueError("download_bag为false,但配置中未设置goosefs_path")
|
|
119
|
+
|
|
120
|
+
use_goosefsx_as_cache = config.get("use_goosefsx_as_cache")
|
|
121
|
+
if use_goosefsx_as_cache is None:
|
|
122
|
+
raise ValueError("download_bag为false,但配置中未设置use_goosefsx_as_cache")
|
|
123
|
+
|
|
124
|
+
bag_cache_dir = config.get("bag_cache_dir")
|
|
125
|
+
if use_goosefsx_as_cache and not bag_cache_dir:
|
|
126
|
+
raise ValueError("use_goosefsx_as_cache为true,但配置中未设置bag_cache_dir")
|
|
127
|
+
os.makedirs(bag_cache_dir, exist_ok=True)
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
# Return path of the raw data
|
|
131
|
+
def load_data(metadata: Dict[str, Any], config: dict) -> str:
|
|
132
|
+
"""
|
|
133
|
+
根据配置选择下载或使用goosefs路径加载数据
|
|
134
|
+
|
|
135
|
+
Args:
|
|
136
|
+
metadata: metadata对象
|
|
137
|
+
|
|
138
|
+
Returns:
|
|
139
|
+
bag文件路径
|
|
140
|
+
"""
|
|
141
|
+
if config.get("download_bag"):
|
|
142
|
+
return download_bag_from_url(metadata, config)
|
|
143
|
+
else:
|
|
144
|
+
return get_path_on_goosefs_or_goosefsx(metadata, config)
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
"""数据读取模块,支持rosbag和mcap格式"""
|
|
2
|
+
import os
|
|
3
|
+
from typing import Iterator, Tuple, Optional, Any
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
import rosbag
|
|
7
|
+
from mcap_ros1.decoder import DecoderFactory as ROS1DecoderFactory
|
|
8
|
+
from mcap.reader import make_reader
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class BagReader:
|
|
12
|
+
"""数据读取器,支持rosbag和mcap格式"""
|
|
13
|
+
|
|
14
|
+
def __init__(self, bag_path: str):
|
|
15
|
+
"""
|
|
16
|
+
初始化数据读取器
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
bag_path: bag或mcap文件路径
|
|
20
|
+
"""
|
|
21
|
+
if not os.path.exists(bag_path):
|
|
22
|
+
raise FileNotFoundError(f"文件不存在: {bag_path}")
|
|
23
|
+
|
|
24
|
+
self.bag_path = bag_path
|
|
25
|
+
self.is_rosbag = bag_path.endswith(".bag")
|
|
26
|
+
self.is_mcap = bag_path.endswith(".mcap")
|
|
27
|
+
|
|
28
|
+
if not (self.is_rosbag or self.is_mcap):
|
|
29
|
+
raise ValueError(f"不支持的文件格式: {bag_path}")
|
|
30
|
+
|
|
31
|
+
if self.is_rosbag:
|
|
32
|
+
self.reader = rosbag.Bag(bag_path, 'r')
|
|
33
|
+
else:
|
|
34
|
+
self.file = open(bag_path, "rb")
|
|
35
|
+
self.reader = make_reader(self.file, decoder_factories=[ROS1DecoderFactory()])
|
|
36
|
+
self.summary = self.reader.get_summary()
|
|
37
|
+
|
|
38
|
+
def get_start_time(self) -> float:
|
|
39
|
+
"""获取数据开始时间(秒)"""
|
|
40
|
+
if self.is_rosbag:
|
|
41
|
+
return self.reader.get_start_time()
|
|
42
|
+
else:
|
|
43
|
+
return self.summary.statistics.message_start_time / 1e9
|
|
44
|
+
|
|
45
|
+
def get_end_time(self) -> float:
|
|
46
|
+
"""获取数据结束时间(秒)"""
|
|
47
|
+
if self.is_rosbag:
|
|
48
|
+
return self.reader.get_end_time()
|
|
49
|
+
else:
|
|
50
|
+
return self.summary.statistics.message_end_time / 1e9
|
|
51
|
+
|
|
52
|
+
def read_messages(self, topics: Optional[list] = None) -> Iterator[Tuple[str, Any, float]]:
|
|
53
|
+
"""
|
|
54
|
+
读取消息
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
topics: 要读取的topic列表,如果为None则读取所有topic
|
|
58
|
+
|
|
59
|
+
Yields:
|
|
60
|
+
(topic, message, timestamp): topic名称、消息对象、时间戳(秒)
|
|
61
|
+
"""
|
|
62
|
+
if self.is_rosbag:
|
|
63
|
+
for topic, msg, t in self.reader.read_messages(topics=topics):
|
|
64
|
+
yield topic, msg, t.to_sec()
|
|
65
|
+
else:
|
|
66
|
+
for _, channel, message, decoded_msg in self.reader.iter_decoded_messages(topics=topics):
|
|
67
|
+
yield channel.topic, decoded_msg, message.log_time / 1e9
|
|
68
|
+
|
|
69
|
+
def get_file_type(self) -> str:
|
|
70
|
+
"""获取文件类型"""
|
|
71
|
+
if self.is_rosbag:
|
|
72
|
+
return "bag"
|
|
73
|
+
elif self.is_mcap:
|
|
74
|
+
return "mcap"
|
|
75
|
+
else:
|
|
76
|
+
return Path(self.bag_path).suffix[1:] # 去掉点号
|
|
77
|
+
|
|
78
|
+
def close(self):
|
|
79
|
+
if self.is_rosbag:
|
|
80
|
+
self.reader.close()
|
|
81
|
+
else:
|
|
82
|
+
self.file.close()
|
|
83
|
+
|
|
84
|
+
def __enter__(self):
|
|
85
|
+
# open resources
|
|
86
|
+
return self
|
|
87
|
+
|
|
88
|
+
def __exit__(self, exc_type, exc_value, traceback):
|
|
89
|
+
# close resources
|
|
90
|
+
self.close()
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def convert_message(msg, data_type=None):
|
|
5
|
+
msg_type = type(msg).__name__
|
|
6
|
+
if msg_type == "_sensor_msgs__CompressedImage":
|
|
7
|
+
return convert_compressed_image(msg)
|
|
8
|
+
if msg_type == "_data_msgs__ComponentObservation":
|
|
9
|
+
return convert_component_observation(msg, data_type)
|
|
10
|
+
if msg_type == "_data_msgs__ComponentAction":
|
|
11
|
+
return convert_component_action(msg, data_type)
|
|
12
|
+
|
|
13
|
+
raise ValueError(f"message type {msg_type} not supported")
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def convert_compressed_image(msg):
|
|
17
|
+
return np.frombuffer(msg.data, dtype=np.uint8)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def convert_component_observation(msg, data_type):
|
|
21
|
+
if data_type == "pose":
|
|
22
|
+
pose = msg.multibody_pose.pose
|
|
23
|
+
return np.array([
|
|
24
|
+
pose.position.x, pose.position.y, pose.position.z,
|
|
25
|
+
pose.orientation.x, pose.orientation.y, pose.orientation.z, pose.orientation.w
|
|
26
|
+
])
|
|
27
|
+
if data_type == "joint":
|
|
28
|
+
return np.array([state_msg.q for state_msg in msg.multibody_state.states])
|
|
29
|
+
raise ValueError(f"data_type {data_type} not supported")
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def convert_component_action(msg, data_type):
|
|
33
|
+
if data_type == "pose":
|
|
34
|
+
pose = msg.pose_command.pose
|
|
35
|
+
return np.array([
|
|
36
|
+
pose.position.x, pose.position.y, pose.position.z,
|
|
37
|
+
pose.orientation.x, pose.orientation.y, pose.orientation.z, pose.orientation.w
|
|
38
|
+
])
|
|
39
|
+
if data_type == "joint":
|
|
40
|
+
return np.array(msg.joint_commands)
|
|
41
|
+
raise ValueError(f"data_type {data_type} not supported")
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
import h5py
|
|
2
|
+
import numpy as np
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class OutputDataset:
|
|
6
|
+
def __init__(self, path: str, frame_structure: dict):
|
|
7
|
+
self.path = path
|
|
8
|
+
self.frame_structure = frame_structure
|
|
9
|
+
|
|
10
|
+
def add_frame(self, frame_data: dict):
|
|
11
|
+
pass
|
|
12
|
+
|
|
13
|
+
def save_episode(self):
|
|
14
|
+
pass
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class Hdf5Dataset(OutputDataset):
|
|
18
|
+
def __init__(self, path: str, frame_structure: dict):
|
|
19
|
+
super().__init__(path, frame_structure)
|
|
20
|
+
self.hdf5_file = h5py.File(path, 'w')
|
|
21
|
+
self.datasets = {}
|
|
22
|
+
|
|
23
|
+
for key, value in self.frame_structure.items():
|
|
24
|
+
if value["type"] == "image":
|
|
25
|
+
self.datasets[key] = self.hdf5_file.create_dataset(
|
|
26
|
+
key, (0,), maxshape=(None,), dtype=h5py.special_dtype(vlen=np.dtype('uint8')))
|
|
27
|
+
elif value["type"] == "string":
|
|
28
|
+
self.datasets[key] = self.hdf5_file.create_dataset(
|
|
29
|
+
key, (0,), maxshape=(None,), dtype=h5py.string_dtype(encoding='utf-8'))
|
|
30
|
+
elif value["type"] == "float32":
|
|
31
|
+
self.datasets[key] = self.hdf5_file.create_dataset(
|
|
32
|
+
key, (0, value["dof"]), maxshape=(None, value["dof"]), dtype='float32')
|
|
33
|
+
else:
|
|
34
|
+
raise ValueError(f"Unsupported frame structure type: {value['type']}")
|
|
35
|
+
|
|
36
|
+
def add_frame(self, frame_data: dict):
|
|
37
|
+
for key in self.frame_structure.keys():
|
|
38
|
+
self.datasets[key].resize(self.datasets[key].shape[0] + 1, axis=0)
|
|
39
|
+
self.datasets[key][-1] = frame_data[key]
|
|
40
|
+
|
|
41
|
+
def __enter__(self):
|
|
42
|
+
return self
|
|
43
|
+
|
|
44
|
+
def __exit__(self, exc_type, exc_value, traceback):
|
|
45
|
+
self.hdf5_file.__exit__(exc_type, exc_value, traceback)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class LerobotDataset(OutputDataset):
|
|
49
|
+
def __init__(self, path: str):
|
|
50
|
+
super().__init__(path)
|
|
51
|
+
|
|
52
|
+
def add_frame(self, frame_data: dict):
|
|
53
|
+
pass
|
|
54
|
+
|
|
55
|
+
def save_episode(self):
|
|
56
|
+
pass
|
|
@@ -0,0 +1,137 @@
|
|
|
1
|
+
from data_pipeline.data_process.utils.message_convert import convert_message
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def get_data_for_value(topic_to_msg, value, data_source):
|
|
5
|
+
if value in topic_to_msg:
|
|
6
|
+
return convert_message(topic_to_msg[value])
|
|
7
|
+
|
|
8
|
+
if data_source == 1:
|
|
9
|
+
return get_umi_data_for_value(topic_to_msg, value)
|
|
10
|
+
elif data_source == 2:
|
|
11
|
+
return get_xtrainer_data_for_value(topic_to_msg, value)
|
|
12
|
+
elif data_source == 3:
|
|
13
|
+
return get_agibot_data_for_value(topic_to_msg, value)
|
|
14
|
+
|
|
15
|
+
raise ValueError(f"value {value} not supported for data source {data_source}")
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def get_umi_data_for_value(topic_to_msg, value):
|
|
19
|
+
if value == "head/camera/color":
|
|
20
|
+
return convert_message(topic_to_msg["/robot/data/head_realsense/color_image"])
|
|
21
|
+
if value == "left_wrist/camera/color":
|
|
22
|
+
return convert_message(topic_to_msg["/robot/data/left_hand_realsense/color_image"])
|
|
23
|
+
if value == "right_wrist/camera/color":
|
|
24
|
+
return convert_message(topic_to_msg["/robot/data/right_hand_realsense/color_image"])
|
|
25
|
+
|
|
26
|
+
if value == "left_gripper/state":
|
|
27
|
+
return convert_message(topic_to_msg["/robot/data/changingtek_hand_left/observation"], data_type="joint")
|
|
28
|
+
if value == "right_gripper/state":
|
|
29
|
+
return convert_message(topic_to_msg["/robot/data/changingtek_hand_right/observation"], data_type="joint")
|
|
30
|
+
if value == "left_gripper/command":
|
|
31
|
+
return convert_message(topic_to_msg["/robot/data/changingtek_hand_left/action"], data_type="joint")
|
|
32
|
+
if value == "right_gripper/command":
|
|
33
|
+
return convert_message(topic_to_msg["/robot/data/changingtek_hand_right/action"], data_type="joint")
|
|
34
|
+
|
|
35
|
+
if value == "left_wrist/pose/state":
|
|
36
|
+
return convert_message(topic_to_msg["/robot/data/jaka_arm_left/observation"], data_type="pose")
|
|
37
|
+
if value == "right_wrist/pose/state":
|
|
38
|
+
return convert_message(topic_to_msg["/robot/data/jaka_arm_right/observation"], data_type="pose")
|
|
39
|
+
if value == "left_wrist/pose/command":
|
|
40
|
+
return convert_message(topic_to_msg["/robot/data/jaka_arm_left/action"], data_type="pose")
|
|
41
|
+
if value == "right_wrist/pose/command":
|
|
42
|
+
return convert_message(topic_to_msg["/robot/data/jaka_arm_right/action"], data_type="pose")
|
|
43
|
+
|
|
44
|
+
if value == "left_wrist/joint/state":
|
|
45
|
+
return convert_message(topic_to_msg["/robot/data/jaka_arm_left/observation"], data_type="joint")
|
|
46
|
+
if value == "right_wrist/joint/state":
|
|
47
|
+
return convert_message(topic_to_msg["/robot/data/jaka_arm_right/observation"], data_type="joint")
|
|
48
|
+
|
|
49
|
+
raise ValueError(f"value {value} not supported for umi data")
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def get_agibot_data_for_value(topic_to_msg, value):
|
|
53
|
+
if value == "head/camera/color":
|
|
54
|
+
return convert_message(topic_to_msg["/robot/data/head_realsense/color_image"])
|
|
55
|
+
if value == "left_wrist/camera/color":
|
|
56
|
+
return convert_message(topic_to_msg["/robot/data/left_gripper/color_image"])
|
|
57
|
+
if value == "right_wrist/camera/color":
|
|
58
|
+
return convert_message(topic_to_msg["/robot/data/right_gripper/color_image"])
|
|
59
|
+
|
|
60
|
+
if value == "left_gripper/state":
|
|
61
|
+
return convert_message(topic_to_msg["/robot/data/changingtek_hand_left/observation"], data_type="joint")
|
|
62
|
+
if value == "right_gripper/state":
|
|
63
|
+
return convert_message(topic_to_msg["/robot/data/changingtek_hand_right/observation"], data_type="joint")
|
|
64
|
+
if value == "left_gripper/command":
|
|
65
|
+
return convert_message(topic_to_msg["/robot/data/changingtek_hand_left/action"], data_type="joint")
|
|
66
|
+
if value == "right_gripper/command":
|
|
67
|
+
return convert_message(topic_to_msg["/robot/data/changingtek_hand_right/action"], data_type="joint")
|
|
68
|
+
|
|
69
|
+
if value == "left_wrist/pose/state":
|
|
70
|
+
return convert_message(topic_to_msg["/robot/data/jaka_arm_left/observation"], data_type="pose")
|
|
71
|
+
if value == "right_wrist/pose/state":
|
|
72
|
+
return convert_message(topic_to_msg["/robot/data/jaka_arm_right/observation"], data_type="pose")
|
|
73
|
+
if value == "left_wrist/pose/command":
|
|
74
|
+
return convert_message(topic_to_msg["/robot/data/jaka_arm_left/action"], data_type="pose")
|
|
75
|
+
if value == "right_wrist/pose/command":
|
|
76
|
+
return convert_message(topic_to_msg["/robot/data/jaka_arm_right/action"], data_type="pose")
|
|
77
|
+
|
|
78
|
+
if value == "left_wrist/joint/state":
|
|
79
|
+
return convert_message(topic_to_msg["/robot/data/jaka_arm_left/observation"], data_type="joint")
|
|
80
|
+
if value == "right_wrist/joint/state":
|
|
81
|
+
return convert_message(topic_to_msg["/robot/data/jaka_arm_right/observation"], data_type="joint")
|
|
82
|
+
|
|
83
|
+
raise ValueError(f"value {value} not supported for agibot data")
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def get_xtrainer_data_for_value(topic_to_msg, value):
|
|
87
|
+
if value == "head/camera/color":
|
|
88
|
+
return convert_message(topic_to_msg["/robot/data/head_realsense/color_image"])
|
|
89
|
+
if value == "left_wrist/camera/color":
|
|
90
|
+
return convert_message(topic_to_msg["/robot/data/left_hand_realsense/color_image"])
|
|
91
|
+
if value == "right_wrist/camera/color":
|
|
92
|
+
return convert_message(topic_to_msg["/robot/data/right_hand_realsense/color_image"])
|
|
93
|
+
|
|
94
|
+
if value == "left_gripper/state":
|
|
95
|
+
if "/robot/data/left_hand/observation" in topic_to_msg:
|
|
96
|
+
return convert_message(topic_to_msg["/robot/data/left_hand/observation"], data_type="joint")
|
|
97
|
+
return convert_message(topic_to_msg["/robot/data/aloha_left_hand/observation"], data_type="joint")
|
|
98
|
+
if value == "right_gripper/state":
|
|
99
|
+
if "/robot/data/right_hand/observation" in topic_to_msg:
|
|
100
|
+
return convert_message(topic_to_msg["/robot/data/right_hand/observation"], data_type="joint")
|
|
101
|
+
return convert_message(topic_to_msg["/robot/data/aloha_right_hand/observation"], data_type="joint")
|
|
102
|
+
if value == "left_gripper/command":
|
|
103
|
+
if "/robot/data/left_hand/action" in topic_to_msg:
|
|
104
|
+
return convert_message(topic_to_msg["/robot/data/left_hand/action"], data_type="joint")
|
|
105
|
+
return convert_message(topic_to_msg["/robot/data/aloha_left_hand/action"], data_type="joint")
|
|
106
|
+
if value == "right_gripper/command":
|
|
107
|
+
if "/robot/data/right_hand/action" in topic_to_msg:
|
|
108
|
+
return convert_message(topic_to_msg["/robot/data/right_hand/action"], data_type="joint")
|
|
109
|
+
return convert_message(topic_to_msg["/robot/data/aloha_right_hand/action"], data_type="joint")
|
|
110
|
+
|
|
111
|
+
if value == "left_wrist/pose/state":
|
|
112
|
+
if "/robot/data/left_arm/observation" in topic_to_msg:
|
|
113
|
+
return convert_message(topic_to_msg["/robot/data/left_arm/observation"], data_type="pose")
|
|
114
|
+
return convert_message(topic_to_msg["/robot/data/aloha_left_arm/observation"], data_type="pose")
|
|
115
|
+
if value == "right_wrist/pose/state":
|
|
116
|
+
if "/robot/data/right_arm/observation" in topic_to_msg:
|
|
117
|
+
return convert_message(topic_to_msg["/robot/data/right_arm/observation"], data_type="pose")
|
|
118
|
+
return convert_message(topic_to_msg["/robot/data/aloha_right_arm/observation"], data_type="pose")
|
|
119
|
+
|
|
120
|
+
if value == "left_wrist/joint/state":
|
|
121
|
+
if "/robot/data/left_arm/observation" in topic_to_msg:
|
|
122
|
+
return convert_message(topic_to_msg["/robot/data/left_arm/observation"], data_type="joint")
|
|
123
|
+
return convert_message(topic_to_msg["/robot/data/aloha_left_arm/observation"], data_type="joint")
|
|
124
|
+
if value == "right_wrist/joint/state":
|
|
125
|
+
if "/robot/data/right_arm/observation" in topic_to_msg:
|
|
126
|
+
return convert_message(topic_to_msg["/robot/data/right_arm/observation"], data_type="joint")
|
|
127
|
+
return convert_message(topic_to_msg["/robot/data/aloha_right_arm/observation"], data_type="joint")
|
|
128
|
+
if value == "left_wrist/joint/command":
|
|
129
|
+
if "/robot/data/left_arm/action" in topic_to_msg:
|
|
130
|
+
return convert_message(topic_to_msg["/robot/data/left_arm/action"], data_type="joint")
|
|
131
|
+
return convert_message(topic_to_msg["/robot/data/aloha_left_arm/action"], data_type="joint")
|
|
132
|
+
if value == "right_wrist/joint/command":
|
|
133
|
+
if "/robot/data/right_arm/action" in topic_to_msg:
|
|
134
|
+
return convert_message(topic_to_msg["/robot/data/right_arm/action"], data_type="joint")
|
|
135
|
+
return convert_message(topic_to_msg["/robot/data/aloha_right_arm/action"], data_type="joint")
|
|
136
|
+
|
|
137
|
+
raise ValueError(f"value {value} not supported for xtrainer data")
|