neuracore-types 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,5 @@
1
+ """Neuracore Types - Shared type definitions for Neuracore."""
2
+
3
+ from neuracore_types.neuracore_types import * # noqa: F403
4
+
5
+ __version__ = "1.0.0"
@@ -0,0 +1,1077 @@
1
+ """Defines the core data structures used throughout Neuracore."""
2
+
3
+ import base64
4
+ import time
5
+ from datetime import datetime, timezone
6
+ from enum import Enum
7
+ from typing import Any, List, NamedTuple, Optional, Tuple, Union
8
+ from uuid import uuid4
9
+
10
+ import numpy as np
11
+ from pydantic import (
12
+ BaseModel,
13
+ ConfigDict,
14
+ Field,
15
+ NonNegativeInt,
16
+ field_serializer,
17
+ field_validator,
18
+ )
19
+
20
+
21
+ def _sort_dict_by_keys(data_dict: Optional[dict[str, Any]]) -> Optional[dict[str, Any]]:
22
+ """Sort a dictionary by its keys to ensure consistent ordering.
23
+
24
+ This is a helper function used internally by the data models to ensure
25
+ consistent dictionary ordering. Use the model's order() or
26
+ sort_in_place() methods instead of calling this directly.
27
+
28
+ Args:
29
+ data_dict: Dictionary to sort, or None
30
+
31
+ Returns:
32
+ New dictionary with keys sorted alphabetically, or None if input was None
33
+ """
34
+ if data_dict is None:
35
+ return None
36
+ return {key: data_dict[key] for key in sorted(data_dict.keys())}
37
+
38
+
39
+ class NCData(BaseModel):
40
+ """Base class for all Neuracore data with automatic timestamping.
41
+
42
+ Provides a common base for all data types in the system with automatic
43
+ timestamp generation for temporal synchronization and data ordering.
44
+ """
45
+
46
+ timestamp: float = Field(default_factory=lambda: time.time())
47
+
48
+ def order(self) -> "NCData":
49
+ """Return a new instance with sorted data.
50
+
51
+ This method should be overridden by subclasses to implement specific
52
+ ordering logic for the data type. The base class implementation does
53
+ nothing and returns self.
54
+ """
55
+ return self
56
+
57
+
58
+ class JointData(NCData):
59
+ """Robot joint state data including positions, velocities, or torques.
60
+
61
+ Represents joint-space data for robotic systems with support for named
62
+ joints and additional auxiliary values. Used for positions, velocities,
63
+ torques, and target positions.
64
+ """
65
+
66
+ values: dict[str, float]
67
+ additional_values: Optional[dict[str, float]] = None
68
+
69
+ def order(self) -> "JointData":
70
+ """Return a new JointData instance with sorted joint names.
71
+
72
+ Returns:
73
+ New JointData with alphabetically sorted joint names.
74
+ """
75
+ return JointData(
76
+ timestamp=self.timestamp,
77
+ values=_sort_dict_by_keys(self.values) or {},
78
+ additional_values=_sort_dict_by_keys(self.additional_values),
79
+ )
80
+
81
+ def numpy(self, order: Optional[List[str]] = None) -> np.ndarray:
82
+ """Convert the joint values to a NumPy array.
83
+
84
+ Args:
85
+ order: The order in which the numpy array is returned.
86
+
87
+ Returns:
88
+ NumPy array of joint values.
89
+ """
90
+ if order is not None:
91
+ values = [self.values[name] for name in order]
92
+ else:
93
+ values = list(self.values.values())
94
+ return np.array(values, dtype=np.float32)
95
+
96
+
97
+ class CameraData(NCData):
98
+ """Camera sensor data including images and calibration information.
99
+
100
+ Contains image data along with camera intrinsic and extrinsic parameters
101
+ for 3D reconstruction and computer vision applications. The frame field
102
+ is populated during dataset iteration for efficiency.
103
+ """
104
+
105
+ frame_idx: int = 0 # Needed so we can index video after sync
106
+ extrinsics: Optional[list[list[float]]] = None
107
+ intrinsics: Optional[list[list[float]]] = None
108
+ frame: Optional[Union[Any, str]] = None # Only filled in when using dataset iter
109
+
110
+
111
+ class PoseData(NCData):
112
+ """6DOF pose data for objects, end-effectors, or coordinate frames.
113
+
114
+ Represents position and orientation information for tracking objects
115
+ or robot components in 3D space. Poses are stored as dictionaries
116
+ mapping pose names to [x, y, z, rx, ry, rz] values.
117
+ """
118
+
119
+ pose: dict[str, list[float]]
120
+
121
+ def order(self) -> "PoseData":
122
+ """Return a new PoseData instance with sorted pose coordinates.
123
+
124
+ Returns:
125
+ New PoseData with alphabetically sorted pose coordinate names.
126
+ """
127
+ return PoseData(
128
+ timestamp=self.timestamp, pose=_sort_dict_by_keys(self.pose) or {}
129
+ )
130
+
131
+
132
+ class EndEffectorData(NCData):
133
+ """End-effector state data including gripper and tool configurations.
134
+
135
+ Contains the state of robot end-effectors such as gripper opening amounts,
136
+ tool activations, or other end-effector specific parameters.
137
+ """
138
+
139
+ open_amounts: dict[str, float]
140
+
141
+ def order(self) -> "EndEffectorData":
142
+ """Return a new EndEffectorData instance with sorted effector names.
143
+
144
+ Returns:
145
+ New EndEffectorData with alphabetically sorted effector names.
146
+ """
147
+ return EndEffectorData(
148
+ timestamp=self.timestamp,
149
+ open_amounts=_sort_dict_by_keys(self.open_amounts) or {},
150
+ )
151
+
152
+
153
+ class EndEffectorPoseData(NCData):
154
+ """End-effector pose data.
155
+
156
+ Contains the pose of end-effectors as a 7-element list containing the
157
+ position and unit quaternion orientation [x, y, z, qx, qy, qz, qw].
158
+ """
159
+
160
+ poses: dict[str, list[float]]
161
+
162
+ def order(self) -> "EndEffectorPoseData":
163
+ """Return a new EndEffectorPoseData instance with sorted effector names.
164
+
165
+ Returns:
166
+ New EndEffectorPoseData with alphabetically sorted effector names.
167
+ """
168
+ return EndEffectorPoseData(
169
+ timestamp=self.timestamp,
170
+ poses=_sort_dict_by_keys(self.poses) or {},
171
+ )
172
+
173
+
174
+ class ParallelGripperOpenAmountData(NCData):
175
+ """Open amount data for parallel end effector gripper.
176
+
177
+ Contains the state of parallel gripper opening amounts.
178
+ """
179
+
180
+ open_amounts: dict[str, float]
181
+
182
+ def order(self) -> "ParallelGripperOpenAmountData":
183
+ """Return a new Gripper Open Amount instance with sorted gripper names.
184
+
185
+ Returns:
186
+ New ParallelGripperOpenAmountData with alphabetically sorted gripper names.
187
+ """
188
+ return ParallelGripperOpenAmountData(
189
+ timestamp=self.timestamp,
190
+ open_amounts=_sort_dict_by_keys(self.open_amounts) or {},
191
+ )
192
+
193
+
194
+ class PointCloudData(NCData):
195
+ """3D point cloud data with optional RGB colouring and camera parameters.
196
+
197
+ Represents 3D spatial data from depth sensors or LiDAR systems with
198
+ optional colour information and camera calibration for registration.
199
+ """
200
+
201
+ model_config = ConfigDict(arbitrary_types_allowed=True)
202
+
203
+ points: Optional[np.ndarray] = None # (N, 3) float16
204
+ rgb_points: Optional[np.ndarray] = None # (N, 3) uint8
205
+ extrinsics: Optional[np.ndarray] = None # (4, 4) float16
206
+ intrinsics: Optional[np.ndarray] = None # (3, 3) float16
207
+
208
+ @staticmethod
209
+ def _encode(arr: np.ndarray, dtype: Any) -> str:
210
+ return base64.b64encode(arr.astype(dtype).tobytes()).decode("utf-8")
211
+
212
+ @staticmethod
213
+ def _decode(data: str, dtype: Any, shape: Tuple[int, ...]) -> np.ndarray:
214
+ return np.frombuffer(
215
+ base64.b64decode(data.encode("utf-8")), dtype=dtype
216
+ ).reshape(*shape)
217
+
218
+ @field_validator("points", mode="before")
219
+ @classmethod
220
+ def decode_points(cls, v: Union[str, np.ndarray]) -> Optional[np.ndarray]:
221
+ """Decode base64 string to NumPy array if needed.
222
+
223
+ Args:
224
+ v: Base64 encoded string or NumPy array
225
+
226
+ Returns:
227
+ Decoded NumPy array or None
228
+ """
229
+ return cls._decode(v, np.float16, (-1, 3)) if isinstance(v, str) else v
230
+
231
+ @field_validator("rgb_points", mode="before")
232
+ @classmethod
233
+ def decode_rgb_points(cls, v: Union[str, np.ndarray]) -> Optional[np.ndarray]:
234
+ """Decode base64 string to NumPy array if needed.
235
+
236
+ Args:
237
+ v: Base64 encoded string or NumPy array
238
+
239
+ Returns:
240
+ Decoded NumPy array or None
241
+ """
242
+ return cls._decode(v, np.uint8, (-1, 3)) if isinstance(v, str) else v
243
+
244
+ @field_validator("extrinsics", mode="before")
245
+ @classmethod
246
+ def decode_extrinsics(cls, v: Union[str, np.ndarray]) -> Optional[np.ndarray]:
247
+ """Decode base64 string to NumPy array if needed.
248
+
249
+ Args:
250
+ v: Base64 encoded string or NumPy array
251
+
252
+ Returns:
253
+ Decoded NumPy array or None
254
+ """
255
+ return cls._decode(v, np.float16, (4, 4)) if isinstance(v, str) else v
256
+
257
+ @field_validator("intrinsics", mode="before")
258
+ @classmethod
259
+ def decode_intrinsics(cls, v: Union[str, np.ndarray]) -> Optional[np.ndarray]:
260
+ """Decode base64 string to NumPy array if needed.
261
+
262
+ Args:
263
+ v: Base64 encoded string or NumPy array
264
+
265
+ Returns:
266
+ Decoded NumPy array or None
267
+ """
268
+ return cls._decode(v, np.float16, (3, 3)) if isinstance(v, str) else v
269
+
270
+ # --- Serializers (encode on dump) ---
271
+ @field_serializer("points", when_used="json")
272
+ def serialize_points(self, v: Optional[np.ndarray]) -> Optional[str]:
273
+ """Encode NumPy array to base64 string if needed.
274
+
275
+ Args:
276
+ v: NumPy array to encode
277
+
278
+ Returns:
279
+ Base64 encoded string or None
280
+ """
281
+ return self._encode(v, np.float16) if v is not None else None
282
+
283
+ @field_serializer("rgb_points", when_used="json")
284
+ def serialize_rgb_points(self, v: Optional[np.ndarray]) -> Optional[str]:
285
+ """Encode NumPy array to base64 string if needed.
286
+
287
+ Args:
288
+ v: NumPy array to encode
289
+
290
+ Returns:
291
+ Base64 encoded string or None
292
+ """
293
+ return self._encode(v, np.uint8) if v is not None else None
294
+
295
+ @field_serializer("extrinsics", when_used="json")
296
+ def serialize_extrinsics(self, v: Optional[np.ndarray]) -> Optional[str]:
297
+ """Encode NumPy array to base64 string if needed.
298
+
299
+ Args:
300
+ v: NumPy array to encode
301
+
302
+ Returns:
303
+ Base64 encoded string or None
304
+ """
305
+ return self._encode(v, np.float16) if v is not None else None
306
+
307
+ @field_serializer("intrinsics", when_used="json")
308
+ def serialize_intrinsics(self, v: Optional[np.ndarray]) -> Optional[str]:
309
+ """Encode NumPy array to base64 string if needed.
310
+
311
+ Args:
312
+ v: NumPy array to encode
313
+
314
+ Returns:
315
+ Base64 encoded string or None
316
+ """
317
+ return self._encode(v, np.float16) if v is not None else None
318
+
319
+
320
+ class LanguageData(NCData):
321
+ """Natural language instruction or description data.
322
+
323
+ Contains text-based information such as task descriptions, voice commands,
324
+ or other linguistic data associated with robot demonstrations.
325
+ """
326
+
327
+ text: str
328
+
329
+
330
+ class CustomData(NCData):
331
+ """Generic container for application-specific data types.
332
+
333
+ Provides a flexible way to include custom sensor data or application-specific
334
+ information that doesn't fit into the standard data categories.
335
+ """
336
+
337
+ data: Any
338
+
339
+
340
+ class SyncPoint(BaseModel):
341
+ """Synchronized collection of all sensor data at a single time point.
342
+
343
+ Represents a complete snapshot of robot state and sensor information
344
+ at a specific timestamp. Used for creating temporally aligned datasets
345
+ and ensuring consistent data relationships across different sensors.
346
+ """
347
+
348
+ timestamp: float = Field(default_factory=lambda: time.time())
349
+ joint_positions: Optional[JointData] = None
350
+ joint_velocities: Optional[JointData] = None
351
+ joint_torques: Optional[JointData] = None
352
+ joint_target_positions: Optional[JointData] = None
353
+ end_effectors: Optional[EndEffectorData] = None
354
+ end_effector_poses: Optional[EndEffectorPoseData] = None
355
+ parallel_gripper_open_amounts: Optional[ParallelGripperOpenAmountData] = None
356
+ poses: Optional[PoseData] = None
357
+ rgb_images: Optional[dict[str, CameraData]] = None
358
+ depth_images: Optional[dict[str, CameraData]] = None
359
+ point_clouds: Optional[dict[str, PointCloudData]] = None
360
+ language_data: Optional[LanguageData] = None
361
+ custom_data: Optional[dict[str, CustomData]] = None
362
+ robot_id: Optional[str] = None
363
+
364
+ def order(self) -> "SyncPoint":
365
+ """Return a new SyncPoint with all dictionary data consistently ordered.
366
+
367
+ This method ensures all dictionary keys in the sync point are sorted
368
+ alphabetically to provide consistent ordering for machine learning models.
369
+ This is critical for model training and inference as it ensures deterministic
370
+ input ordering across different sync points.
371
+
372
+ The following fields are ordered:
373
+ - RGB images (by camera name)
374
+ - Depth images (by camera name)
375
+ - Point clouds (by sensor name)
376
+ - Custom data (by data type name)
377
+ - Joint data values (by joint name)
378
+ - Pose data (by pose name and pose coordinate names)
379
+ - End effector data (by effector name)
380
+
381
+ Returns:
382
+ New SyncPoint with all dictionary data consistently ordered.
383
+
384
+ Example:
385
+ >>> sync_point = SyncPoint(
386
+ ... rgb_images={"cam_2": data2, "cam_1": data1},
387
+ ... joint_positions=JointData(values={"joint_2": 1.0, "joint_1": 0.5})
388
+ ... )
389
+ >>> ordered = sync_point.order()
390
+ >>> list(ordered.rgb_images.keys())
391
+ ['cam_1', 'cam_2']
392
+ >>> list(ordered.joint_positions.values.keys())
393
+ ['joint_1', 'joint_2']
394
+ """
395
+ return SyncPoint(
396
+ timestamp=self.timestamp,
397
+ # Order joint data using their get_ordered methods
398
+ joint_positions=(
399
+ self.joint_positions.order() if self.joint_positions else None
400
+ ),
401
+ joint_velocities=(
402
+ self.joint_velocities.order() if self.joint_velocities else None
403
+ ),
404
+ joint_torques=(self.joint_torques.order() if self.joint_torques else None),
405
+ joint_target_positions=(
406
+ self.joint_target_positions.order()
407
+ if self.joint_target_positions
408
+ else None
409
+ ),
410
+ # Order end effector data
411
+ end_effectors=(self.end_effectors.order() if self.end_effectors else None),
412
+ # Order pose data (both pose names and pose coordinates)
413
+ poses=self.poses.order() if self.poses else None,
414
+ # Order end effector pose data
415
+ end_effector_poses=(
416
+ self.end_effector_poses.order() if self.end_effector_poses else None
417
+ ),
418
+ # Order parallel gripper open amount data
419
+ parallel_gripper_open_amounts=(
420
+ self.parallel_gripper_open_amounts.order()
421
+ if self.parallel_gripper_open_amounts
422
+ else None
423
+ ),
424
+ # Order camera data by camera/sensor names
425
+ rgb_images=_sort_dict_by_keys(self.rgb_images),
426
+ depth_images=_sort_dict_by_keys(self.depth_images),
427
+ point_clouds=_sort_dict_by_keys(self.point_clouds),
428
+ # Language data doesn't need ordering (single value)
429
+ language_data=self.language_data,
430
+ # Order custom data by data type names
431
+ custom_data=_sort_dict_by_keys(self.custom_data),
432
+ robot_id=self.robot_id,
433
+ )
434
+
435
+
436
+ class SyncedData(BaseModel):
437
+ """Complete synchronized dataset containing a sequence of data points.
438
+
439
+ Represents an entire recording or demonstration as a time-ordered sequence
440
+ of synchronized data points with start and end timestamps for temporal
441
+ reference.
442
+ """
443
+
444
+ frames: list[SyncPoint]
445
+ start_time: float
446
+ end_time: float
447
+ robot_id: str
448
+
449
+ def order(self) -> "SyncedData":
450
+ """Return a new SyncedData with all sync points ordered.
451
+
452
+ Returns:
453
+ New SyncedData with all sync points having consistent ordering.
454
+ """
455
+ return SyncedData(
456
+ frames=[frame.order() for frame in self.frames],
457
+ start_time=self.start_time,
458
+ end_time=self.end_time,
459
+ robot_id=self.robot_id,
460
+ )
461
+
462
+
463
+ class DataType(str, Enum):
464
+ """Enumeration of supported data types in the Neuracore system.
465
+
466
+ Defines the standard data categories used for dataset organization,
467
+ model training, and data processing pipelines.
468
+ """
469
+
470
+ # Robot state
471
+ JOINT_POSITIONS = "joint_positions"
472
+ JOINT_VELOCITIES = "joint_velocities"
473
+ JOINT_TORQUES = "joint_torques"
474
+ JOINT_TARGET_POSITIONS = "joint_target_positions"
475
+ END_EFFECTORS = "end_effectors"
476
+ END_EFFECTOR_POSES = "end_effector_poses"
477
+ PARALLEL_GRIPPER_OPEN_AMOUNTS = "parallel_gripper_open_amounts"
478
+
479
+ # Vision
480
+ RGB_IMAGE = "rgb_image"
481
+ DEPTH_IMAGE = "depth_image"
482
+ POINT_CLOUD = "point_cloud"
483
+
484
+ # Other
485
+ POSES = "poses"
486
+ LANGUAGE = "language"
487
+ CUSTOM = "custom"
488
+
489
+
490
+ class DataItemStats(BaseModel):
491
+ """Statistical summary of data dimensions and distributions.
492
+
493
+ Contains statistical information about data arrays including means,
494
+ standard deviations, counts, and maximum lengths for normalization
495
+ and model configuration purposes.
496
+
497
+ Attributes:
498
+ mean: List of means for each data dimension
499
+ std: List of standard deviations for each data dimension
500
+ count: List of counts for each data dimension
501
+ max_len: Maximum length of the data arrays
502
+ robot_to_ncdata_keys: Mapping of robot ids to their associated
503
+ data keys for this data type
504
+ """
505
+
506
+ mean: list[float] = Field(default_factory=list)
507
+ std: list[float] = Field(default_factory=list)
508
+ count: list[int] = Field(default_factory=list)
509
+ max_len: int = Field(default_factory=lambda data: len(data["mean"]))
510
+ robot_to_ncdata_keys: dict[str, list[str]] = Field(default_factory=dict)
511
+
512
+
513
+ class DatasetDescription(BaseModel):
514
+ """Comprehensive description of dataset contents and statistics.
515
+
516
+ Provides metadata about a complete dataset including statistical summaries
517
+ for all data types, maximum counts for variable-length data, and methods
518
+ for determining which data types are present.
519
+ """
520
+
521
+ # Joint data statistics
522
+ joint_positions: DataItemStats = Field(default_factory=DataItemStats)
523
+ joint_velocities: DataItemStats = Field(default_factory=DataItemStats)
524
+ joint_torques: DataItemStats = Field(default_factory=DataItemStats)
525
+ joint_target_positions: DataItemStats = Field(default_factory=DataItemStats)
526
+
527
+ # End-effector statistics
528
+ end_effector_states: DataItemStats = Field(default_factory=DataItemStats)
529
+
530
+ # End-effector poses statistics
531
+ end_effector_poses: DataItemStats = Field(default_factory=DataItemStats)
532
+
533
+ # Parallel gripper open amount statistics
534
+ parallel_gripper_open_amounts: DataItemStats = Field(default_factory=DataItemStats)
535
+
536
+ # Pose statistics
537
+ poses: DataItemStats = Field(default_factory=DataItemStats)
538
+
539
+ # Visual data counts
540
+ rgb_images: DataItemStats = Field(default_factory=DataItemStats)
541
+ depth_images: DataItemStats = Field(default_factory=DataItemStats)
542
+ point_clouds: DataItemStats = Field(default_factory=DataItemStats)
543
+
544
+ # Language data
545
+ language: DataItemStats = Field(default_factory=DataItemStats)
546
+
547
+ # Custom data statistics
548
+ custom_data: dict[str, DataItemStats] = Field(default_factory=dict)
549
+
550
+ def get_data_types(self) -> list[DataType]:
551
+ """Determine which data types are present in the dataset.
552
+
553
+ Analyzes the dataset statistics to identify which data modalities
554
+ contain actual data (non-zero lengths/counts).
555
+
556
+ Returns:
557
+ List of DataType enums representing the data modalities
558
+ present in this dataset.
559
+ """
560
+ data_types = []
561
+
562
+ # Joint data
563
+ if self.joint_positions.max_len > 0:
564
+ data_types.append(DataType.JOINT_POSITIONS)
565
+ if self.joint_velocities.max_len > 0:
566
+ data_types.append(DataType.JOINT_VELOCITIES)
567
+ if self.joint_torques.max_len > 0:
568
+ data_types.append(DataType.JOINT_TORQUES)
569
+ if self.joint_target_positions.max_len > 0:
570
+ data_types.append(DataType.JOINT_TARGET_POSITIONS)
571
+
572
+ # End-effector data
573
+ if self.end_effector_states.max_len > 0:
574
+ data_types.append(DataType.END_EFFECTORS)
575
+
576
+ # End effector pose data
577
+ if self.end_effector_poses.max_len > 0:
578
+ data_types.append(DataType.END_EFFECTOR_POSES)
579
+
580
+ # Parallel gripper open amount data
581
+ if self.parallel_gripper_open_amounts.max_len > 0:
582
+ data_types.append(DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS)
583
+
584
+ # Pose data
585
+ if self.poses.max_len > 0:
586
+ data_types.append(DataType.POSES)
587
+
588
+ # Visual data
589
+ if self.rgb_images.max_len > 0:
590
+ data_types.append(DataType.RGB_IMAGE)
591
+ if self.depth_images.max_len > 0:
592
+ data_types.append(DataType.DEPTH_IMAGE)
593
+ if self.point_clouds.max_len > 0:
594
+ data_types.append(DataType.POINT_CLOUD)
595
+
596
+ # Language data
597
+ if self.language.max_len > 0:
598
+ data_types.append(DataType.LANGUAGE)
599
+
600
+ # Custom data
601
+ if self.custom_data:
602
+ data_types.append(DataType.CUSTOM)
603
+
604
+ return data_types
605
+
606
+ def add_custom_data(
607
+ self, key: str, stats: DataItemStats, max_length: int = 0
608
+ ) -> None:
609
+ """Add statistics for a custom data type.
610
+
611
+ Args:
612
+ key: Name of the custom data type
613
+ stats: Statistical information for the custom data
614
+ max_length: Maximum length of the custom data arrays
615
+ """
616
+ self.custom_data[key] = stats
617
+
618
+
619
+ class RecordingDescription(BaseModel):
620
+ """Description of a single recording episode with statistics and counts.
621
+
622
+ Provides metadata about an individual recording including data statistics,
623
+ sensor counts, and episode length for analysis and processing.
624
+ """
625
+
626
+ # Joint data statistics
627
+ joint_positions: DataItemStats = Field(default_factory=DataItemStats)
628
+ joint_velocities: DataItemStats = Field(default_factory=DataItemStats)
629
+ joint_torques: DataItemStats = Field(default_factory=DataItemStats)
630
+ joint_target_positions: DataItemStats = Field(default_factory=DataItemStats)
631
+
632
+ # End-effector statistics
633
+ end_effector_states: DataItemStats = Field(default_factory=DataItemStats)
634
+
635
+ # End-effector pose statistics
636
+ end_effector_poses: DataItemStats = Field(default_factory=DataItemStats)
637
+
638
+ # Parallel gripper open amount statistics
639
+ parallel_gripper_open_amounts: DataItemStats = Field(default_factory=DataItemStats)
640
+
641
+ # Pose statistics
642
+ poses: DataItemStats = Field(default_factory=DataItemStats)
643
+
644
+ # Visual data counts
645
+ rgb_images: DataItemStats = Field(default_factory=DataItemStats)
646
+ depth_images: DataItemStats = Field(default_factory=DataItemStats)
647
+ point_clouds: DataItemStats = Field(default_factory=DataItemStats)
648
+
649
+ # Language data
650
+ language: DataItemStats = Field(default_factory=DataItemStats)
651
+
652
+ # Episode metadata
653
+ episode_length: int = 0
654
+
655
+ # Custom data statistics
656
+ custom_data: dict[str, DataItemStats] = Field(default_factory=dict)
657
+
658
+ def get_data_types(self) -> list[DataType]:
659
+ """Determine which data types are present in the recording.
660
+
661
+ Analyzes the recording statistics to identify which data modalities
662
+ contain actual data (non-zero lengths/counts).
663
+
664
+ Returns:
665
+ List of DataType enums representing the data modalities
666
+ present in this recording.
667
+ """
668
+ data_types = []
669
+
670
+ # Joint data
671
+ if self.joint_positions.max_len > 0:
672
+ data_types.append(DataType.JOINT_POSITIONS)
673
+ if self.joint_velocities.max_len > 0:
674
+ data_types.append(DataType.JOINT_VELOCITIES)
675
+ if self.joint_torques.max_len > 0:
676
+ data_types.append(DataType.JOINT_TORQUES)
677
+ if self.joint_target_positions.max_len > 0:
678
+ data_types.append(DataType.JOINT_TARGET_POSITIONS)
679
+
680
+ # End-effector data
681
+ if self.end_effector_states.max_len > 0:
682
+ data_types.append(DataType.END_EFFECTORS)
683
+
684
+ # End-effector pose data
685
+ if self.end_effector_poses.max_len > 0:
686
+ data_types.append(DataType.END_EFFECTOR_POSES)
687
+
688
+ # Parallel gripper open amount data
689
+ if self.parallel_gripper_open_amounts.max_len > 0:
690
+ data_types.append(DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS)
691
+
692
+ # Pose data
693
+ if self.poses.max_len > 0:
694
+ data_types.append(DataType.POSES)
695
+
696
+ # Visual data
697
+ if self.rgb_images.max_len > 0:
698
+ data_types.append(DataType.RGB_IMAGE)
699
+ if self.depth_images.max_len > 0:
700
+ data_types.append(DataType.DEPTH_IMAGE)
701
+ if self.point_clouds.max_len > 0:
702
+ data_types.append(DataType.POINT_CLOUD)
703
+
704
+ # Language data
705
+ if self.language.max_len > 0:
706
+ data_types.append(DataType.LANGUAGE)
707
+
708
+ # Custom data
709
+ if self.custom_data:
710
+ data_types.append(DataType.CUSTOM)
711
+
712
+ return data_types
713
+
714
+
715
+ class ModelInitDescription(BaseModel):
716
+ """Configuration specification for initializing Neuracore models.
717
+
718
+ Defines the model architecture requirements including dataset characteristics,
719
+ input/output data types, and prediction horizons for model initialization
720
+ and training configuration.
721
+ """
722
+
723
+ dataset_description: DatasetDescription
724
+ input_data_types: list[DataType]
725
+ output_data_types: list[DataType]
726
+ output_prediction_horizon: int = 1
727
+
728
+
729
+ class ModelPrediction(BaseModel):
730
+ """Model inference output containing predictions and timing information.
731
+
732
+ Represents the results of model inference including predicted outputs
733
+ for each configured data type and optional timing information for
734
+ performance monitoring.
735
+ """
736
+
737
+ outputs: dict[DataType, Any] = Field(default_factory=dict)
738
+ prediction_time: Optional[float] = None
739
+
740
+
741
+ class SyncedDataset(BaseModel):
742
+ """Represents a dataset of robot demonstrations.
743
+
744
+ A Synchronized dataset groups related robot demonstrations together
745
+ and maintains metadata about the collection as a whole.
746
+
747
+ Attributes:
748
+ id: Unique identifier for the synced dataset.
749
+ parent_id: Unique identifier of the corresponding dataset.
750
+ freq: Frequency at which dataset was processed.
751
+ name: Human-readable name for the dataset.
752
+ created_at: Unix timestamp of dataset creation.
753
+ modified_at: Unix timestamp of last modification.
754
+ description: Optional description of the dataset.
755
+ recording_ids: List of recording IDs in this dataset
756
+ num_demonstrations: Total number of demonstrations.
757
+ total_duration_seconds: Total duration of all demonstrations.
758
+ is_shared: Whether the dataset is shared with other users.
759
+ metadata: Additional arbitrary metadata.
760
+ """
761
+
762
+ id: str
763
+ parent_id: str
764
+ freq: int
765
+ name: str
766
+ created_at: float
767
+ modified_at: float
768
+ description: Optional[str] = None
769
+ recording_ids: list[str] = Field(default_factory=list)
770
+ num_demonstrations: int = 0
771
+ num_processed_demonstrations: int = 0
772
+ total_duration_seconds: float = 0.0
773
+ is_shared: bool = False
774
+ metadata: dict[str, Any] = Field(default_factory=dict)
775
+ dataset_description: DatasetDescription = Field(default_factory=DatasetDescription)
776
+ all_data_types: dict[DataType, int] = Field(default_factory=dict)
777
+ common_data_types: dict[DataType, int] = Field(default_factory=dict)
778
+
779
+
780
+ class Dataset(BaseModel):
781
+ """Represents a dataset of robot demonstrations.
782
+
783
+ A dataset groups related robot demonstrations together and maintains metadata
784
+ about the collection as a whole.
785
+
786
+ Attributes:
787
+ id: Unique identifier for the dataset.
788
+ name: Human-readable name for the dataset.
789
+ created_at: Unix timestamp of dataset creation.
790
+ modified_at: Unix timestamp of last modification.
791
+ description: Optional description of the dataset.
792
+ tags: List of tags for categorizing the dataset.
793
+ recording_ids: List of recording IDs in this dataset
794
+ demonstration_ids: List of demonstration IDs in this dataset.
795
+ num_demonstrations: Total number of demonstrations.
796
+ total_duration_seconds: Total duration of all demonstrations.
797
+ size_bytes: Total size of all demonstrations.
798
+ is_shared: Whether the dataset is shared with other users.
799
+ metadata: Additional arbitrary metadata.
800
+ synced_dataset_ids: List of synced dataset IDs in this dataset.
801
+ They point to synced datasets that synchronized
802
+ this dataset at a particular frequency.
803
+ """
804
+
805
+ id: str
806
+ name: str
807
+ created_at: float
808
+ modified_at: float
809
+ description: Optional[str] = None
810
+ tags: list[str] = Field(default_factory=list)
811
+ recording_ids: list[str] = Field(default_factory=list)
812
+ num_demonstrations: int = 0
813
+ total_duration_seconds: float = 0.0
814
+ size_bytes: int = 0
815
+ is_shared: bool = False
816
+ metadata: dict[str, Any] = Field(default_factory=dict)
817
+ synced_dataset_ids: dict[str, Any] = Field(default_factory=dict)
818
+ all_data_types: dict[DataType, int] = Field(default_factory=dict)
819
+ common_data_types: dict[DataType, int] = Field(default_factory=dict)
820
+ recording_ids_in_bucket: bool = False
821
+
822
+
823
+ class MessageType(str, Enum):
824
+ """Enumerates the types of signaling messages for WebRTC handshakes.
825
+
826
+ These types are used to identify the purpose of a message sent through
827
+ the signaling server during connection establishment.
828
+ """
829
+
830
+ SDP_OFFER = "offer" # Session Description Protocol (SDP) offer from the caller
831
+ SDP_ANSWER = "answer" # Session Description Protocol (SDP) answer from the callee
832
+ ICE_CANDIDATE = "ice" # Interactive Connectivity Establishment (ICE) candidate
833
+ OPEN_CONNECTION = "open_connection" # Request to open a new connection
834
+
835
+
836
+ class HandshakeMessage(BaseModel):
837
+ """Represents a signaling message for the WebRTC handshake process.
838
+
839
+ This message is exchanged between two peers via a signaling server to
840
+ negotiate the connection details, such as SDP offers/answers and ICE
841
+ candidates.
842
+
843
+ Attributes:
844
+ from_id: The unique identifier of the sender peer.
845
+ to_id: The unique identifier of the recipient peer.
846
+ data: The payload of the message, typically an SDP string or a JSON
847
+ object with ICE candidate information.
848
+ connection_id: The unique identifier for the connection session.
849
+ type: The type of the handshake message, as defined by MessageType.
850
+ id: A unique identifier for the message itself.
851
+ """
852
+
853
+ from_id: str
854
+ to_id: str
855
+ data: str
856
+ connection_id: str
857
+ type: MessageType
858
+ id: str = Field(default_factory=lambda: uuid4().hex)
859
+
860
+
861
+ class VideoFormat(str, Enum):
862
+ """Enumerates video format styles over a WebRTC connection."""
863
+
864
+ # use a standard video track with negotiated codec this is more efficient
865
+ WEB_RTC_NEGOTIATED = "WEB_RTC_NEGOTIATED"
866
+ # uses neuracore's data URI format over a custom data channel
867
+ NEURACORE_CUSTOM = "NEURACORE_CUSTOM"
868
+
869
+
870
+ class OpenConnectionRequest(BaseModel):
871
+ """Represents a request to open a new WebRTC connection.
872
+
873
+ Attributes:
874
+ from_id: The unique identifier of the consumer peer.
875
+ to_id: The unique identifier of the producer peer.
876
+ robot_id: The unique identifier for the robot to be created.
877
+ robot_instance: The identifier for the instance of the robot to connect to.
878
+ video_format: The type of video the consumer expects to receive.
879
+ id: the identifier for this connection request.
880
+ created_at: when the request was created.
881
+ """
882
+
883
+ from_id: str
884
+ to_id: str
885
+ robot_id: str
886
+ robot_instance: NonNegativeInt
887
+ video_format: VideoFormat
888
+ id: str = Field(default_factory=lambda: uuid4().hex)
889
+ created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
890
+
891
+
892
+ class OpenConnectionDetails(BaseModel):
893
+ """The details describing properties about the new connection.
894
+
895
+ Attributes:
896
+ connection_token: The token used for security to establish the connection.
897
+ robot_id: The unique identifier for the robot to connect to
898
+ robot_instance: The identifier for the instance of the robot to connect to.
899
+ video_format: The type of video the consumer expects to receive.
900
+ """
901
+
902
+ connection_token: str
903
+ robot_id: str
904
+ robot_instance: NonNegativeInt
905
+ video_format: VideoFormat
906
+
907
+
908
+ class StreamAliveResponse(BaseModel):
909
+ """Represents the response from asserting a stream is alive.
910
+
911
+ This is returned when a client pings a stream to keep it active.
912
+
913
+ Attributes:
914
+ resurrected: A boolean indicating if the stream was considered dead
915
+ and has been successfully resurrected by this request.
916
+ """
917
+
918
+ resurrected: bool
919
+
920
+
921
+ class RobotInstanceIdentifier(NamedTuple):
922
+ """A tuple that uniquely identifies a robot instance.
923
+
924
+ Attributes:
925
+ robot_id: The unique identifier of the robot providing the stream.
926
+ robot_instance: The specific instance number of the robot.
927
+ """
928
+
929
+ robot_id: str
930
+ robot_instance: int
931
+
932
+
933
+ class TrackKind(str, Enum):
934
+ """Enumerates the supported track kinds for streaming."""
935
+
936
+ JOINTS = "joints"
937
+ RGB = "rgb"
938
+ DEPTH = "depth"
939
+ LANGUAGE = "language"
940
+ GRIPPER = "gripper"
941
+ END_EFFECTOR_POSE = "end_effector_pose"
942
+ PARALLEL_GRIPPER_OPEN_AMOUNT = "parallel_gripper_open_amount"
943
+ POINT_CLOUD = "point_cloud"
944
+ POSE = "pose"
945
+ CUSTOM = "custom"
946
+
947
+
948
+ class RobotStreamTrack(BaseModel):
949
+ """Metadata for a robot's media stream track.
950
+
951
+ This model holds all the necessary information to identify and manage
952
+ a single media track (e.g., a video or audio feed) from a specific
953
+ robot instance.
954
+
955
+ Attributes:
956
+ robot_id: The unique identifier of the robot providing the stream.
957
+ robot_instance: The specific instance number of the robot.
958
+ stream_id: The identifier for the overall media stream session.
959
+ kind: The type of media track, typically 'audio' or 'video'.
960
+ label: A human-readable label for the track (e.g., 'front_camera').
961
+ mid: The media ID used in SDP, essential for WebRTC negotiation.
962
+ id: A unique identifier for this track metadata object.
963
+ created_at: The UTC timestamp when this track metadata was created.
964
+ """
965
+
966
+ robot_id: str
967
+ robot_instance: NonNegativeInt
968
+ stream_id: str
969
+ kind: TrackKind
970
+ label: str
971
+ mid: str
972
+ id: str = Field(default_factory=lambda: uuid4().hex)
973
+ created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
974
+
975
+
976
+ class AvailableRobotInstance(BaseModel):
977
+ """Represents a single, available instance of a robot.
978
+
979
+ Attributes:
980
+ robot_instance: The unique identifier for this robot instance.
981
+ tracks: A dictionary of available media stream tracks for this instance.
982
+ connections: The number of current connections to this instance.
983
+ """
984
+
985
+ robot_instance: NonNegativeInt
986
+ # stream_id to list of tracks
987
+ tracks: dict[str, list[RobotStreamTrack]]
988
+ connections: int
989
+
990
+
991
+ class AvailableRobot(BaseModel):
992
+ """Represents an available robot, including all its running instances.
993
+
994
+ Attributes:
995
+ robot_id: The unique identifier for the robot model/type.
996
+ instances: A dictionary of all available instances for this robot,
997
+ keyed by instance ID.
998
+ """
999
+
1000
+ robot_id: str
1001
+ instances: dict[int, AvailableRobotInstance]
1002
+
1003
+
1004
+ class AvailableRobotCapacityUpdate(BaseModel):
1005
+ """Represents an update on the available capacity of all robots.
1006
+
1007
+ This model is used to broadcast the current state of all available
1008
+ robots and their instances.
1009
+
1010
+ Attributes:
1011
+ robots: A list of all available robots and their instances.
1012
+ """
1013
+
1014
+ robots: list[AvailableRobot]
1015
+
1016
+
1017
+ class BaseRecodingUpdatePayload(BaseModel):
1018
+ """Base payload for recording update notifications.
1019
+
1020
+ Contains the minimum information needed to identify a recording
1021
+ and the robot instance it belongs to.
1022
+ """
1023
+
1024
+ recording_id: str
1025
+ robot_id: str
1026
+ instance: NonNegativeInt
1027
+
1028
+
1029
+ class RecodingRequestedPayload(BaseRecodingUpdatePayload):
1030
+ """Payload for recording request notifications.
1031
+
1032
+ Contains information about who requested the recording and what
1033
+ data types should be captured.
1034
+ """
1035
+
1036
+ created_by: str
1037
+ dataset_ids: list[str] = Field(default_factory=list)
1038
+ data_types: set[DataType] = Field(default_factory=set)
1039
+
1040
+
1041
+ class RecordingStartPayload(RecodingRequestedPayload):
1042
+ """Payload for recording start notifications.
1043
+
1044
+ Extends the request payload with the actual start timestamp
1045
+ when recording begins.
1046
+ """
1047
+
1048
+ start_time: float
1049
+
1050
+
1051
+ class RecordingNotificationType(str, Enum):
1052
+ """Types of recording lifecycle notifications."""
1053
+
1054
+ INIT = "init"
1055
+ REQUESTED = "requested"
1056
+ START = "start"
1057
+ STOP = "stop"
1058
+ SAVED = "saved"
1059
+ DISCARDED = "discarded"
1060
+ EXPIRED = "expired"
1061
+
1062
+
1063
+ class RecordingNotification(BaseModel):
1064
+ """Notification message for recording lifecycle events.
1065
+
1066
+ Used to communicate recording state changes across the system,
1067
+ including initialization, start/stop events, and final disposition.
1068
+ """
1069
+
1070
+ type: RecordingNotificationType
1071
+ payload: Union[
1072
+ RecordingStartPayload,
1073
+ RecodingRequestedPayload,
1074
+ list[Union[RecordingStartPayload, RecodingRequestedPayload]],
1075
+ BaseRecodingUpdatePayload,
1076
+ ]
1077
+ id: str = Field(default_factory=lambda: uuid4().hex)
@@ -0,0 +1,89 @@
1
+ Metadata-Version: 2.4
2
+ Name: neuracore-types
3
+ Version: 1.0.0
4
+ Summary: Shared type definitions for Neuracore.
5
+ Author: Neuracore
6
+ License: MIT
7
+ License-File: LICENSE
8
+ Requires-Python: >=3.10
9
+ Requires-Dist: numpy>=1.24.0
10
+ Requires-Dist: pydantic>=2.0.0
11
+ Provides-Extra: dev
12
+ Requires-Dist: pre-commit; extra == 'dev'
13
+ Requires-Dist: pydantic-to-typescript2>=1.0.0; extra == 'dev'
14
+ Description-Content-Type: text/markdown
15
+
16
+ # Neuracore Types
17
+
18
+ Shared type definitions for the Neuracore platform. This package maintains a single source of truth for data types in Python (Pydantic models) and automatically generates TypeScript types.
19
+
20
+ ## Overview
21
+
22
+ - **Python Package**: `neuracore-types` - Pydantic models for Python backend
23
+ - **NPM Package**: `@neuracore/types` - TypeScript types for frontend
24
+
25
+ ## Installation
26
+
27
+ ### Python
28
+
29
+ ```bash
30
+ pip install neuracore-types
31
+ ```
32
+
33
+ ### TypeScript/JavaScript
34
+
35
+ ```bash
36
+ npm install @neuracore/types
37
+ # or
38
+ yarn add @neuracore/types
39
+ # or
40
+ pnpm add @neuracore/types
41
+ ```
42
+
43
+ ## Development
44
+
45
+ ### Setup
46
+
47
+ ```bash
48
+ # Clone the repository
49
+ git clone https://github.com/neuracoreai/neuracore_types.git
50
+ cd neuracore_types
51
+
52
+ # Install Python dependencies
53
+ pip install -e ".[dev]"
54
+
55
+ # Install Node dependencies
56
+ npm install
57
+ ```
58
+
59
+ ### Generate TypeScript Types
60
+
61
+ The TypeScript types are automatically generated from the Python Pydantic models:
62
+
63
+ ```bash
64
+ npm install json-schema-to-typescript
65
+ python scripts/generate_types.py
66
+ ```
67
+
68
+ This will:
69
+ 1. Read the Pydantic models from `neuracore_types/neuracore_types.py`
70
+ 2. Generate TypeScript definitions in `typescript/neuracore_types.ts`
71
+ 3. Create an index file at `typescript/index.ts`
72
+
73
+ ### Build TypeScript Package
74
+
75
+ ```bash
76
+ npm run build
77
+ ```
78
+
79
+ This compiles the TypeScript files to JavaScript and generates type declarations in the `dist/` directory.
80
+
81
+ ## CI/CD
82
+
83
+ The repository includes GitHub Actions workflows that:
84
+
85
+ 1. **On every push to `main` or PR**:
86
+ - Automatically generates TypeScript types from Python models
87
+ - Builds and validates both packages
88
+ - Publishes Python package to PyPI
89
+ - Publishes NPM package to npm registry
@@ -0,0 +1,6 @@
1
+ neuracore_types/__init__.py,sha256=oxBrxwWgs4DxiD5ITELQGVKnXyN4aFwb3BE-1ANXihE,147
2
+ neuracore_types/neuracore_types.py,sha256=Nsl0Kg-50rKQCWRDD-8O8WpuoGDdugGdwDFRaPSPQfQ,37546
3
+ neuracore_types-1.0.0.dist-info/METADATA,sha256=X7QO87UcsFy5VCU2hUlIBliACQ5I4PwyuSRJfuyq8B0,2073
4
+ neuracore_types-1.0.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
+ neuracore_types-1.0.0.dist-info/licenses/LICENSE,sha256=TPLH9MVhc33h8soaIkSnaVcBTspRM8QK4ipd61wQlJk,1066
6
+ neuracore_types-1.0.0.dist-info/RECORD,,
@@ -0,0 +1,4 @@
1
+ Wheel-Version: 1.0
2
+ Generator: hatchling 1.27.0
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Neuracore
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.