rocket-welder-sdk 1.1.34.dev11__py3-none-any.whl → 1.1.34rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rocket_welder_sdk/__init__.py +1 -20
- rocket_welder_sdk/controllers.py +3 -7
- rocket_welder_sdk/frame_metadata.py +2 -2
- rocket_welder_sdk/high_level/__init__.py +25 -11
- rocket_welder_sdk/high_level/connection_strings.py +55 -56
- rocket_welder_sdk/high_level/data_context.py +11 -17
- rocket_welder_sdk/high_level/schema.py +27 -44
- rocket_welder_sdk/high_level/transport_protocol.py +108 -180
- rocket_welder_sdk/rocket_welder_client.py +14 -34
- rocket_welder_sdk/session_id.py +0 -123
- rocket_welder_sdk/transport/__init__.py +8 -0
- {rocket_welder_sdk-1.1.34.dev11.dist-info → rocket_welder_sdk-1.1.34rc1.dist-info}/METADATA +1 -1
- {rocket_welder_sdk-1.1.34.dev11.dist-info → rocket_welder_sdk-1.1.34rc1.dist-info}/RECORD +15 -16
- rocket_welder_sdk/high_level/client.py +0 -262
- {rocket_welder_sdk-1.1.34.dev11.dist-info → rocket_welder_sdk-1.1.34rc1.dist-info}/WHEEL +0 -0
- {rocket_welder_sdk-1.1.34.dev11.dist-info → rocket_welder_sdk-1.1.34rc1.dist-info}/top_level.txt +0 -0
rocket_welder_sdk/__init__.py
CHANGED
|
@@ -16,22 +16,11 @@ from .opencv_controller import OpenCvController
|
|
|
16
16
|
from .periodic_timer import PeriodicTimer, PeriodicTimerSync
|
|
17
17
|
from .rocket_welder_client import RocketWelderClient
|
|
18
18
|
from .session_id import (
|
|
19
|
-
# Explicit URL functions (PREFERRED - set by rocket-welder2)
|
|
20
|
-
ACTIONS_SINK_URL_ENV,
|
|
21
|
-
KEYPOINTS_SINK_URL_ENV,
|
|
22
|
-
SEGMENTATION_SINK_URL_ENV,
|
|
23
|
-
# SessionId-derived URL functions (fallback for backwards compatibility)
|
|
24
19
|
get_actions_url,
|
|
25
|
-
get_actions_url_from_env,
|
|
26
|
-
get_configured_nng_urls,
|
|
27
20
|
get_keypoints_url,
|
|
28
|
-
get_keypoints_url_from_env,
|
|
29
21
|
get_nng_urls,
|
|
30
|
-
get_nng_urls_from_env,
|
|
31
22
|
get_segmentation_url,
|
|
32
|
-
get_segmentation_url_from_env,
|
|
33
23
|
get_session_id_from_env,
|
|
34
|
-
has_explicit_nng_urls,
|
|
35
24
|
parse_session_id,
|
|
36
25
|
)
|
|
37
26
|
|
|
@@ -60,10 +49,7 @@ if _log_level:
|
|
|
60
49
|
pass # Invalid log level, ignore
|
|
61
50
|
|
|
62
51
|
__all__ = [
|
|
63
|
-
"ACTIONS_SINK_URL_ENV",
|
|
64
52
|
"FRAME_METADATA_SIZE",
|
|
65
|
-
"KEYPOINTS_SINK_URL_ENV",
|
|
66
|
-
"SEGMENTATION_SINK_URL_ENV",
|
|
67
53
|
"BytesSize",
|
|
68
54
|
"Client",
|
|
69
55
|
"ConnectionMode",
|
|
@@ -80,16 +66,11 @@ __all__ = [
|
|
|
80
66
|
"PeriodicTimerSync",
|
|
81
67
|
"Protocol",
|
|
82
68
|
"RocketWelderClient",
|
|
69
|
+
# SessionId utilities for NNG URL generation
|
|
83
70
|
"get_actions_url",
|
|
84
|
-
"get_actions_url_from_env",
|
|
85
|
-
"get_configured_nng_urls",
|
|
86
71
|
"get_keypoints_url",
|
|
87
|
-
"get_keypoints_url_from_env",
|
|
88
72
|
"get_nng_urls",
|
|
89
|
-
"get_nng_urls_from_env",
|
|
90
73
|
"get_segmentation_url",
|
|
91
|
-
"get_segmentation_url_from_env",
|
|
92
74
|
"get_session_id_from_env",
|
|
93
|
-
"has_explicit_nng_urls",
|
|
94
75
|
"parse_session_id",
|
|
95
76
|
]
|
rocket_welder_sdk/controllers.py
CHANGED
|
@@ -338,7 +338,7 @@ class OneWayShmController(IController):
|
|
|
338
338
|
Matches C# CreateMat behavior - creates Mat wrapping the data.
|
|
339
339
|
|
|
340
340
|
Frame data layout from GStreamer zerosink:
|
|
341
|
-
[FrameMetadata (16 bytes)][Pixel Data (
|
|
341
|
+
[FrameMetadata (16 bytes)][Pixel Data (W×H×C bytes)]
|
|
342
342
|
|
|
343
343
|
Args:
|
|
344
344
|
frame: ZeroBuffer frame
|
|
@@ -438,9 +438,7 @@ class OneWayShmController(IController):
|
|
|
438
438
|
sqrt_pixels = math.sqrt(pixels)
|
|
439
439
|
if sqrt_pixels == int(sqrt_pixels):
|
|
440
440
|
dimension = int(sqrt_pixels)
|
|
441
|
-
logger.info(
|
|
442
|
-
f"Pixel data size {pixel_data_size} suggests {dimension}x{dimension} RGB"
|
|
443
|
-
)
|
|
441
|
+
logger.info(f"Pixel data size {pixel_data_size} suggests {dimension}x{dimension} RGB")
|
|
444
442
|
pixel_data = np.frombuffer(frame.data[FRAME_METADATA_SIZE:], dtype=np.uint8)
|
|
445
443
|
return pixel_data.reshape((dimension, dimension, 3)) # type: ignore[no-any-return]
|
|
446
444
|
|
|
@@ -450,9 +448,7 @@ class OneWayShmController(IController):
|
|
|
450
448
|
sqrt_pixels = math.sqrt(pixels)
|
|
451
449
|
if sqrt_pixels == int(sqrt_pixels):
|
|
452
450
|
dimension = int(sqrt_pixels)
|
|
453
|
-
logger.info(
|
|
454
|
-
f"Pixel data size {pixel_data_size} suggests {dimension}x{dimension} RGBA"
|
|
455
|
-
)
|
|
451
|
+
logger.info(f"Pixel data size {pixel_data_size} suggests {dimension}x{dimension} RGBA")
|
|
456
452
|
pixel_data = np.frombuffer(frame.data[FRAME_METADATA_SIZE:], dtype=np.uint8)
|
|
457
453
|
return pixel_data.reshape((dimension, dimension, 4)) # type: ignore[no-any-return]
|
|
458
454
|
|
|
@@ -18,7 +18,7 @@ from __future__ import annotations
|
|
|
18
18
|
|
|
19
19
|
import struct
|
|
20
20
|
from dataclasses import dataclass
|
|
21
|
-
from typing import
|
|
21
|
+
from typing import Optional
|
|
22
22
|
|
|
23
23
|
# Size of the FrameMetadata structure in bytes
|
|
24
24
|
FRAME_METADATA_SIZE = 16
|
|
@@ -113,7 +113,7 @@ class GstVideoFormat:
|
|
|
113
113
|
GRAY16_BE = 26
|
|
114
114
|
GRAY16_LE = 27
|
|
115
115
|
|
|
116
|
-
_FORMAT_NAMES:
|
|
116
|
+
_FORMAT_NAMES: dict[int, str] = {
|
|
117
117
|
0: "UNKNOWN",
|
|
118
118
|
2: "I420",
|
|
119
119
|
3: "YV12",
|
|
@@ -1,18 +1,26 @@
|
|
|
1
1
|
"""
|
|
2
2
|
High-level API for RocketWelder SDK.
|
|
3
3
|
|
|
4
|
-
|
|
4
|
+
Provides a simplified, user-friendly API for common video processing workflows
|
|
5
|
+
with automatic transport management and schema definitions.
|
|
5
6
|
|
|
6
7
|
Example:
|
|
7
|
-
from rocket_welder_sdk.high_level import RocketWelderClient
|
|
8
|
+
from rocket_welder_sdk.high_level import RocketWelderClient, Transport
|
|
8
9
|
|
|
9
|
-
with RocketWelderClient.from_environment() as client:
|
|
10
|
+
async with RocketWelderClient.from_environment() as client:
|
|
11
|
+
# Define keypoints schema
|
|
10
12
|
nose = client.keypoints.define_point("nose")
|
|
13
|
+
left_eye = client.keypoints.define_point("left_eye")
|
|
14
|
+
|
|
15
|
+
# Define segmentation classes
|
|
11
16
|
person = client.segmentation.define_class(1, "person")
|
|
12
|
-
|
|
17
|
+
|
|
18
|
+
async for input_frame, seg_ctx, kp_ctx, output_frame in client.start():
|
|
19
|
+
# Process frame...
|
|
20
|
+
kp_ctx.add(nose, x=100, y=200, confidence=0.95)
|
|
21
|
+
seg_ctx.add(person, instance_id=0, points=contour_points)
|
|
13
22
|
"""
|
|
14
23
|
|
|
15
|
-
from .client import RocketWelderClient, RocketWelderClientOptions
|
|
16
24
|
from .connection_strings import (
|
|
17
25
|
KeyPointsConnectionString,
|
|
18
26
|
SegmentationConnectionString,
|
|
@@ -26,11 +34,15 @@ from .data_context import (
|
|
|
26
34
|
from .schema import (
|
|
27
35
|
IKeyPointsSchema,
|
|
28
36
|
ISegmentationSchema,
|
|
29
|
-
|
|
37
|
+
KeyPoint,
|
|
30
38
|
SegmentClass,
|
|
31
39
|
)
|
|
32
40
|
from .transport_protocol import (
|
|
33
|
-
|
|
41
|
+
MessagingLibrary,
|
|
42
|
+
MessagingPattern,
|
|
43
|
+
Transport,
|
|
44
|
+
TransportBuilder,
|
|
45
|
+
TransportLayer,
|
|
34
46
|
TransportProtocol,
|
|
35
47
|
)
|
|
36
48
|
|
|
@@ -39,13 +51,15 @@ __all__ = [
|
|
|
39
51
|
"IKeyPointsSchema",
|
|
40
52
|
"ISegmentationDataContext",
|
|
41
53
|
"ISegmentationSchema",
|
|
42
|
-
"
|
|
54
|
+
"KeyPoint",
|
|
43
55
|
"KeyPointsConnectionString",
|
|
44
|
-
"
|
|
45
|
-
"
|
|
56
|
+
"MessagingLibrary",
|
|
57
|
+
"MessagingPattern",
|
|
46
58
|
"SegmentClass",
|
|
47
59
|
"SegmentationConnectionString",
|
|
48
|
-
"
|
|
60
|
+
"Transport",
|
|
61
|
+
"TransportBuilder",
|
|
62
|
+
"TransportLayer",
|
|
49
63
|
"TransportProtocol",
|
|
50
64
|
"VideoSourceConnectionString",
|
|
51
65
|
"VideoSourceType",
|
|
@@ -6,8 +6,7 @@ Connection string format: protocol://path?param1=value1¶m2=value2
|
|
|
6
6
|
Examples:
|
|
7
7
|
nng+push+ipc://tmp/keypoints?masterFrameInterval=300
|
|
8
8
|
nng+pub+tcp://localhost:5555
|
|
9
|
-
file
|
|
10
|
-
socket:///tmp/my.sock
|
|
9
|
+
file://path/to/output.bin
|
|
11
10
|
"""
|
|
12
11
|
|
|
13
12
|
from __future__ import annotations
|
|
@@ -145,19 +144,19 @@ class KeyPointsConnectionString:
|
|
|
145
144
|
"""
|
|
146
145
|
Strongly-typed connection string for KeyPoints output.
|
|
147
146
|
|
|
148
|
-
Supported protocols:
|
|
149
|
-
-
|
|
150
|
-
-
|
|
151
|
-
-
|
|
152
|
-
- nng+push+tcp://host:port - NNG Push over TCP
|
|
147
|
+
Supported protocols (composable with + operator):
|
|
148
|
+
- Transport.Nng + Transport.Push + Transport.Ipc → nng+push+ipc://tmp/keypoints
|
|
149
|
+
- Transport.Nng + Transport.Push + Transport.Tcp → nng+push+tcp://host:port
|
|
150
|
+
- file://path/to/file.bin - File output
|
|
153
151
|
|
|
154
152
|
Supported parameters:
|
|
155
153
|
- masterFrameInterval: Interval between master frames (default: 300)
|
|
156
154
|
"""
|
|
157
155
|
|
|
158
156
|
value: str
|
|
159
|
-
protocol: TransportProtocol
|
|
160
|
-
|
|
157
|
+
protocol: Optional[TransportProtocol] = None
|
|
158
|
+
is_file: bool = False
|
|
159
|
+
address: str = ""
|
|
161
160
|
master_frame_interval: int = 300
|
|
162
161
|
parameters: Dict[str, str] = field(default_factory=dict)
|
|
163
162
|
|
|
@@ -200,26 +199,25 @@ class KeyPointsConnectionString:
|
|
|
200
199
|
|
|
201
200
|
# Parse protocol and address
|
|
202
201
|
scheme_end = endpoint_part.find("://")
|
|
203
|
-
if scheme_end
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
elif
|
|
218
|
-
#
|
|
219
|
-
address =
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
address = protocol.create_nng_address(path_part)
|
|
202
|
+
if scheme_end > 0:
|
|
203
|
+
protocol_str = endpoint_part[:scheme_end]
|
|
204
|
+
path_part = endpoint_part[scheme_end + 3 :] # skip "://"
|
|
205
|
+
|
|
206
|
+
if protocol_str.lower() == "file":
|
|
207
|
+
address = "/" + path_part # Restore absolute path
|
|
208
|
+
is_file = True
|
|
209
|
+
protocol = None
|
|
210
|
+
else:
|
|
211
|
+
protocol = TransportProtocol.try_parse(protocol_str)
|
|
212
|
+
if protocol is None:
|
|
213
|
+
return None
|
|
214
|
+
address = protocol.create_nng_address(path_part)
|
|
215
|
+
is_file = False
|
|
216
|
+
elif s.startswith("/"):
|
|
217
|
+
# Assume absolute file path
|
|
218
|
+
address = s
|
|
219
|
+
is_file = True
|
|
220
|
+
protocol = None
|
|
223
221
|
else:
|
|
224
222
|
return None
|
|
225
223
|
|
|
@@ -232,6 +230,7 @@ class KeyPointsConnectionString:
|
|
|
232
230
|
return cls(
|
|
233
231
|
value=s,
|
|
234
232
|
protocol=protocol,
|
|
233
|
+
is_file=is_file,
|
|
235
234
|
address=address,
|
|
236
235
|
master_frame_interval=master_frame_interval,
|
|
237
236
|
parameters=parameters,
|
|
@@ -246,16 +245,16 @@ class SegmentationConnectionString:
|
|
|
246
245
|
"""
|
|
247
246
|
Strongly-typed connection string for Segmentation output.
|
|
248
247
|
|
|
249
|
-
Supported protocols:
|
|
250
|
-
-
|
|
251
|
-
-
|
|
252
|
-
-
|
|
253
|
-
- nng+push+tcp://host:port - NNG Push over TCP
|
|
248
|
+
Supported protocols (composable with + operator):
|
|
249
|
+
- Transport.Nng + Transport.Push + Transport.Ipc → nng+push+ipc://tmp/segmentation
|
|
250
|
+
- Transport.Nng + Transport.Push + Transport.Tcp → nng+push+tcp://host:port
|
|
251
|
+
- file://path/to/file.bin - File output
|
|
254
252
|
"""
|
|
255
253
|
|
|
256
254
|
value: str
|
|
257
|
-
protocol: TransportProtocol
|
|
258
|
-
|
|
255
|
+
protocol: Optional[TransportProtocol] = None
|
|
256
|
+
is_file: bool = False
|
|
257
|
+
address: str = ""
|
|
259
258
|
parameters: Dict[str, str] = field(default_factory=dict)
|
|
260
259
|
|
|
261
260
|
@classmethod
|
|
@@ -297,32 +296,32 @@ class SegmentationConnectionString:
|
|
|
297
296
|
|
|
298
297
|
# Parse protocol and address
|
|
299
298
|
scheme_end = endpoint_part.find("://")
|
|
300
|
-
if scheme_end
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
elif
|
|
315
|
-
#
|
|
316
|
-
address =
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
address = protocol.create_nng_address(path_part)
|
|
299
|
+
if scheme_end > 0:
|
|
300
|
+
protocol_str = endpoint_part[:scheme_end]
|
|
301
|
+
path_part = endpoint_part[scheme_end + 3 :] # skip "://"
|
|
302
|
+
|
|
303
|
+
if protocol_str.lower() == "file":
|
|
304
|
+
address = "/" + path_part # Restore absolute path
|
|
305
|
+
is_file = True
|
|
306
|
+
protocol = None
|
|
307
|
+
else:
|
|
308
|
+
protocol = TransportProtocol.try_parse(protocol_str)
|
|
309
|
+
if protocol is None:
|
|
310
|
+
return None
|
|
311
|
+
address = protocol.create_nng_address(path_part)
|
|
312
|
+
is_file = False
|
|
313
|
+
elif s.startswith("/"):
|
|
314
|
+
# Assume absolute file path
|
|
315
|
+
address = s
|
|
316
|
+
is_file = True
|
|
317
|
+
protocol = None
|
|
320
318
|
else:
|
|
321
319
|
return None
|
|
322
320
|
|
|
323
321
|
return cls(
|
|
324
322
|
value=s,
|
|
325
323
|
protocol=protocol,
|
|
324
|
+
is_file=is_file,
|
|
326
325
|
address=address,
|
|
327
326
|
parameters=parameters,
|
|
328
327
|
)
|
|
@@ -17,7 +17,7 @@ if TYPE_CHECKING:
|
|
|
17
17
|
from rocket_welder_sdk.keypoints_protocol import IKeyPointsWriter
|
|
18
18
|
from rocket_welder_sdk.segmentation_result import SegmentationResultWriter
|
|
19
19
|
|
|
20
|
-
from .schema import
|
|
20
|
+
from .schema import KeyPoint, SegmentClass
|
|
21
21
|
|
|
22
22
|
# Type aliases
|
|
23
23
|
Point = Tuple[int, int]
|
|
@@ -37,12 +37,12 @@ class IKeyPointsDataContext(ABC):
|
|
|
37
37
|
pass
|
|
38
38
|
|
|
39
39
|
@abstractmethod
|
|
40
|
-
def add(self, point:
|
|
40
|
+
def add(self, point: KeyPoint, x: int, y: int, confidence: float) -> None:
|
|
41
41
|
"""
|
|
42
42
|
Add a keypoint detection for this frame.
|
|
43
43
|
|
|
44
44
|
Args:
|
|
45
|
-
point:
|
|
45
|
+
point: KeyPoint from schema definition
|
|
46
46
|
x: X coordinate in pixels
|
|
47
47
|
y: Y coordinate in pixels
|
|
48
48
|
confidence: Detection confidence (0.0 to 1.0)
|
|
@@ -50,22 +50,17 @@ class IKeyPointsDataContext(ABC):
|
|
|
50
50
|
pass
|
|
51
51
|
|
|
52
52
|
@abstractmethod
|
|
53
|
-
def add_point(self, point:
|
|
53
|
+
def add_point(self, point: KeyPoint, position: Point, confidence: float) -> None:
|
|
54
54
|
"""
|
|
55
55
|
Add a keypoint detection using a Point tuple.
|
|
56
56
|
|
|
57
57
|
Args:
|
|
58
|
-
point:
|
|
58
|
+
point: KeyPoint from schema definition
|
|
59
59
|
position: (x, y) tuple
|
|
60
60
|
confidence: Detection confidence (0.0 to 1.0)
|
|
61
61
|
"""
|
|
62
62
|
pass
|
|
63
63
|
|
|
64
|
-
@abstractmethod
|
|
65
|
-
def commit(self) -> None:
|
|
66
|
-
"""Commit the context (called automatically when delegate returns)."""
|
|
67
|
-
pass
|
|
68
|
-
|
|
69
64
|
|
|
70
65
|
class ISegmentationDataContext(ABC):
|
|
71
66
|
"""
|
|
@@ -97,11 +92,6 @@ class ISegmentationDataContext(ABC):
|
|
|
97
92
|
"""
|
|
98
93
|
pass
|
|
99
94
|
|
|
100
|
-
@abstractmethod
|
|
101
|
-
def commit(self) -> None:
|
|
102
|
-
"""Commit the context (called automatically when delegate returns)."""
|
|
103
|
-
pass
|
|
104
|
-
|
|
105
95
|
|
|
106
96
|
class KeyPointsDataContext(IKeyPointsDataContext):
|
|
107
97
|
"""Implementation of keypoints data context."""
|
|
@@ -111,6 +101,8 @@ class KeyPointsDataContext(IKeyPointsDataContext):
|
|
|
111
101
|
frame_id: int,
|
|
112
102
|
writer: IKeyPointsWriter,
|
|
113
103
|
) -> None:
|
|
104
|
+
from .schema import KeyPoint # noqa: F401
|
|
105
|
+
|
|
114
106
|
self._frame_id = frame_id
|
|
115
107
|
self._writer = writer
|
|
116
108
|
|
|
@@ -118,11 +110,11 @@ class KeyPointsDataContext(IKeyPointsDataContext):
|
|
|
118
110
|
def frame_id(self) -> int:
|
|
119
111
|
return self._frame_id
|
|
120
112
|
|
|
121
|
-
def add(self, point:
|
|
113
|
+
def add(self, point: KeyPoint, x: int, y: int, confidence: float) -> None:
|
|
122
114
|
"""Add a keypoint detection for this frame."""
|
|
123
115
|
self._writer.append(point.id, x, y, confidence)
|
|
124
116
|
|
|
125
|
-
def add_point(self, point:
|
|
117
|
+
def add_point(self, point: KeyPoint, position: Point, confidence: float) -> None:
|
|
126
118
|
"""Add a keypoint detection using a Point tuple."""
|
|
127
119
|
self._writer.append_point(point.id, position, confidence)
|
|
128
120
|
|
|
@@ -139,6 +131,8 @@ class SegmentationDataContext(ISegmentationDataContext):
|
|
|
139
131
|
frame_id: int,
|
|
140
132
|
writer: SegmentationResultWriter,
|
|
141
133
|
) -> None:
|
|
134
|
+
from .schema import SegmentClass # noqa: F401
|
|
135
|
+
|
|
142
136
|
self._frame_id = frame_id
|
|
143
137
|
self._writer = writer
|
|
144
138
|
|
|
@@ -10,11 +10,11 @@ from __future__ import annotations
|
|
|
10
10
|
import json
|
|
11
11
|
from abc import ABC, abstractmethod
|
|
12
12
|
from dataclasses import dataclass
|
|
13
|
-
from typing import Dict, List
|
|
13
|
+
from typing import Dict, List
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
@dataclass(frozen=True)
|
|
17
|
-
class
|
|
17
|
+
class KeyPoint:
|
|
18
18
|
"""
|
|
19
19
|
A keypoint definition with ID and name.
|
|
20
20
|
|
|
@@ -26,7 +26,7 @@ class KeyPointDefinition:
|
|
|
26
26
|
name: str
|
|
27
27
|
|
|
28
28
|
def __str__(self) -> str:
|
|
29
|
-
return f"
|
|
29
|
+
return f"KeyPoint({self.id}, '{self.name}')"
|
|
30
30
|
|
|
31
31
|
|
|
32
32
|
@dataclass(frozen=True)
|
|
@@ -54,7 +54,7 @@ class IKeyPointsSchema(ABC):
|
|
|
54
54
|
"""
|
|
55
55
|
|
|
56
56
|
@abstractmethod
|
|
57
|
-
def define_point(self, name: str) ->
|
|
57
|
+
def define_point(self, name: str) -> KeyPoint:
|
|
58
58
|
"""
|
|
59
59
|
Define a new keypoint.
|
|
60
60
|
|
|
@@ -62,13 +62,13 @@ class IKeyPointsSchema(ABC):
|
|
|
62
62
|
name: Human-readable name for the keypoint (e.g., "nose", "left_eye")
|
|
63
63
|
|
|
64
64
|
Returns:
|
|
65
|
-
|
|
65
|
+
KeyPoint handle for use with IKeyPointsDataContext.add()
|
|
66
66
|
"""
|
|
67
67
|
pass
|
|
68
68
|
|
|
69
69
|
@property
|
|
70
70
|
@abstractmethod
|
|
71
|
-
def defined_points(self) -> List[
|
|
71
|
+
def defined_points(self) -> List[KeyPoint]:
|
|
72
72
|
"""Get all defined keypoints."""
|
|
73
73
|
pass
|
|
74
74
|
|
|
@@ -116,41 +116,34 @@ class KeyPointsSchema(IKeyPointsSchema):
|
|
|
116
116
|
"""Implementation of keypoints schema."""
|
|
117
117
|
|
|
118
118
|
def __init__(self) -> None:
|
|
119
|
-
self._points: Dict[str,
|
|
119
|
+
self._points: Dict[str, KeyPoint] = {}
|
|
120
120
|
self._next_id = 0
|
|
121
121
|
|
|
122
|
-
def define_point(self, name: str) ->
|
|
122
|
+
def define_point(self, name: str) -> KeyPoint:
|
|
123
123
|
"""Define a new keypoint."""
|
|
124
124
|
if name in self._points:
|
|
125
125
|
raise ValueError(f"Keypoint '{name}' already defined")
|
|
126
126
|
|
|
127
|
-
point =
|
|
127
|
+
point = KeyPoint(id=self._next_id, name=name)
|
|
128
128
|
self._points[name] = point
|
|
129
129
|
self._next_id += 1
|
|
130
130
|
return point
|
|
131
131
|
|
|
132
132
|
@property
|
|
133
|
-
def defined_points(self) -> List[
|
|
133
|
+
def defined_points(self) -> List[KeyPoint]:
|
|
134
134
|
"""Get all defined keypoints."""
|
|
135
135
|
return list(self._points.values())
|
|
136
136
|
|
|
137
137
|
def get_metadata_json(self) -> str:
|
|
138
|
-
"""
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
"""
|
|
148
|
-
metadata: Dict[str, Any] = {
|
|
149
|
-
"version": 1,
|
|
150
|
-
"type": "keypoints",
|
|
151
|
-
"points": [{"id": p.id, "name": p.name} for p in self._points.values()],
|
|
152
|
-
}
|
|
153
|
-
return json.dumps(metadata, indent=2)
|
|
138
|
+
"""Get JSON metadata for serialization."""
|
|
139
|
+
return json.dumps(
|
|
140
|
+
{
|
|
141
|
+
"version": "1.0",
|
|
142
|
+
"compute_module_name": "",
|
|
143
|
+
"points": {p.name: p.id for p in self._points.values()},
|
|
144
|
+
},
|
|
145
|
+
indent=2,
|
|
146
|
+
)
|
|
154
147
|
|
|
155
148
|
|
|
156
149
|
class SegmentationSchema(ISegmentationSchema):
|
|
@@ -177,21 +170,11 @@ class SegmentationSchema(ISegmentationSchema):
|
|
|
177
170
|
return list(self._classes.values())
|
|
178
171
|
|
|
179
172
|
def get_metadata_json(self) -> str:
|
|
180
|
-
"""
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
}
|
|
189
|
-
"""
|
|
190
|
-
metadata: Dict[str, Any] = {
|
|
191
|
-
"version": 1,
|
|
192
|
-
"type": "segmentation",
|
|
193
|
-
"classes": [
|
|
194
|
-
{"classId": c.class_id, "name": c.name} for c in self._classes.values()
|
|
195
|
-
],
|
|
196
|
-
}
|
|
197
|
-
return json.dumps(metadata, indent=2)
|
|
173
|
+
"""Get JSON metadata for serialization."""
|
|
174
|
+
return json.dumps(
|
|
175
|
+
{
|
|
176
|
+
"version": "1.0",
|
|
177
|
+
"classes": {str(c.class_id): c.name for c in self._classes.values()},
|
|
178
|
+
},
|
|
179
|
+
indent=2,
|
|
180
|
+
)
|