rocket-welder-sdk 1.1.36.dev14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rocket_welder_sdk/__init__.py +95 -0
- rocket_welder_sdk/bytes_size.py +234 -0
- rocket_welder_sdk/connection_string.py +291 -0
- rocket_welder_sdk/controllers.py +831 -0
- rocket_welder_sdk/external_controls/__init__.py +30 -0
- rocket_welder_sdk/external_controls/contracts.py +100 -0
- rocket_welder_sdk/external_controls/contracts_old.py +105 -0
- rocket_welder_sdk/frame_metadata.py +138 -0
- rocket_welder_sdk/gst_metadata.py +411 -0
- rocket_welder_sdk/high_level/__init__.py +54 -0
- rocket_welder_sdk/high_level/client.py +235 -0
- rocket_welder_sdk/high_level/connection_strings.py +331 -0
- rocket_welder_sdk/high_level/data_context.py +169 -0
- rocket_welder_sdk/high_level/frame_sink_factory.py +118 -0
- rocket_welder_sdk/high_level/schema.py +195 -0
- rocket_welder_sdk/high_level/transport_protocol.py +238 -0
- rocket_welder_sdk/keypoints_protocol.py +642 -0
- rocket_welder_sdk/opencv_controller.py +278 -0
- rocket_welder_sdk/periodic_timer.py +303 -0
- rocket_welder_sdk/py.typed +2 -0
- rocket_welder_sdk/rocket_welder_client.py +497 -0
- rocket_welder_sdk/segmentation_result.py +420 -0
- rocket_welder_sdk/session_id.py +238 -0
- rocket_welder_sdk/transport/__init__.py +31 -0
- rocket_welder_sdk/transport/frame_sink.py +122 -0
- rocket_welder_sdk/transport/frame_source.py +74 -0
- rocket_welder_sdk/transport/nng_transport.py +197 -0
- rocket_welder_sdk/transport/stream_transport.py +193 -0
- rocket_welder_sdk/transport/tcp_transport.py +154 -0
- rocket_welder_sdk/transport/unix_socket_transport.py +339 -0
- rocket_welder_sdk/ui/__init__.py +48 -0
- rocket_welder_sdk/ui/controls.py +362 -0
- rocket_welder_sdk/ui/icons.py +21628 -0
- rocket_welder_sdk/ui/ui_events_projection.py +226 -0
- rocket_welder_sdk/ui/ui_service.py +358 -0
- rocket_welder_sdk/ui/value_types.py +72 -0
- rocket_welder_sdk-1.1.36.dev14.dist-info/METADATA +845 -0
- rocket_welder_sdk-1.1.36.dev14.dist-info/RECORD +40 -0
- rocket_welder_sdk-1.1.36.dev14.dist-info/WHEEL +5 -0
- rocket_welder_sdk-1.1.36.dev14.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,331 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Strongly-typed connection strings with parsing support.
|
|
3
|
+
|
|
4
|
+
Connection string format: protocol://path?param1=value1¶m2=value2
|
|
5
|
+
|
|
6
|
+
Examples:
|
|
7
|
+
nng+push+ipc://tmp/keypoints?masterFrameInterval=300
|
|
8
|
+
nng+pub+tcp://localhost:5555
|
|
9
|
+
file:///path/to/output.bin
|
|
10
|
+
socket:///tmp/my.sock
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
import contextlib
|
|
16
|
+
import os
|
|
17
|
+
from dataclasses import dataclass, field
|
|
18
|
+
from enum import Enum, auto
|
|
19
|
+
from typing import Dict, Optional
|
|
20
|
+
from urllib.parse import parse_qs
|
|
21
|
+
|
|
22
|
+
from .transport_protocol import TransportProtocol
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class VideoSourceType(Enum):
|
|
26
|
+
"""Type of video source."""
|
|
27
|
+
|
|
28
|
+
CAMERA = auto()
|
|
29
|
+
FILE = auto()
|
|
30
|
+
SHARED_MEMORY = auto()
|
|
31
|
+
RTSP = auto()
|
|
32
|
+
HTTP = auto()
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@dataclass(frozen=True)
|
|
36
|
+
class VideoSourceConnectionString:
|
|
37
|
+
"""
|
|
38
|
+
Strongly-typed connection string for video source input.
|
|
39
|
+
|
|
40
|
+
Supported formats:
|
|
41
|
+
- "0", "1", etc. - Camera device index
|
|
42
|
+
- file://path/to/video.mp4 - Video file
|
|
43
|
+
- shm://buffer_name - Shared memory buffer
|
|
44
|
+
- rtsp://host/stream - RTSP stream
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
value: str
|
|
48
|
+
source_type: VideoSourceType
|
|
49
|
+
camera_index: Optional[int] = None
|
|
50
|
+
path: Optional[str] = None
|
|
51
|
+
parameters: Dict[str, str] = field(default_factory=dict)
|
|
52
|
+
|
|
53
|
+
@classmethod
|
|
54
|
+
def default(cls) -> VideoSourceConnectionString:
|
|
55
|
+
"""Default video source (camera 0)."""
|
|
56
|
+
return cls.parse("0")
|
|
57
|
+
|
|
58
|
+
@classmethod
|
|
59
|
+
def from_environment(cls, variable_name: str = "VIDEO_SOURCE") -> VideoSourceConnectionString:
|
|
60
|
+
"""Create from environment variable or use default."""
|
|
61
|
+
value = os.environ.get(variable_name) or os.environ.get("CONNECTION_STRING")
|
|
62
|
+
return cls.parse(value) if value else cls.default()
|
|
63
|
+
|
|
64
|
+
@classmethod
|
|
65
|
+
def parse(cls, s: str) -> VideoSourceConnectionString:
|
|
66
|
+
"""Parse a connection string."""
|
|
67
|
+
result = cls.try_parse(s)
|
|
68
|
+
if result is None:
|
|
69
|
+
raise ValueError(f"Invalid video source connection string: {s}")
|
|
70
|
+
return result
|
|
71
|
+
|
|
72
|
+
@classmethod
|
|
73
|
+
def try_parse(cls, s: str) -> Optional[VideoSourceConnectionString]:
|
|
74
|
+
"""Try to parse a connection string."""
|
|
75
|
+
if not s or not s.strip():
|
|
76
|
+
return None
|
|
77
|
+
|
|
78
|
+
s = s.strip()
|
|
79
|
+
parameters: Dict[str, str] = {}
|
|
80
|
+
|
|
81
|
+
# Extract query parameters
|
|
82
|
+
if "?" in s:
|
|
83
|
+
base, query = s.split("?", 1)
|
|
84
|
+
for key, values in parse_qs(query).items():
|
|
85
|
+
parameters[key.lower()] = values[0] if values else ""
|
|
86
|
+
s = base
|
|
87
|
+
|
|
88
|
+
# Check for camera index first
|
|
89
|
+
if s.isdigit():
|
|
90
|
+
return cls(
|
|
91
|
+
value=s,
|
|
92
|
+
source_type=VideoSourceType.CAMERA,
|
|
93
|
+
camera_index=int(s),
|
|
94
|
+
parameters=parameters,
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
# Parse protocol
|
|
98
|
+
if s.startswith("file://"):
|
|
99
|
+
path = "/" + s[7:] # Restore absolute path
|
|
100
|
+
return cls(
|
|
101
|
+
value=s,
|
|
102
|
+
source_type=VideoSourceType.FILE,
|
|
103
|
+
path=path,
|
|
104
|
+
parameters=parameters,
|
|
105
|
+
)
|
|
106
|
+
elif s.startswith("shm://"):
|
|
107
|
+
path = s[6:]
|
|
108
|
+
return cls(
|
|
109
|
+
value=s,
|
|
110
|
+
source_type=VideoSourceType.SHARED_MEMORY,
|
|
111
|
+
path=path,
|
|
112
|
+
parameters=parameters,
|
|
113
|
+
)
|
|
114
|
+
elif s.startswith("rtsp://"):
|
|
115
|
+
return cls(
|
|
116
|
+
value=s,
|
|
117
|
+
source_type=VideoSourceType.RTSP,
|
|
118
|
+
path=s,
|
|
119
|
+
parameters=parameters,
|
|
120
|
+
)
|
|
121
|
+
elif s.startswith("http://") or s.startswith("https://"):
|
|
122
|
+
return cls(
|
|
123
|
+
value=s,
|
|
124
|
+
source_type=VideoSourceType.HTTP,
|
|
125
|
+
path=s,
|
|
126
|
+
parameters=parameters,
|
|
127
|
+
)
|
|
128
|
+
elif "://" not in s:
|
|
129
|
+
# Assume file path
|
|
130
|
+
return cls(
|
|
131
|
+
value=s,
|
|
132
|
+
source_type=VideoSourceType.FILE,
|
|
133
|
+
path=s,
|
|
134
|
+
parameters=parameters,
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
return None
|
|
138
|
+
|
|
139
|
+
def __str__(self) -> str:
|
|
140
|
+
return self.value
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
@dataclass(frozen=True)
|
|
144
|
+
class KeyPointsConnectionString:
|
|
145
|
+
"""
|
|
146
|
+
Strongly-typed connection string for KeyPoints output.
|
|
147
|
+
|
|
148
|
+
Supported protocols:
|
|
149
|
+
- file:///path/to/file.bin - File output (absolute path)
|
|
150
|
+
- socket:///tmp/socket.sock - Unix domain socket
|
|
151
|
+
- nng+push+ipc://tmp/keypoints - NNG Push over IPC
|
|
152
|
+
- nng+push+tcp://host:port - NNG Push over TCP
|
|
153
|
+
|
|
154
|
+
Supported parameters:
|
|
155
|
+
- masterFrameInterval: Interval between master frames (default: 300)
|
|
156
|
+
"""
|
|
157
|
+
|
|
158
|
+
value: str
|
|
159
|
+
protocol: TransportProtocol
|
|
160
|
+
address: str
|
|
161
|
+
master_frame_interval: int = 300
|
|
162
|
+
parameters: Dict[str, str] = field(default_factory=dict)
|
|
163
|
+
|
|
164
|
+
@classmethod
|
|
165
|
+
def default(cls) -> KeyPointsConnectionString:
|
|
166
|
+
"""Default connection string for KeyPoints."""
|
|
167
|
+
return cls.parse("nng+push+ipc://tmp/rocket-welder-keypoints?masterFrameInterval=300")
|
|
168
|
+
|
|
169
|
+
@classmethod
|
|
170
|
+
def from_environment(
|
|
171
|
+
cls, variable_name: str = "KEYPOINTS_CONNECTION_STRING"
|
|
172
|
+
) -> KeyPointsConnectionString:
|
|
173
|
+
"""Create from environment variable or use default."""
|
|
174
|
+
value = os.environ.get(variable_name)
|
|
175
|
+
return cls.parse(value) if value else cls.default()
|
|
176
|
+
|
|
177
|
+
@classmethod
|
|
178
|
+
def parse(cls, s: str) -> KeyPointsConnectionString:
|
|
179
|
+
"""Parse a connection string."""
|
|
180
|
+
result = cls.try_parse(s)
|
|
181
|
+
if result is None:
|
|
182
|
+
raise ValueError(f"Invalid KeyPoints connection string: {s}")
|
|
183
|
+
return result
|
|
184
|
+
|
|
185
|
+
@classmethod
|
|
186
|
+
def try_parse(cls, s: str) -> Optional[KeyPointsConnectionString]:
|
|
187
|
+
"""Try to parse a connection string."""
|
|
188
|
+
if not s or not s.strip():
|
|
189
|
+
return None
|
|
190
|
+
|
|
191
|
+
s = s.strip()
|
|
192
|
+
parameters: Dict[str, str] = {}
|
|
193
|
+
|
|
194
|
+
# Extract query parameters
|
|
195
|
+
endpoint_part = s
|
|
196
|
+
if "?" in s:
|
|
197
|
+
endpoint_part, query = s.split("?", 1)
|
|
198
|
+
for key, values in parse_qs(query).items():
|
|
199
|
+
parameters[key.lower()] = values[0] if values else ""
|
|
200
|
+
|
|
201
|
+
# Parse protocol and address
|
|
202
|
+
scheme_end = endpoint_part.find("://")
|
|
203
|
+
if scheme_end <= 0:
|
|
204
|
+
return None
|
|
205
|
+
|
|
206
|
+
schema_str = endpoint_part[:scheme_end]
|
|
207
|
+
path_part = endpoint_part[scheme_end + 3 :] # skip "://"
|
|
208
|
+
|
|
209
|
+
protocol = TransportProtocol.try_parse(schema_str)
|
|
210
|
+
if protocol is None:
|
|
211
|
+
return None
|
|
212
|
+
|
|
213
|
+
# Build address based on protocol type
|
|
214
|
+
if protocol.is_file:
|
|
215
|
+
# file:///absolute/path -> /absolute/path
|
|
216
|
+
address = path_part if path_part.startswith("/") else "/" + path_part
|
|
217
|
+
elif protocol.is_socket:
|
|
218
|
+
# socket:///tmp/sock -> /tmp/sock
|
|
219
|
+
address = path_part if path_part.startswith("/") else "/" + path_part
|
|
220
|
+
elif protocol.is_nng:
|
|
221
|
+
# NNG protocols need proper address format
|
|
222
|
+
address = protocol.create_nng_address(path_part)
|
|
223
|
+
else:
|
|
224
|
+
return None
|
|
225
|
+
|
|
226
|
+
# Parse masterFrameInterval
|
|
227
|
+
master_frame_interval = 300 # default
|
|
228
|
+
if "masterframeinterval" in parameters:
|
|
229
|
+
with contextlib.suppress(ValueError):
|
|
230
|
+
master_frame_interval = int(parameters["masterframeinterval"])
|
|
231
|
+
|
|
232
|
+
return cls(
|
|
233
|
+
value=s,
|
|
234
|
+
protocol=protocol,
|
|
235
|
+
address=address,
|
|
236
|
+
master_frame_interval=master_frame_interval,
|
|
237
|
+
parameters=parameters,
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
def __str__(self) -> str:
|
|
241
|
+
return self.value
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
@dataclass(frozen=True)
|
|
245
|
+
class SegmentationConnectionString:
|
|
246
|
+
"""
|
|
247
|
+
Strongly-typed connection string for Segmentation output.
|
|
248
|
+
|
|
249
|
+
Supported protocols:
|
|
250
|
+
- file:///path/to/file.bin - File output (absolute path)
|
|
251
|
+
- socket:///tmp/socket.sock - Unix domain socket
|
|
252
|
+
- nng+push+ipc://tmp/segmentation - NNG Push over IPC
|
|
253
|
+
- nng+push+tcp://host:port - NNG Push over TCP
|
|
254
|
+
"""
|
|
255
|
+
|
|
256
|
+
value: str
|
|
257
|
+
protocol: TransportProtocol
|
|
258
|
+
address: str
|
|
259
|
+
parameters: Dict[str, str] = field(default_factory=dict)
|
|
260
|
+
|
|
261
|
+
@classmethod
|
|
262
|
+
def default(cls) -> SegmentationConnectionString:
|
|
263
|
+
"""Default connection string for Segmentation."""
|
|
264
|
+
return cls.parse("nng+push+ipc://tmp/rocket-welder-segmentation")
|
|
265
|
+
|
|
266
|
+
@classmethod
|
|
267
|
+
def from_environment(
|
|
268
|
+
cls, variable_name: str = "SEGMENTATION_CONNECTION_STRING"
|
|
269
|
+
) -> SegmentationConnectionString:
|
|
270
|
+
"""Create from environment variable or use default."""
|
|
271
|
+
value = os.environ.get(variable_name)
|
|
272
|
+
return cls.parse(value) if value else cls.default()
|
|
273
|
+
|
|
274
|
+
@classmethod
|
|
275
|
+
def parse(cls, s: str) -> SegmentationConnectionString:
|
|
276
|
+
"""Parse a connection string."""
|
|
277
|
+
result = cls.try_parse(s)
|
|
278
|
+
if result is None:
|
|
279
|
+
raise ValueError(f"Invalid Segmentation connection string: {s}")
|
|
280
|
+
return result
|
|
281
|
+
|
|
282
|
+
@classmethod
|
|
283
|
+
def try_parse(cls, s: str) -> Optional[SegmentationConnectionString]:
|
|
284
|
+
"""Try to parse a connection string."""
|
|
285
|
+
if not s or not s.strip():
|
|
286
|
+
return None
|
|
287
|
+
|
|
288
|
+
s = s.strip()
|
|
289
|
+
parameters: Dict[str, str] = {}
|
|
290
|
+
|
|
291
|
+
# Extract query parameters
|
|
292
|
+
endpoint_part = s
|
|
293
|
+
if "?" in s:
|
|
294
|
+
endpoint_part, query = s.split("?", 1)
|
|
295
|
+
for key, values in parse_qs(query).items():
|
|
296
|
+
parameters[key.lower()] = values[0] if values else ""
|
|
297
|
+
|
|
298
|
+
# Parse protocol and address
|
|
299
|
+
scheme_end = endpoint_part.find("://")
|
|
300
|
+
if scheme_end <= 0:
|
|
301
|
+
return None
|
|
302
|
+
|
|
303
|
+
schema_str = endpoint_part[:scheme_end]
|
|
304
|
+
path_part = endpoint_part[scheme_end + 3 :] # skip "://"
|
|
305
|
+
|
|
306
|
+
protocol = TransportProtocol.try_parse(schema_str)
|
|
307
|
+
if protocol is None:
|
|
308
|
+
return None
|
|
309
|
+
|
|
310
|
+
# Build address based on protocol type
|
|
311
|
+
if protocol.is_file:
|
|
312
|
+
# file:///absolute/path -> /absolute/path
|
|
313
|
+
address = path_part if path_part.startswith("/") else "/" + path_part
|
|
314
|
+
elif protocol.is_socket:
|
|
315
|
+
# socket:///tmp/sock -> /tmp/sock
|
|
316
|
+
address = path_part if path_part.startswith("/") else "/" + path_part
|
|
317
|
+
elif protocol.is_nng:
|
|
318
|
+
# NNG protocols need proper address format
|
|
319
|
+
address = protocol.create_nng_address(path_part)
|
|
320
|
+
else:
|
|
321
|
+
return None
|
|
322
|
+
|
|
323
|
+
return cls(
|
|
324
|
+
value=s,
|
|
325
|
+
protocol=protocol,
|
|
326
|
+
address=address,
|
|
327
|
+
parameters=parameters,
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
def __str__(self) -> str:
|
|
331
|
+
return self.value
|
|
@@ -0,0 +1,169 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Data context types for per-frame keypoints and segmentation data.
|
|
3
|
+
|
|
4
|
+
Implements the Unit of Work pattern - contexts are created per-frame
|
|
5
|
+
and auto-commit when the processing delegate returns.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
from abc import ABC, abstractmethod
|
|
11
|
+
from typing import TYPE_CHECKING, Sequence, Tuple, Union
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
import numpy.typing as npt
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from rocket_welder_sdk.keypoints_protocol import IKeyPointsWriter
|
|
18
|
+
from rocket_welder_sdk.segmentation_result import SegmentationResultWriter
|
|
19
|
+
|
|
20
|
+
from .schema import KeyPointDefinition, SegmentClass
|
|
21
|
+
|
|
22
|
+
# Type aliases
|
|
23
|
+
Point = Tuple[int, int]
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class IKeyPointsDataContext(ABC):
|
|
27
|
+
"""
|
|
28
|
+
Unit of Work for keypoints data, scoped to a single frame.
|
|
29
|
+
|
|
30
|
+
Auto-commits when the processing delegate returns.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
@property
|
|
34
|
+
@abstractmethod
|
|
35
|
+
def frame_id(self) -> int:
|
|
36
|
+
"""Current frame ID."""
|
|
37
|
+
pass
|
|
38
|
+
|
|
39
|
+
@abstractmethod
|
|
40
|
+
def add(self, point: KeyPointDefinition, x: int, y: int, confidence: float) -> None:
|
|
41
|
+
"""
|
|
42
|
+
Add a keypoint detection for this frame.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
point: KeyPointDefinition from schema definition
|
|
46
|
+
x: X coordinate in pixels
|
|
47
|
+
y: Y coordinate in pixels
|
|
48
|
+
confidence: Detection confidence (0.0 to 1.0)
|
|
49
|
+
"""
|
|
50
|
+
pass
|
|
51
|
+
|
|
52
|
+
@abstractmethod
|
|
53
|
+
def add_point(self, point: KeyPointDefinition, position: Point, confidence: float) -> None:
|
|
54
|
+
"""
|
|
55
|
+
Add a keypoint detection using a Point tuple.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
point: KeyPointDefinition from schema definition
|
|
59
|
+
position: (x, y) tuple
|
|
60
|
+
confidence: Detection confidence (0.0 to 1.0)
|
|
61
|
+
"""
|
|
62
|
+
pass
|
|
63
|
+
|
|
64
|
+
@abstractmethod
|
|
65
|
+
def commit(self) -> None:
|
|
66
|
+
"""Commit the context (called automatically when delegate returns)."""
|
|
67
|
+
pass
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class ISegmentationDataContext(ABC):
|
|
71
|
+
"""
|
|
72
|
+
Unit of Work for segmentation data, scoped to a single frame.
|
|
73
|
+
|
|
74
|
+
Auto-commits when the processing delegate returns.
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
@property
|
|
78
|
+
@abstractmethod
|
|
79
|
+
def frame_id(self) -> int:
|
|
80
|
+
"""Current frame ID."""
|
|
81
|
+
pass
|
|
82
|
+
|
|
83
|
+
@abstractmethod
|
|
84
|
+
def add(
|
|
85
|
+
self,
|
|
86
|
+
segment_class: SegmentClass,
|
|
87
|
+
instance_id: int,
|
|
88
|
+
points: Union[Sequence[Point], npt.NDArray[np.int32]],
|
|
89
|
+
) -> None:
|
|
90
|
+
"""
|
|
91
|
+
Add a segmentation instance for this frame.
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
segment_class: SegmentClass from schema definition
|
|
95
|
+
instance_id: Instance ID (for multiple instances of same class, 0-255)
|
|
96
|
+
points: Contour points defining the instance boundary
|
|
97
|
+
"""
|
|
98
|
+
pass
|
|
99
|
+
|
|
100
|
+
@abstractmethod
|
|
101
|
+
def commit(self) -> None:
|
|
102
|
+
"""Commit the context (called automatically when delegate returns)."""
|
|
103
|
+
pass
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class KeyPointsDataContext(IKeyPointsDataContext):
|
|
107
|
+
"""Implementation of keypoints data context."""
|
|
108
|
+
|
|
109
|
+
def __init__(
|
|
110
|
+
self,
|
|
111
|
+
frame_id: int,
|
|
112
|
+
writer: IKeyPointsWriter,
|
|
113
|
+
) -> None:
|
|
114
|
+
self._frame_id = frame_id
|
|
115
|
+
self._writer = writer
|
|
116
|
+
|
|
117
|
+
@property
|
|
118
|
+
def frame_id(self) -> int:
|
|
119
|
+
return self._frame_id
|
|
120
|
+
|
|
121
|
+
def add(self, point: KeyPointDefinition, x: int, y: int, confidence: float) -> None:
|
|
122
|
+
"""Add a keypoint detection for this frame."""
|
|
123
|
+
self._writer.append(point.id, x, y, confidence)
|
|
124
|
+
|
|
125
|
+
def add_point(self, point: KeyPointDefinition, position: Point, confidence: float) -> None:
|
|
126
|
+
"""Add a keypoint detection using a Point tuple."""
|
|
127
|
+
self._writer.append_point(point.id, position, confidence)
|
|
128
|
+
|
|
129
|
+
def commit(self) -> None:
|
|
130
|
+
"""Commit the context (called automatically when delegate returns)."""
|
|
131
|
+
self._writer.close()
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
class SegmentationDataContext(ISegmentationDataContext):
|
|
135
|
+
"""Implementation of segmentation data context."""
|
|
136
|
+
|
|
137
|
+
def __init__(
|
|
138
|
+
self,
|
|
139
|
+
frame_id: int,
|
|
140
|
+
writer: SegmentationResultWriter,
|
|
141
|
+
) -> None:
|
|
142
|
+
self._frame_id = frame_id
|
|
143
|
+
self._writer = writer
|
|
144
|
+
|
|
145
|
+
@property
|
|
146
|
+
def frame_id(self) -> int:
|
|
147
|
+
return self._frame_id
|
|
148
|
+
|
|
149
|
+
def add(
|
|
150
|
+
self,
|
|
151
|
+
segment_class: SegmentClass,
|
|
152
|
+
instance_id: int,
|
|
153
|
+
points: Union[Sequence[Point], npt.NDArray[np.int32]],
|
|
154
|
+
) -> None:
|
|
155
|
+
"""Add a segmentation instance for this frame."""
|
|
156
|
+
if instance_id < 0 or instance_id > 255:
|
|
157
|
+
raise ValueError(f"instance_id must be 0-255, got {instance_id}")
|
|
158
|
+
|
|
159
|
+
# Convert to numpy array if needed
|
|
160
|
+
if isinstance(points, np.ndarray):
|
|
161
|
+
points_array = points
|
|
162
|
+
else:
|
|
163
|
+
points_array = np.array(points, dtype=np.int32)
|
|
164
|
+
|
|
165
|
+
self._writer.append(segment_class.class_id, instance_id, points_array)
|
|
166
|
+
|
|
167
|
+
def commit(self) -> None:
|
|
168
|
+
"""Commit the context (called automatically when delegate returns)."""
|
|
169
|
+
self._writer.close()
|
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Factory for creating IFrameSink instances from parsed protocol and address.
|
|
3
|
+
|
|
4
|
+
Does NOT parse URLs - use SegmentationConnectionString or KeyPointsConnectionString for parsing.
|
|
5
|
+
|
|
6
|
+
This mirrors the C# FrameSinkFactory class for API consistency.
|
|
7
|
+
|
|
8
|
+
Usage:
|
|
9
|
+
from rocket_welder_sdk.high_level import FrameSinkFactory, SegmentationConnectionString
|
|
10
|
+
|
|
11
|
+
cs = SegmentationConnectionString.parse("socket:///tmp/seg.sock")
|
|
12
|
+
sink = FrameSinkFactory.create(cs.protocol, cs.address)
|
|
13
|
+
|
|
14
|
+
# For null sink (no output configured):
|
|
15
|
+
sink = FrameSinkFactory.create_null()
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
from __future__ import annotations
|
|
19
|
+
|
|
20
|
+
import logging
|
|
21
|
+
from typing import TYPE_CHECKING, Optional
|
|
22
|
+
|
|
23
|
+
from .transport_protocol import TransportProtocol
|
|
24
|
+
|
|
25
|
+
if TYPE_CHECKING:
|
|
26
|
+
from rocket_welder_sdk.transport.frame_sink import IFrameSink
|
|
27
|
+
|
|
28
|
+
logger = logging.getLogger(__name__)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class FrameSinkFactory:
|
|
32
|
+
"""
|
|
33
|
+
Factory for creating IFrameSink instances from parsed protocol and address.
|
|
34
|
+
|
|
35
|
+
Does NOT parse URLs - use SegmentationConnectionString or KeyPointsConnectionString for parsing.
|
|
36
|
+
|
|
37
|
+
Mirrors C# RocketWelder.SDK.Transport.FrameSinkFactory.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
@staticmethod
|
|
41
|
+
def create(
|
|
42
|
+
protocol: Optional[TransportProtocol],
|
|
43
|
+
address: str,
|
|
44
|
+
*,
|
|
45
|
+
logger_instance: Optional[logging.Logger] = None,
|
|
46
|
+
) -> IFrameSink:
|
|
47
|
+
"""
|
|
48
|
+
Create a frame sink from parsed protocol and address.
|
|
49
|
+
|
|
50
|
+
Returns NullFrameSink if protocol is None (no URL specified).
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
protocol: The transport protocol (from ConnectionString.protocol), or None
|
|
54
|
+
address: The address (file path, socket path, or NNG address)
|
|
55
|
+
logger_instance: Optional logger for diagnostics
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
An IFrameSink connected to the specified address, or NullFrameSink if protocol is None
|
|
59
|
+
|
|
60
|
+
Raises:
|
|
61
|
+
ValueError: If protocol is not supported for sinks
|
|
62
|
+
|
|
63
|
+
Example:
|
|
64
|
+
cs = SegmentationConnectionString.parse("socket:///tmp/seg.sock")
|
|
65
|
+
sink = FrameSinkFactory.create(cs.protocol, cs.address)
|
|
66
|
+
"""
|
|
67
|
+
from rocket_welder_sdk.transport import NngFrameSink, NullFrameSink
|
|
68
|
+
from rocket_welder_sdk.transport.stream_transport import StreamFrameSink
|
|
69
|
+
from rocket_welder_sdk.transport.unix_socket_transport import UnixSocketFrameSink
|
|
70
|
+
|
|
71
|
+
log = logger_instance or logger
|
|
72
|
+
|
|
73
|
+
# Handle None protocol - return null sink
|
|
74
|
+
if protocol is None:
|
|
75
|
+
log.debug("No protocol specified, using NullFrameSink")
|
|
76
|
+
return NullFrameSink.instance()
|
|
77
|
+
|
|
78
|
+
if not isinstance(protocol, TransportProtocol):
|
|
79
|
+
raise TypeError(f"Expected TransportProtocol, got {type(protocol).__name__}")
|
|
80
|
+
|
|
81
|
+
if protocol.is_file:
|
|
82
|
+
log.info("Creating file frame sink at: %s", address)
|
|
83
|
+
file_handle = open(address, "wb") # noqa: SIM115
|
|
84
|
+
return StreamFrameSink(file_handle)
|
|
85
|
+
|
|
86
|
+
if protocol.is_socket:
|
|
87
|
+
log.info("Creating Unix socket frame sink at: %s", address)
|
|
88
|
+
return UnixSocketFrameSink.connect(address)
|
|
89
|
+
|
|
90
|
+
if protocol.is_nng:
|
|
91
|
+
log.info("Creating NNG frame sink (%s) at: %s", protocol.schema, address)
|
|
92
|
+
|
|
93
|
+
if protocol.is_pub:
|
|
94
|
+
return NngFrameSink.create_publisher(address)
|
|
95
|
+
if protocol.is_push:
|
|
96
|
+
return NngFrameSink.create_pusher(address)
|
|
97
|
+
|
|
98
|
+
raise ValueError(
|
|
99
|
+
f"NNG protocol '{protocol.schema}' is not supported for sinks "
|
|
100
|
+
"(only pub and push are supported)"
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
raise ValueError(f"Transport protocol '{protocol.schema}' is not supported for frame sinks")
|
|
104
|
+
|
|
105
|
+
@staticmethod
|
|
106
|
+
def create_null() -> IFrameSink:
|
|
107
|
+
"""
|
|
108
|
+
Create a null frame sink that discards all data.
|
|
109
|
+
|
|
110
|
+
Use when no output URL is configured.
|
|
111
|
+
"""
|
|
112
|
+
from rocket_welder_sdk.transport import NullFrameSink
|
|
113
|
+
|
|
114
|
+
return NullFrameSink.instance()
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
# Re-export for convenience
|
|
118
|
+
__all__ = ["FrameSinkFactory"]
|