rocket-welder-sdk 1.1.32__py3-none-any.whl → 1.1.33__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rocket_welder_sdk/__init__.py +5 -6
- rocket_welder_sdk/controllers.py +134 -101
- rocket_welder_sdk/frame_metadata.py +138 -0
- rocket_welder_sdk/high_level/__init__.py +66 -0
- rocket_welder_sdk/high_level/connection_strings.py +330 -0
- rocket_welder_sdk/high_level/data_context.py +163 -0
- rocket_welder_sdk/high_level/schema.py +180 -0
- rocket_welder_sdk/high_level/transport_protocol.py +166 -0
- rocket_welder_sdk/keypoints_protocol.py +642 -0
- rocket_welder_sdk/rocket_welder_client.py +17 -3
- rocket_welder_sdk/segmentation_result.py +420 -0
- rocket_welder_sdk/transport/__init__.py +38 -0
- rocket_welder_sdk/transport/frame_sink.py +77 -0
- rocket_welder_sdk/transport/frame_source.py +74 -0
- rocket_welder_sdk/transport/nng_transport.py +197 -0
- rocket_welder_sdk/transport/stream_transport.py +193 -0
- rocket_welder_sdk/transport/tcp_transport.py +154 -0
- rocket_welder_sdk/transport/unix_socket_transport.py +339 -0
- {rocket_welder_sdk-1.1.32.dist-info → rocket_welder_sdk-1.1.33.dist-info}/METADATA +15 -2
- rocket_welder_sdk-1.1.33.dist-info/RECORD +37 -0
- rocket_welder_sdk-1.1.32.dist-info/RECORD +0 -22
- {rocket_welder_sdk-1.1.32.dist-info → rocket_welder_sdk-1.1.33.dist-info}/WHEEL +0 -0
- {rocket_welder_sdk-1.1.32.dist-info → rocket_welder_sdk-1.1.33.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,330 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Strongly-typed connection strings with parsing support.
|
|
3
|
+
|
|
4
|
+
Connection string format: protocol://path?param1=value1¶m2=value2
|
|
5
|
+
|
|
6
|
+
Examples:
|
|
7
|
+
nng+push+ipc://tmp/keypoints?masterFrameInterval=300
|
|
8
|
+
nng+pub+tcp://localhost:5555
|
|
9
|
+
file://path/to/output.bin
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
import contextlib
|
|
15
|
+
import os
|
|
16
|
+
from dataclasses import dataclass, field
|
|
17
|
+
from enum import Enum, auto
|
|
18
|
+
from typing import Dict, Optional
|
|
19
|
+
from urllib.parse import parse_qs
|
|
20
|
+
|
|
21
|
+
from .transport_protocol import TransportProtocol
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class VideoSourceType(Enum):
|
|
25
|
+
"""Type of video source."""
|
|
26
|
+
|
|
27
|
+
CAMERA = auto()
|
|
28
|
+
FILE = auto()
|
|
29
|
+
SHARED_MEMORY = auto()
|
|
30
|
+
RTSP = auto()
|
|
31
|
+
HTTP = auto()
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@dataclass(frozen=True)
|
|
35
|
+
class VideoSourceConnectionString:
|
|
36
|
+
"""
|
|
37
|
+
Strongly-typed connection string for video source input.
|
|
38
|
+
|
|
39
|
+
Supported formats:
|
|
40
|
+
- "0", "1", etc. - Camera device index
|
|
41
|
+
- file://path/to/video.mp4 - Video file
|
|
42
|
+
- shm://buffer_name - Shared memory buffer
|
|
43
|
+
- rtsp://host/stream - RTSP stream
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
value: str
|
|
47
|
+
source_type: VideoSourceType
|
|
48
|
+
camera_index: Optional[int] = None
|
|
49
|
+
path: Optional[str] = None
|
|
50
|
+
parameters: Dict[str, str] = field(default_factory=dict)
|
|
51
|
+
|
|
52
|
+
@classmethod
|
|
53
|
+
def default(cls) -> VideoSourceConnectionString:
|
|
54
|
+
"""Default video source (camera 0)."""
|
|
55
|
+
return cls.parse("0")
|
|
56
|
+
|
|
57
|
+
@classmethod
|
|
58
|
+
def from_environment(cls, variable_name: str = "VIDEO_SOURCE") -> VideoSourceConnectionString:
|
|
59
|
+
"""Create from environment variable or use default."""
|
|
60
|
+
value = os.environ.get(variable_name) or os.environ.get("CONNECTION_STRING")
|
|
61
|
+
return cls.parse(value) if value else cls.default()
|
|
62
|
+
|
|
63
|
+
@classmethod
|
|
64
|
+
def parse(cls, s: str) -> VideoSourceConnectionString:
|
|
65
|
+
"""Parse a connection string."""
|
|
66
|
+
result = cls.try_parse(s)
|
|
67
|
+
if result is None:
|
|
68
|
+
raise ValueError(f"Invalid video source connection string: {s}")
|
|
69
|
+
return result
|
|
70
|
+
|
|
71
|
+
@classmethod
|
|
72
|
+
def try_parse(cls, s: str) -> Optional[VideoSourceConnectionString]:
|
|
73
|
+
"""Try to parse a connection string."""
|
|
74
|
+
if not s or not s.strip():
|
|
75
|
+
return None
|
|
76
|
+
|
|
77
|
+
s = s.strip()
|
|
78
|
+
parameters: Dict[str, str] = {}
|
|
79
|
+
|
|
80
|
+
# Extract query parameters
|
|
81
|
+
if "?" in s:
|
|
82
|
+
base, query = s.split("?", 1)
|
|
83
|
+
for key, values in parse_qs(query).items():
|
|
84
|
+
parameters[key.lower()] = values[0] if values else ""
|
|
85
|
+
s = base
|
|
86
|
+
|
|
87
|
+
# Check for camera index first
|
|
88
|
+
if s.isdigit():
|
|
89
|
+
return cls(
|
|
90
|
+
value=s,
|
|
91
|
+
source_type=VideoSourceType.CAMERA,
|
|
92
|
+
camera_index=int(s),
|
|
93
|
+
parameters=parameters,
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
# Parse protocol
|
|
97
|
+
if s.startswith("file://"):
|
|
98
|
+
path = "/" + s[7:] # Restore absolute path
|
|
99
|
+
return cls(
|
|
100
|
+
value=s,
|
|
101
|
+
source_type=VideoSourceType.FILE,
|
|
102
|
+
path=path,
|
|
103
|
+
parameters=parameters,
|
|
104
|
+
)
|
|
105
|
+
elif s.startswith("shm://"):
|
|
106
|
+
path = s[6:]
|
|
107
|
+
return cls(
|
|
108
|
+
value=s,
|
|
109
|
+
source_type=VideoSourceType.SHARED_MEMORY,
|
|
110
|
+
path=path,
|
|
111
|
+
parameters=parameters,
|
|
112
|
+
)
|
|
113
|
+
elif s.startswith("rtsp://"):
|
|
114
|
+
return cls(
|
|
115
|
+
value=s,
|
|
116
|
+
source_type=VideoSourceType.RTSP,
|
|
117
|
+
path=s,
|
|
118
|
+
parameters=parameters,
|
|
119
|
+
)
|
|
120
|
+
elif s.startswith("http://") or s.startswith("https://"):
|
|
121
|
+
return cls(
|
|
122
|
+
value=s,
|
|
123
|
+
source_type=VideoSourceType.HTTP,
|
|
124
|
+
path=s,
|
|
125
|
+
parameters=parameters,
|
|
126
|
+
)
|
|
127
|
+
elif "://" not in s:
|
|
128
|
+
# Assume file path
|
|
129
|
+
return cls(
|
|
130
|
+
value=s,
|
|
131
|
+
source_type=VideoSourceType.FILE,
|
|
132
|
+
path=s,
|
|
133
|
+
parameters=parameters,
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
return None
|
|
137
|
+
|
|
138
|
+
def __str__(self) -> str:
|
|
139
|
+
return self.value
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
@dataclass(frozen=True)
|
|
143
|
+
class KeyPointsConnectionString:
|
|
144
|
+
"""
|
|
145
|
+
Strongly-typed connection string for KeyPoints output.
|
|
146
|
+
|
|
147
|
+
Supported protocols (composable with + operator):
|
|
148
|
+
- Transport.Nng + Transport.Push + Transport.Ipc → nng+push+ipc://tmp/keypoints
|
|
149
|
+
- Transport.Nng + Transport.Push + Transport.Tcp → nng+push+tcp://host:port
|
|
150
|
+
- file://path/to/file.bin - File output
|
|
151
|
+
|
|
152
|
+
Supported parameters:
|
|
153
|
+
- masterFrameInterval: Interval between master frames (default: 300)
|
|
154
|
+
"""
|
|
155
|
+
|
|
156
|
+
value: str
|
|
157
|
+
protocol: Optional[TransportProtocol] = None
|
|
158
|
+
is_file: bool = False
|
|
159
|
+
address: str = ""
|
|
160
|
+
master_frame_interval: int = 300
|
|
161
|
+
parameters: Dict[str, str] = field(default_factory=dict)
|
|
162
|
+
|
|
163
|
+
@classmethod
|
|
164
|
+
def default(cls) -> KeyPointsConnectionString:
|
|
165
|
+
"""Default connection string for KeyPoints."""
|
|
166
|
+
return cls.parse("nng+push+ipc://tmp/rocket-welder-keypoints?masterFrameInterval=300")
|
|
167
|
+
|
|
168
|
+
@classmethod
|
|
169
|
+
def from_environment(
|
|
170
|
+
cls, variable_name: str = "KEYPOINTS_CONNECTION_STRING"
|
|
171
|
+
) -> KeyPointsConnectionString:
|
|
172
|
+
"""Create from environment variable or use default."""
|
|
173
|
+
value = os.environ.get(variable_name)
|
|
174
|
+
return cls.parse(value) if value else cls.default()
|
|
175
|
+
|
|
176
|
+
@classmethod
|
|
177
|
+
def parse(cls, s: str) -> KeyPointsConnectionString:
|
|
178
|
+
"""Parse a connection string."""
|
|
179
|
+
result = cls.try_parse(s)
|
|
180
|
+
if result is None:
|
|
181
|
+
raise ValueError(f"Invalid KeyPoints connection string: {s}")
|
|
182
|
+
return result
|
|
183
|
+
|
|
184
|
+
@classmethod
|
|
185
|
+
def try_parse(cls, s: str) -> Optional[KeyPointsConnectionString]:
|
|
186
|
+
"""Try to parse a connection string."""
|
|
187
|
+
if not s or not s.strip():
|
|
188
|
+
return None
|
|
189
|
+
|
|
190
|
+
s = s.strip()
|
|
191
|
+
parameters: Dict[str, str] = {}
|
|
192
|
+
|
|
193
|
+
# Extract query parameters
|
|
194
|
+
endpoint_part = s
|
|
195
|
+
if "?" in s:
|
|
196
|
+
endpoint_part, query = s.split("?", 1)
|
|
197
|
+
for key, values in parse_qs(query).items():
|
|
198
|
+
parameters[key.lower()] = values[0] if values else ""
|
|
199
|
+
|
|
200
|
+
# Parse protocol and address
|
|
201
|
+
scheme_end = endpoint_part.find("://")
|
|
202
|
+
if scheme_end > 0:
|
|
203
|
+
protocol_str = endpoint_part[:scheme_end]
|
|
204
|
+
path_part = endpoint_part[scheme_end + 3 :] # skip "://"
|
|
205
|
+
|
|
206
|
+
if protocol_str.lower() == "file":
|
|
207
|
+
address = "/" + path_part # Restore absolute path
|
|
208
|
+
is_file = True
|
|
209
|
+
protocol = None
|
|
210
|
+
else:
|
|
211
|
+
protocol = TransportProtocol.try_parse(protocol_str)
|
|
212
|
+
if protocol is None:
|
|
213
|
+
return None
|
|
214
|
+
address = protocol.create_nng_address(path_part)
|
|
215
|
+
is_file = False
|
|
216
|
+
elif s.startswith("/"):
|
|
217
|
+
# Assume absolute file path
|
|
218
|
+
address = s
|
|
219
|
+
is_file = True
|
|
220
|
+
protocol = None
|
|
221
|
+
else:
|
|
222
|
+
return None
|
|
223
|
+
|
|
224
|
+
# Parse masterFrameInterval
|
|
225
|
+
master_frame_interval = 300 # default
|
|
226
|
+
if "masterframeinterval" in parameters:
|
|
227
|
+
with contextlib.suppress(ValueError):
|
|
228
|
+
master_frame_interval = int(parameters["masterframeinterval"])
|
|
229
|
+
|
|
230
|
+
return cls(
|
|
231
|
+
value=s,
|
|
232
|
+
protocol=protocol,
|
|
233
|
+
is_file=is_file,
|
|
234
|
+
address=address,
|
|
235
|
+
master_frame_interval=master_frame_interval,
|
|
236
|
+
parameters=parameters,
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
def __str__(self) -> str:
|
|
240
|
+
return self.value
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
@dataclass(frozen=True)
|
|
244
|
+
class SegmentationConnectionString:
|
|
245
|
+
"""
|
|
246
|
+
Strongly-typed connection string for Segmentation output.
|
|
247
|
+
|
|
248
|
+
Supported protocols (composable with + operator):
|
|
249
|
+
- Transport.Nng + Transport.Push + Transport.Ipc → nng+push+ipc://tmp/segmentation
|
|
250
|
+
- Transport.Nng + Transport.Push + Transport.Tcp → nng+push+tcp://host:port
|
|
251
|
+
- file://path/to/file.bin - File output
|
|
252
|
+
"""
|
|
253
|
+
|
|
254
|
+
value: str
|
|
255
|
+
protocol: Optional[TransportProtocol] = None
|
|
256
|
+
is_file: bool = False
|
|
257
|
+
address: str = ""
|
|
258
|
+
parameters: Dict[str, str] = field(default_factory=dict)
|
|
259
|
+
|
|
260
|
+
@classmethod
|
|
261
|
+
def default(cls) -> SegmentationConnectionString:
|
|
262
|
+
"""Default connection string for Segmentation."""
|
|
263
|
+
return cls.parse("nng+push+ipc://tmp/rocket-welder-segmentation")
|
|
264
|
+
|
|
265
|
+
@classmethod
|
|
266
|
+
def from_environment(
|
|
267
|
+
cls, variable_name: str = "SEGMENTATION_CONNECTION_STRING"
|
|
268
|
+
) -> SegmentationConnectionString:
|
|
269
|
+
"""Create from environment variable or use default."""
|
|
270
|
+
value = os.environ.get(variable_name)
|
|
271
|
+
return cls.parse(value) if value else cls.default()
|
|
272
|
+
|
|
273
|
+
@classmethod
|
|
274
|
+
def parse(cls, s: str) -> SegmentationConnectionString:
|
|
275
|
+
"""Parse a connection string."""
|
|
276
|
+
result = cls.try_parse(s)
|
|
277
|
+
if result is None:
|
|
278
|
+
raise ValueError(f"Invalid Segmentation connection string: {s}")
|
|
279
|
+
return result
|
|
280
|
+
|
|
281
|
+
@classmethod
|
|
282
|
+
def try_parse(cls, s: str) -> Optional[SegmentationConnectionString]:
|
|
283
|
+
"""Try to parse a connection string."""
|
|
284
|
+
if not s or not s.strip():
|
|
285
|
+
return None
|
|
286
|
+
|
|
287
|
+
s = s.strip()
|
|
288
|
+
parameters: Dict[str, str] = {}
|
|
289
|
+
|
|
290
|
+
# Extract query parameters
|
|
291
|
+
endpoint_part = s
|
|
292
|
+
if "?" in s:
|
|
293
|
+
endpoint_part, query = s.split("?", 1)
|
|
294
|
+
for key, values in parse_qs(query).items():
|
|
295
|
+
parameters[key.lower()] = values[0] if values else ""
|
|
296
|
+
|
|
297
|
+
# Parse protocol and address
|
|
298
|
+
scheme_end = endpoint_part.find("://")
|
|
299
|
+
if scheme_end > 0:
|
|
300
|
+
protocol_str = endpoint_part[:scheme_end]
|
|
301
|
+
path_part = endpoint_part[scheme_end + 3 :] # skip "://"
|
|
302
|
+
|
|
303
|
+
if protocol_str.lower() == "file":
|
|
304
|
+
address = "/" + path_part # Restore absolute path
|
|
305
|
+
is_file = True
|
|
306
|
+
protocol = None
|
|
307
|
+
else:
|
|
308
|
+
protocol = TransportProtocol.try_parse(protocol_str)
|
|
309
|
+
if protocol is None:
|
|
310
|
+
return None
|
|
311
|
+
address = protocol.create_nng_address(path_part)
|
|
312
|
+
is_file = False
|
|
313
|
+
elif s.startswith("/"):
|
|
314
|
+
# Assume absolute file path
|
|
315
|
+
address = s
|
|
316
|
+
is_file = True
|
|
317
|
+
protocol = None
|
|
318
|
+
else:
|
|
319
|
+
return None
|
|
320
|
+
|
|
321
|
+
return cls(
|
|
322
|
+
value=s,
|
|
323
|
+
protocol=protocol,
|
|
324
|
+
is_file=is_file,
|
|
325
|
+
address=address,
|
|
326
|
+
parameters=parameters,
|
|
327
|
+
)
|
|
328
|
+
|
|
329
|
+
def __str__(self) -> str:
|
|
330
|
+
return self.value
|
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Data context types for per-frame keypoints and segmentation data.
|
|
3
|
+
|
|
4
|
+
Implements the Unit of Work pattern - contexts are created per-frame
|
|
5
|
+
and auto-commit when the processing delegate returns.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
from abc import ABC, abstractmethod
|
|
11
|
+
from typing import TYPE_CHECKING, Sequence, Tuple, Union
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
import numpy.typing as npt
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from rocket_welder_sdk.keypoints_protocol import IKeyPointsWriter
|
|
18
|
+
from rocket_welder_sdk.segmentation_result import SegmentationResultWriter
|
|
19
|
+
|
|
20
|
+
from .schema import KeyPoint, SegmentClass
|
|
21
|
+
|
|
22
|
+
# Type aliases
|
|
23
|
+
Point = Tuple[int, int]
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class IKeyPointsDataContext(ABC):
|
|
27
|
+
"""
|
|
28
|
+
Unit of Work for keypoints data, scoped to a single frame.
|
|
29
|
+
|
|
30
|
+
Auto-commits when the processing delegate returns.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
@property
|
|
34
|
+
@abstractmethod
|
|
35
|
+
def frame_id(self) -> int:
|
|
36
|
+
"""Current frame ID."""
|
|
37
|
+
pass
|
|
38
|
+
|
|
39
|
+
@abstractmethod
|
|
40
|
+
def add(self, point: KeyPoint, x: int, y: int, confidence: float) -> None:
|
|
41
|
+
"""
|
|
42
|
+
Add a keypoint detection for this frame.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
point: KeyPoint from schema definition
|
|
46
|
+
x: X coordinate in pixels
|
|
47
|
+
y: Y coordinate in pixels
|
|
48
|
+
confidence: Detection confidence (0.0 to 1.0)
|
|
49
|
+
"""
|
|
50
|
+
pass
|
|
51
|
+
|
|
52
|
+
@abstractmethod
|
|
53
|
+
def add_point(self, point: KeyPoint, position: Point, confidence: float) -> None:
|
|
54
|
+
"""
|
|
55
|
+
Add a keypoint detection using a Point tuple.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
point: KeyPoint from schema definition
|
|
59
|
+
position: (x, y) tuple
|
|
60
|
+
confidence: Detection confidence (0.0 to 1.0)
|
|
61
|
+
"""
|
|
62
|
+
pass
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class ISegmentationDataContext(ABC):
|
|
66
|
+
"""
|
|
67
|
+
Unit of Work for segmentation data, scoped to a single frame.
|
|
68
|
+
|
|
69
|
+
Auto-commits when the processing delegate returns.
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
@property
|
|
73
|
+
@abstractmethod
|
|
74
|
+
def frame_id(self) -> int:
|
|
75
|
+
"""Current frame ID."""
|
|
76
|
+
pass
|
|
77
|
+
|
|
78
|
+
@abstractmethod
|
|
79
|
+
def add(
|
|
80
|
+
self,
|
|
81
|
+
segment_class: SegmentClass,
|
|
82
|
+
instance_id: int,
|
|
83
|
+
points: Union[Sequence[Point], npt.NDArray[np.int32]],
|
|
84
|
+
) -> None:
|
|
85
|
+
"""
|
|
86
|
+
Add a segmentation instance for this frame.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
segment_class: SegmentClass from schema definition
|
|
90
|
+
instance_id: Instance ID (for multiple instances of same class, 0-255)
|
|
91
|
+
points: Contour points defining the instance boundary
|
|
92
|
+
"""
|
|
93
|
+
pass
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class KeyPointsDataContext(IKeyPointsDataContext):
|
|
97
|
+
"""Implementation of keypoints data context."""
|
|
98
|
+
|
|
99
|
+
def __init__(
|
|
100
|
+
self,
|
|
101
|
+
frame_id: int,
|
|
102
|
+
writer: IKeyPointsWriter,
|
|
103
|
+
) -> None:
|
|
104
|
+
from .schema import KeyPoint # noqa: F401
|
|
105
|
+
|
|
106
|
+
self._frame_id = frame_id
|
|
107
|
+
self._writer = writer
|
|
108
|
+
|
|
109
|
+
@property
|
|
110
|
+
def frame_id(self) -> int:
|
|
111
|
+
return self._frame_id
|
|
112
|
+
|
|
113
|
+
def add(self, point: KeyPoint, x: int, y: int, confidence: float) -> None:
|
|
114
|
+
"""Add a keypoint detection for this frame."""
|
|
115
|
+
self._writer.append(point.id, x, y, confidence)
|
|
116
|
+
|
|
117
|
+
def add_point(self, point: KeyPoint, position: Point, confidence: float) -> None:
|
|
118
|
+
"""Add a keypoint detection using a Point tuple."""
|
|
119
|
+
self._writer.append_point(point.id, position, confidence)
|
|
120
|
+
|
|
121
|
+
def commit(self) -> None:
|
|
122
|
+
"""Commit the context (called automatically when delegate returns)."""
|
|
123
|
+
self._writer.close()
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
class SegmentationDataContext(ISegmentationDataContext):
|
|
127
|
+
"""Implementation of segmentation data context."""
|
|
128
|
+
|
|
129
|
+
def __init__(
|
|
130
|
+
self,
|
|
131
|
+
frame_id: int,
|
|
132
|
+
writer: SegmentationResultWriter,
|
|
133
|
+
) -> None:
|
|
134
|
+
from .schema import SegmentClass # noqa: F401
|
|
135
|
+
|
|
136
|
+
self._frame_id = frame_id
|
|
137
|
+
self._writer = writer
|
|
138
|
+
|
|
139
|
+
@property
|
|
140
|
+
def frame_id(self) -> int:
|
|
141
|
+
return self._frame_id
|
|
142
|
+
|
|
143
|
+
def add(
|
|
144
|
+
self,
|
|
145
|
+
segment_class: SegmentClass,
|
|
146
|
+
instance_id: int,
|
|
147
|
+
points: Union[Sequence[Point], npt.NDArray[np.int32]],
|
|
148
|
+
) -> None:
|
|
149
|
+
"""Add a segmentation instance for this frame."""
|
|
150
|
+
if instance_id < 0 or instance_id > 255:
|
|
151
|
+
raise ValueError(f"instance_id must be 0-255, got {instance_id}")
|
|
152
|
+
|
|
153
|
+
# Convert to numpy array if needed
|
|
154
|
+
if isinstance(points, np.ndarray):
|
|
155
|
+
points_array = points
|
|
156
|
+
else:
|
|
157
|
+
points_array = np.array(points, dtype=np.int32)
|
|
158
|
+
|
|
159
|
+
self._writer.append(segment_class.class_id, instance_id, points_array)
|
|
160
|
+
|
|
161
|
+
def commit(self) -> None:
|
|
162
|
+
"""Commit the context (called automatically when delegate returns)."""
|
|
163
|
+
self._writer.close()
|
|
@@ -0,0 +1,180 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Schema types for KeyPoints and Segmentation.
|
|
3
|
+
|
|
4
|
+
Provides type-safe definitions for keypoints and segmentation classes
|
|
5
|
+
that are defined at initialization time and used during processing.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import json
|
|
11
|
+
from abc import ABC, abstractmethod
|
|
12
|
+
from dataclasses import dataclass
|
|
13
|
+
from typing import Dict, List
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass(frozen=True)
|
|
17
|
+
class KeyPoint:
|
|
18
|
+
"""
|
|
19
|
+
A keypoint definition with ID and name.
|
|
20
|
+
|
|
21
|
+
Created via IKeyPointsSchema.define_point().
|
|
22
|
+
Used as a type-safe handle when adding keypoints to data context.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
id: int
|
|
26
|
+
name: str
|
|
27
|
+
|
|
28
|
+
def __str__(self) -> str:
|
|
29
|
+
return f"KeyPoint({self.id}, '{self.name}')"
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@dataclass(frozen=True)
|
|
33
|
+
class SegmentClass:
|
|
34
|
+
"""
|
|
35
|
+
A segmentation class definition with class ID and name.
|
|
36
|
+
|
|
37
|
+
Created via ISegmentationSchema.define_class().
|
|
38
|
+
Used as a type-safe handle when adding instances to data context.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
class_id: int
|
|
42
|
+
name: str
|
|
43
|
+
|
|
44
|
+
def __str__(self) -> str:
|
|
45
|
+
return f"SegmentClass({self.class_id}, '{self.name}')"
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class IKeyPointsSchema(ABC):
|
|
49
|
+
"""
|
|
50
|
+
Interface for defining keypoints schema.
|
|
51
|
+
|
|
52
|
+
Keypoints are defined once at initialization and referenced by handle
|
|
53
|
+
when adding data to the context.
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
@abstractmethod
|
|
57
|
+
def define_point(self, name: str) -> KeyPoint:
|
|
58
|
+
"""
|
|
59
|
+
Define a new keypoint.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
name: Human-readable name for the keypoint (e.g., "nose", "left_eye")
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
KeyPoint handle for use with IKeyPointsDataContext.add()
|
|
66
|
+
"""
|
|
67
|
+
pass
|
|
68
|
+
|
|
69
|
+
@property
|
|
70
|
+
@abstractmethod
|
|
71
|
+
def defined_points(self) -> List[KeyPoint]:
|
|
72
|
+
"""Get all defined keypoints."""
|
|
73
|
+
pass
|
|
74
|
+
|
|
75
|
+
@abstractmethod
|
|
76
|
+
def get_metadata_json(self) -> str:
|
|
77
|
+
"""Get JSON metadata for serialization."""
|
|
78
|
+
pass
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class ISegmentationSchema(ABC):
|
|
82
|
+
"""
|
|
83
|
+
Interface for defining segmentation classes schema.
|
|
84
|
+
|
|
85
|
+
Classes are defined once at initialization and referenced by handle
|
|
86
|
+
when adding instances to the context.
|
|
87
|
+
"""
|
|
88
|
+
|
|
89
|
+
@abstractmethod
|
|
90
|
+
def define_class(self, class_id: int, name: str) -> SegmentClass:
|
|
91
|
+
"""
|
|
92
|
+
Define a new segmentation class.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
class_id: Unique class identifier (0-255)
|
|
96
|
+
name: Human-readable name for the class (e.g., "person", "car")
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
SegmentClass handle for use with ISegmentationDataContext.add()
|
|
100
|
+
"""
|
|
101
|
+
pass
|
|
102
|
+
|
|
103
|
+
@property
|
|
104
|
+
@abstractmethod
|
|
105
|
+
def defined_classes(self) -> List[SegmentClass]:
|
|
106
|
+
"""Get all defined classes."""
|
|
107
|
+
pass
|
|
108
|
+
|
|
109
|
+
@abstractmethod
|
|
110
|
+
def get_metadata_json(self) -> str:
|
|
111
|
+
"""Get JSON metadata for serialization."""
|
|
112
|
+
pass
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
class KeyPointsSchema(IKeyPointsSchema):
|
|
116
|
+
"""Implementation of keypoints schema."""
|
|
117
|
+
|
|
118
|
+
def __init__(self) -> None:
|
|
119
|
+
self._points: Dict[str, KeyPoint] = {}
|
|
120
|
+
self._next_id = 0
|
|
121
|
+
|
|
122
|
+
def define_point(self, name: str) -> KeyPoint:
|
|
123
|
+
"""Define a new keypoint."""
|
|
124
|
+
if name in self._points:
|
|
125
|
+
raise ValueError(f"Keypoint '{name}' already defined")
|
|
126
|
+
|
|
127
|
+
point = KeyPoint(id=self._next_id, name=name)
|
|
128
|
+
self._points[name] = point
|
|
129
|
+
self._next_id += 1
|
|
130
|
+
return point
|
|
131
|
+
|
|
132
|
+
@property
|
|
133
|
+
def defined_points(self) -> List[KeyPoint]:
|
|
134
|
+
"""Get all defined keypoints."""
|
|
135
|
+
return list(self._points.values())
|
|
136
|
+
|
|
137
|
+
def get_metadata_json(self) -> str:
|
|
138
|
+
"""Get JSON metadata for serialization."""
|
|
139
|
+
return json.dumps(
|
|
140
|
+
{
|
|
141
|
+
"version": "1.0",
|
|
142
|
+
"compute_module_name": "",
|
|
143
|
+
"points": {p.name: p.id for p in self._points.values()},
|
|
144
|
+
},
|
|
145
|
+
indent=2,
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
class SegmentationSchema(ISegmentationSchema):
|
|
150
|
+
"""Implementation of segmentation schema."""
|
|
151
|
+
|
|
152
|
+
def __init__(self) -> None:
|
|
153
|
+
self._classes: Dict[int, SegmentClass] = {}
|
|
154
|
+
|
|
155
|
+
def define_class(self, class_id: int, name: str) -> SegmentClass:
|
|
156
|
+
"""Define a new segmentation class."""
|
|
157
|
+
if class_id < 0 or class_id > 255:
|
|
158
|
+
raise ValueError(f"class_id must be 0-255, got {class_id}")
|
|
159
|
+
|
|
160
|
+
if class_id in self._classes:
|
|
161
|
+
raise ValueError(f"Class ID {class_id} already defined")
|
|
162
|
+
|
|
163
|
+
segment_class = SegmentClass(class_id=class_id, name=name)
|
|
164
|
+
self._classes[class_id] = segment_class
|
|
165
|
+
return segment_class
|
|
166
|
+
|
|
167
|
+
@property
|
|
168
|
+
def defined_classes(self) -> List[SegmentClass]:
|
|
169
|
+
"""Get all defined classes."""
|
|
170
|
+
return list(self._classes.values())
|
|
171
|
+
|
|
172
|
+
def get_metadata_json(self) -> str:
|
|
173
|
+
"""Get JSON metadata for serialization."""
|
|
174
|
+
return json.dumps(
|
|
175
|
+
{
|
|
176
|
+
"version": "1.0",
|
|
177
|
+
"classes": {str(c.class_id): c.name for c in self._classes.values()},
|
|
178
|
+
},
|
|
179
|
+
indent=2,
|
|
180
|
+
)
|