rocket-welder-sdk 1.1.36.dev14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. rocket_welder_sdk/__init__.py +95 -0
  2. rocket_welder_sdk/bytes_size.py +234 -0
  3. rocket_welder_sdk/connection_string.py +291 -0
  4. rocket_welder_sdk/controllers.py +831 -0
  5. rocket_welder_sdk/external_controls/__init__.py +30 -0
  6. rocket_welder_sdk/external_controls/contracts.py +100 -0
  7. rocket_welder_sdk/external_controls/contracts_old.py +105 -0
  8. rocket_welder_sdk/frame_metadata.py +138 -0
  9. rocket_welder_sdk/gst_metadata.py +411 -0
  10. rocket_welder_sdk/high_level/__init__.py +54 -0
  11. rocket_welder_sdk/high_level/client.py +235 -0
  12. rocket_welder_sdk/high_level/connection_strings.py +331 -0
  13. rocket_welder_sdk/high_level/data_context.py +169 -0
  14. rocket_welder_sdk/high_level/frame_sink_factory.py +118 -0
  15. rocket_welder_sdk/high_level/schema.py +195 -0
  16. rocket_welder_sdk/high_level/transport_protocol.py +238 -0
  17. rocket_welder_sdk/keypoints_protocol.py +642 -0
  18. rocket_welder_sdk/opencv_controller.py +278 -0
  19. rocket_welder_sdk/periodic_timer.py +303 -0
  20. rocket_welder_sdk/py.typed +2 -0
  21. rocket_welder_sdk/rocket_welder_client.py +497 -0
  22. rocket_welder_sdk/segmentation_result.py +420 -0
  23. rocket_welder_sdk/session_id.py +238 -0
  24. rocket_welder_sdk/transport/__init__.py +31 -0
  25. rocket_welder_sdk/transport/frame_sink.py +122 -0
  26. rocket_welder_sdk/transport/frame_source.py +74 -0
  27. rocket_welder_sdk/transport/nng_transport.py +197 -0
  28. rocket_welder_sdk/transport/stream_transport.py +193 -0
  29. rocket_welder_sdk/transport/tcp_transport.py +154 -0
  30. rocket_welder_sdk/transport/unix_socket_transport.py +339 -0
  31. rocket_welder_sdk/ui/__init__.py +48 -0
  32. rocket_welder_sdk/ui/controls.py +362 -0
  33. rocket_welder_sdk/ui/icons.py +21628 -0
  34. rocket_welder_sdk/ui/ui_events_projection.py +226 -0
  35. rocket_welder_sdk/ui/ui_service.py +358 -0
  36. rocket_welder_sdk/ui/value_types.py +72 -0
  37. rocket_welder_sdk-1.1.36.dev14.dist-info/METADATA +845 -0
  38. rocket_welder_sdk-1.1.36.dev14.dist-info/RECORD +40 -0
  39. rocket_welder_sdk-1.1.36.dev14.dist-info/WHEEL +5 -0
  40. rocket_welder_sdk-1.1.36.dev14.dist-info/top_level.txt +1 -0
@@ -0,0 +1,411 @@
1
+ """
2
+ GStreamer metadata structures for RocketWelder SDK.
3
+ Matches C# GstCaps and GstMetadata functionality.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ import json
9
+ import re
10
+ from dataclasses import dataclass
11
+ from typing import Any
12
+
13
+ import numpy as np
14
+ import numpy.typing as npt
15
+
16
+
17
+ @dataclass
18
+ class GstCaps:
19
+ """
20
+ GStreamer capabilities representation.
21
+
22
+ Represents video format capabilities including format, dimensions, framerate, etc.
23
+ Matches the C# GstCaps implementation with proper parsing and numpy integration.
24
+ """
25
+
26
+ width: int
27
+ height: int
28
+ format: str
29
+ depth_type: type[np.uint8] | type[np.uint16]
30
+ channels: int
31
+ bytes_per_pixel: int
32
+ framerate_num: int | None = None
33
+ framerate_den: int | None = None
34
+ interlace_mode: str | None = None
35
+ colorimetry: str | None = None
36
+ caps_string: str | None = None
37
+
38
+ @property
39
+ def frame_size(self) -> int:
40
+ """Calculate the expected frame size in bytes."""
41
+ return self.width * self.height * self.bytes_per_pixel
42
+
43
+ @property
44
+ def framerate(self) -> float | None:
45
+ """Get framerate as double (FPS)."""
46
+ if (
47
+ self.framerate_num is not None
48
+ and self.framerate_den is not None
49
+ and self.framerate_den > 0
50
+ ):
51
+ return self.framerate_num / self.framerate_den
52
+ return None
53
+
54
+ @classmethod
55
+ def parse(cls, caps_string: str) -> GstCaps:
56
+ """
57
+ Parse GStreamer caps string.
58
+ Example: "video/x-raw, format=(string)RGB, width=(int)640, height=(int)480, framerate=(fraction)30/1"
59
+
60
+ Args:
61
+ caps_string: GStreamer caps string
62
+
63
+ Returns:
64
+ GstCaps instance
65
+
66
+ Raises:
67
+ ValueError: If caps string is invalid
68
+ """
69
+ if not caps_string or not caps_string.strip():
70
+ raise ValueError("Empty caps string")
71
+
72
+ caps_string = caps_string.strip()
73
+
74
+ # Check if it's a video caps
75
+ if not caps_string.startswith("video/x-raw"):
76
+ raise ValueError(f"Not a video/x-raw caps string: {caps_string}")
77
+
78
+ try:
79
+ # Parse width
80
+ width_match = re.search(r"width=\(int\)(\d+)", caps_string)
81
+ if not width_match:
82
+ raise ValueError("Missing width in caps string")
83
+ width = int(width_match.group(1))
84
+
85
+ # Parse height
86
+ height_match = re.search(r"height=\(int\)(\d+)", caps_string)
87
+ if not height_match:
88
+ raise ValueError("Missing height in caps string")
89
+ height = int(height_match.group(1))
90
+
91
+ # Parse format
92
+ format_match = re.search(r"format=\(string\)(\w+)", caps_string)
93
+ format_str = format_match.group(1) if format_match else "RGB"
94
+
95
+ # Parse framerate (optional)
96
+ framerate_num = None
97
+ framerate_den = None
98
+ framerate_match = re.search(r"framerate=\(fraction\)(\d+)/(\d+)", caps_string)
99
+ if framerate_match:
100
+ framerate_num = int(framerate_match.group(1))
101
+ framerate_den = int(framerate_match.group(2))
102
+
103
+ # Parse interlace mode (optional)
104
+ interlace_mode = None
105
+ interlace_match = re.search(r"interlace-mode=\(string\)(\w+)", caps_string)
106
+ if interlace_match:
107
+ interlace_mode = interlace_match.group(1)
108
+
109
+ # Parse colorimetry (optional)
110
+ colorimetry = None
111
+ colorimetry_match = re.search(r"colorimetry=\(string\)([\w:]+)", caps_string)
112
+ if colorimetry_match:
113
+ colorimetry = colorimetry_match.group(1)
114
+
115
+ # Map format to numpy dtype and get channel info
116
+ depth_type, channels, bytes_per_pixel = cls._map_gstreamer_format_to_numpy(format_str)
117
+
118
+ return cls(
119
+ width=width,
120
+ height=height,
121
+ format=format_str,
122
+ depth_type=depth_type,
123
+ channels=channels,
124
+ bytes_per_pixel=bytes_per_pixel,
125
+ framerate_num=framerate_num,
126
+ framerate_den=framerate_den,
127
+ interlace_mode=interlace_mode,
128
+ colorimetry=colorimetry,
129
+ caps_string=caps_string,
130
+ )
131
+ except Exception as e:
132
+ raise ValueError(f"Failed to parse caps string: {caps_string}") from e
133
+
134
+ @classmethod
135
+ def from_simple(cls, width: int, height: int, format: str = "RGB") -> GstCaps:
136
+ """
137
+ Create GstCaps from simple parameters.
138
+
139
+ Args:
140
+ width: Frame width
141
+ height: Frame height
142
+ format: Pixel format (default: "RGB")
143
+
144
+ Returns:
145
+ GstCaps instance
146
+ """
147
+ depth_type, channels, bytes_per_pixel = cls._map_gstreamer_format_to_numpy(format)
148
+ return cls(
149
+ width=width,
150
+ height=height,
151
+ format=format,
152
+ depth_type=depth_type,
153
+ channels=channels,
154
+ bytes_per_pixel=bytes_per_pixel,
155
+ )
156
+
157
+ @staticmethod
158
+ def _map_gstreamer_format_to_numpy(
159
+ format: str,
160
+ ) -> tuple[type[np.uint8] | type[np.uint16], int, int]:
161
+ """
162
+ Map GStreamer format strings to numpy dtype.
163
+ Reference: https://gstreamer.freedesktop.org/documentation/video/video-format.html
164
+
165
+ Args:
166
+ format: GStreamer format string
167
+
168
+ Returns:
169
+ Tuple of (numpy dtype, channels, bytes_per_pixel)
170
+ """
171
+ format_upper = format.upper() if format else "RGB"
172
+
173
+ format_map = {
174
+ # RGB formats
175
+ "RGB": (np.uint8, 3, 3),
176
+ "BGR": (np.uint8, 3, 3),
177
+ "RGBA": (np.uint8, 4, 4),
178
+ "BGRA": (np.uint8, 4, 4),
179
+ "ARGB": (np.uint8, 4, 4),
180
+ "ABGR": (np.uint8, 4, 4),
181
+ "RGBX": (np.uint8, 4, 4), # RGB with padding
182
+ "BGRX": (np.uint8, 4, 4), # BGR with padding
183
+ "XRGB": (np.uint8, 4, 4), # RGB with padding
184
+ "XBGR": (np.uint8, 4, 4), # BGR with padding
185
+ # 16-bit RGB formats
186
+ "RGB16": (np.uint16, 3, 6),
187
+ "BGR16": (np.uint16, 3, 6),
188
+ # Grayscale formats
189
+ "GRAY8": (np.uint8, 1, 1),
190
+ "GRAY16_LE": (np.uint16, 1, 2),
191
+ "GRAY16_BE": (np.uint16, 1, 2),
192
+ # YUV planar formats (Y plane only for simplicity)
193
+ "I420": (np.uint8, 1, 1),
194
+ "YV12": (np.uint8, 1, 1),
195
+ "NV12": (np.uint8, 1, 1),
196
+ "NV21": (np.uint8, 1, 1),
197
+ # YUV packed formats
198
+ "YUY2": (np.uint8, 2, 2),
199
+ "UYVY": (np.uint8, 2, 2),
200
+ "YVYU": (np.uint8, 2, 2),
201
+ # Bayer formats (raw sensor data)
202
+ "BGGR": (np.uint8, 1, 1),
203
+ "RGGB": (np.uint8, 1, 1),
204
+ "GRBG": (np.uint8, 1, 1),
205
+ "GBRG": (np.uint8, 1, 1),
206
+ }
207
+
208
+ # Default to RGB if unknown
209
+ return format_map.get(format_upper, (np.uint8, 3, 3))
210
+
211
+ def create_array(
212
+ self, data: bytes | memoryview | npt.NDArray[np.uint8] | npt.NDArray[np.uint16]
213
+ ) -> npt.NDArray[np.uint8] | npt.NDArray[np.uint16]:
214
+ """
215
+ Create numpy array with proper format from data.
216
+
217
+ Args:
218
+ data: Frame data as bytes, memoryview, or existing numpy array
219
+
220
+ Returns:
221
+ Numpy array with proper shape and dtype
222
+
223
+ Raises:
224
+ ValueError: If data size doesn't match expected frame size
225
+ """
226
+ # Convert memoryview to bytes if needed
227
+ if isinstance(data, memoryview):
228
+ data = bytes(data)
229
+
230
+ # If it's already a numpy array, check size and reshape if needed
231
+ if isinstance(data, np.ndarray):
232
+ if data.size * data.itemsize != self.frame_size:
233
+ raise ValueError(
234
+ f"Data size mismatch. Expected {self.frame_size} bytes for "
235
+ f"{self.width}x{self.height} {self.format}, got {data.size * data.itemsize}"
236
+ )
237
+ # Reshape if needed
238
+ if self.channels == 1:
239
+ return data.reshape((self.height, self.width))
240
+ else:
241
+ return data.reshape((self.height, self.width, self.channels))
242
+
243
+ # Check data size
244
+ if len(data) != self.frame_size:
245
+ raise ValueError(
246
+ f"Data size mismatch. Expected {self.frame_size} bytes for "
247
+ f"{self.width}x{self.height} {self.format}, got {len(data)}"
248
+ )
249
+
250
+ # Create array from bytes
251
+ arr = np.frombuffer(data, dtype=self.depth_type)
252
+
253
+ # Reshape based on channels
254
+ if self.channels == 1:
255
+ return arr.reshape((self.height, self.width))
256
+ else:
257
+ # For multi-channel images, reshape to (height, width, channels)
258
+ total_pixels = self.width * self.height * self.channels
259
+ if self.depth_type == np.uint16:
260
+ # For 16-bit formats, we need to account for the item size
261
+ arr = arr[:total_pixels]
262
+ return arr.reshape((self.height, self.width, self.channels))
263
+
264
+ def create_array_from_pointer(
265
+ self, ptr: int, copy: bool = False
266
+ ) -> npt.NDArray[np.uint8] | npt.NDArray[np.uint16]:
267
+ """
268
+ Create numpy array from memory pointer (zero-copy by default).
269
+
270
+ Args:
271
+ ptr: Memory pointer as integer
272
+ copy: If True, make a copy of the data; if False, create a view
273
+
274
+ Returns:
275
+ Numpy array with proper shape and dtype
276
+ """
277
+ # Calculate total elements based on depth type
278
+ if self.depth_type == np.uint16:
279
+ total_elements = self.width * self.height * self.channels
280
+ else:
281
+ total_elements = self.frame_size
282
+
283
+ # Create array from pointer using ctypes
284
+ import ctypes
285
+
286
+ # Create a buffer from the pointer
287
+ buffer_size = total_elements * self.depth_type.itemsize
288
+ c_buffer = (ctypes.c_byte * buffer_size).from_address(ptr)
289
+ arr = np.frombuffer(c_buffer, dtype=self.depth_type)
290
+
291
+ # Reshape based on channels
292
+ if self.channels == 1:
293
+ shaped = arr.reshape((self.height, self.width))
294
+ else:
295
+ shaped = arr.reshape((self.height, self.width, self.channels))
296
+
297
+ return shaped.copy() if copy else shaped
298
+
299
+ def __str__(self) -> str:
300
+ """String representation."""
301
+ # If we have the original caps string, return it for perfect round-tripping
302
+ if self.caps_string:
303
+ return self.caps_string
304
+
305
+ # Otherwise build a simple display string
306
+ fps = f" @ {self.framerate:.2f}fps" if self.framerate else ""
307
+ return f"{self.width}x{self.height} {self.format}{fps}"
308
+
309
+ def to_dict(self) -> dict[str, Any]:
310
+ """Convert to dictionary representation."""
311
+ result = {
312
+ "width": self.width,
313
+ "height": self.height,
314
+ "format": self.format,
315
+ "channels": self.channels,
316
+ "bytes_per_pixel": self.bytes_per_pixel,
317
+ }
318
+
319
+ if self.framerate_num is not None:
320
+ result["framerate_num"] = self.framerate_num
321
+ if self.framerate_den is not None:
322
+ result["framerate_den"] = self.framerate_den
323
+ if self.interlace_mode:
324
+ result["interlace_mode"] = self.interlace_mode
325
+ if self.colorimetry:
326
+ result["colorimetry"] = self.colorimetry
327
+
328
+ return result
329
+
330
+
331
+ @dataclass
332
+ class GstMetadata:
333
+ """
334
+ GStreamer metadata structure.
335
+
336
+ Matches the JSON structure written by GStreamer plugins.
337
+ Compatible with C# GstMetadata record.
338
+ """
339
+
340
+ type: str
341
+ version: str
342
+ caps: GstCaps
343
+ element_name: str
344
+
345
+ @classmethod
346
+ def from_json(cls, json_data: str | bytes | dict[str, Any]) -> GstMetadata:
347
+ """
348
+ Create GstMetadata from JSON data.
349
+
350
+ Args:
351
+ json_data: JSON string, bytes, or dictionary
352
+
353
+ Returns:
354
+ GstMetadata instance
355
+
356
+ Raises:
357
+ ValueError: If JSON is invalid or missing required fields
358
+ """
359
+ # Parse JSON if needed
360
+ if isinstance(json_data, (str, bytes)):
361
+ if isinstance(json_data, bytes):
362
+ json_data = json_data.decode("utf-8")
363
+ try:
364
+ data = json.loads(json_data)
365
+ except json.JSONDecodeError as e:
366
+ raise ValueError(f"Invalid JSON: {e}") from e
367
+ else:
368
+ data = json_data
369
+
370
+ # Validate required fields
371
+ if not isinstance(data, dict):
372
+ raise ValueError("JSON must be an object/dictionary")
373
+
374
+ # Get required fields
375
+ type_str = data.get("type", "")
376
+ version = data.get("version", "")
377
+ element_name = data.get("element_name", "")
378
+
379
+ # Parse caps - it's a STRING in the JSON!
380
+ caps_data = data.get("caps")
381
+ if isinstance(caps_data, str):
382
+ # This is the normal case - caps is a string that needs parsing
383
+ caps = GstCaps.parse(caps_data)
384
+ elif isinstance(caps_data, dict):
385
+ # Fallback for dict format (shouldn't happen with real GStreamer)
386
+ # Create a simple caps from dict
387
+ width = caps_data.get("width", 640)
388
+ height = caps_data.get("height", 480)
389
+ format_str = caps_data.get("format", "RGB")
390
+ caps = GstCaps.from_simple(width, height, format_str)
391
+ else:
392
+ raise ValueError(f"Invalid caps data type: {type(caps_data)}")
393
+
394
+ return cls(type=type_str, version=version, caps=caps, element_name=element_name)
395
+
396
+ def to_json(self) -> str:
397
+ """Convert to JSON string."""
398
+ return json.dumps(self.to_dict())
399
+
400
+ def to_dict(self) -> dict[str, Any]:
401
+ """Convert to dictionary representation."""
402
+ return {
403
+ "type": self.type,
404
+ "version": self.version,
405
+ "caps": str(self.caps), # Caps as string for C# compatibility
406
+ "element_name": self.element_name,
407
+ }
408
+
409
+ def __str__(self) -> str:
410
+ """String representation."""
411
+ return f"GstMetadata(type={self.type}, element={self.element_name}, caps={self.caps})"
@@ -0,0 +1,54 @@
1
+ """
2
+ High-level API for RocketWelder SDK.
3
+
4
+ Mirrors C# RocketWelder.SDK API for consistent developer experience.
5
+
6
+ Example:
7
+ from rocket_welder_sdk.high_level import RocketWelderClient
8
+
9
+ with RocketWelderClient.from_environment() as client:
10
+ nose = client.keypoints.define_point("nose")
11
+ person = client.segmentation.define_class(1, "person")
12
+ client.start(process_frame)
13
+ """
14
+
15
+ from .client import RocketWelderClient, RocketWelderClientOptions
16
+ from .connection_strings import (
17
+ KeyPointsConnectionString,
18
+ SegmentationConnectionString,
19
+ VideoSourceConnectionString,
20
+ VideoSourceType,
21
+ )
22
+ from .data_context import (
23
+ IKeyPointsDataContext,
24
+ ISegmentationDataContext,
25
+ )
26
+ from .frame_sink_factory import FrameSinkFactory
27
+ from .schema import (
28
+ IKeyPointsSchema,
29
+ ISegmentationSchema,
30
+ KeyPointDefinition,
31
+ SegmentClass,
32
+ )
33
+ from .transport_protocol import (
34
+ TransportKind,
35
+ TransportProtocol,
36
+ )
37
+
38
+ __all__ = [
39
+ "FrameSinkFactory",
40
+ "IKeyPointsDataContext",
41
+ "IKeyPointsSchema",
42
+ "ISegmentationDataContext",
43
+ "ISegmentationSchema",
44
+ "KeyPointDefinition",
45
+ "KeyPointsConnectionString",
46
+ "RocketWelderClient",
47
+ "RocketWelderClientOptions",
48
+ "SegmentClass",
49
+ "SegmentationConnectionString",
50
+ "TransportKind",
51
+ "TransportProtocol",
52
+ "VideoSourceConnectionString",
53
+ "VideoSourceType",
54
+ ]
@@ -0,0 +1,235 @@
1
+ """
2
+ RocketWelderClient - High-level API matching C# RocketWelder.SDK.
3
+
4
+ Usage:
5
+ with RocketWelderClient.from_environment() as client:
6
+ # Define schema
7
+ nose = client.keypoints.define_point("nose")
8
+ person = client.segmentation.define_class(1, "person")
9
+
10
+ # Start processing
11
+ client.start(process_frame)
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import logging
17
+ from dataclasses import dataclass, field
18
+ from typing import TYPE_CHECKING, Any, Callable, Optional
19
+
20
+ import numpy as np
21
+ import numpy.typing as npt
22
+ from typing_extensions import TypeAlias
23
+
24
+ from .connection_strings import (
25
+ KeyPointsConnectionString,
26
+ SegmentationConnectionString,
27
+ VideoSourceConnectionString,
28
+ )
29
+ from .data_context import (
30
+ IKeyPointsDataContext,
31
+ ISegmentationDataContext,
32
+ KeyPointsDataContext,
33
+ SegmentationDataContext,
34
+ )
35
+ from .frame_sink_factory import FrameSinkFactory
36
+ from .schema import (
37
+ IKeyPointsSchema,
38
+ ISegmentationSchema,
39
+ KeyPointsSchema,
40
+ SegmentationSchema,
41
+ )
42
+
43
+ if TYPE_CHECKING:
44
+ from rocket_welder_sdk.keypoints_protocol import KeyPointsSink
45
+ from rocket_welder_sdk.transport.frame_sink import IFrameSink
46
+
47
+ # Type alias for OpenCV Mat (numpy array)
48
+ Mat: TypeAlias = npt.NDArray[np.uint8]
49
+
50
+ logger = logging.getLogger(__name__)
51
+
52
+
53
+ @dataclass
54
+ class RocketWelderClientOptions:
55
+ """Configuration options for RocketWelderClient."""
56
+
57
+ video_source: VideoSourceConnectionString = field(
58
+ default_factory=VideoSourceConnectionString.default
59
+ )
60
+ keypoints: KeyPointsConnectionString = field(default_factory=KeyPointsConnectionString.default)
61
+ segmentation: SegmentationConnectionString = field(
62
+ default_factory=SegmentationConnectionString.default
63
+ )
64
+
65
+ @classmethod
66
+ def from_environment(cls) -> RocketWelderClientOptions:
67
+ """Create from environment variables."""
68
+ return cls(
69
+ video_source=VideoSourceConnectionString.from_environment(),
70
+ keypoints=KeyPointsConnectionString.from_environment(),
71
+ segmentation=SegmentationConnectionString.from_environment(),
72
+ )
73
+
74
+
75
+ class RocketWelderClient:
76
+ """
77
+ High-level client for RocketWelder SDK.
78
+
79
+ Mirrors C# RocketWelder.SDK.IRocketWelderClient interface.
80
+ """
81
+
82
+ def __init__(self, options: RocketWelderClientOptions) -> None:
83
+ self._options = options
84
+ self._keypoints_schema = KeyPointsSchema()
85
+ self._segmentation_schema = SegmentationSchema()
86
+ self._keypoints_sink: Optional[KeyPointsSink] = None
87
+ self._keypoints_frame_sink: Optional[IFrameSink] = None
88
+ self._segmentation_frame_sink: Optional[IFrameSink] = None
89
+ self._closed = False
90
+ logger.debug("RocketWelderClient created with options: %s", options)
91
+
92
+ @classmethod
93
+ def from_environment(cls) -> RocketWelderClient:
94
+ """Create client from environment variables."""
95
+ logger.info("Creating RocketWelderClient from environment variables")
96
+ return cls(RocketWelderClientOptions.from_environment())
97
+
98
+ @classmethod
99
+ def create(cls, options: Optional[RocketWelderClientOptions] = None) -> RocketWelderClient:
100
+ """Create client with explicit options."""
101
+ return cls(options or RocketWelderClientOptions())
102
+
103
+ @property
104
+ def keypoints(self) -> IKeyPointsSchema:
105
+ """Schema for defining keypoints."""
106
+ return self._keypoints_schema
107
+
108
+ @property
109
+ def segmentation(self) -> ISegmentationSchema:
110
+ """Schema for defining segmentation classes."""
111
+ return self._segmentation_schema
112
+
113
+ def start(
114
+ self,
115
+ process_frame: Callable[[Mat, ISegmentationDataContext, IKeyPointsDataContext, Mat], None],
116
+ ) -> None:
117
+ """Start with both keypoints and segmentation."""
118
+ self._run_loop(process_frame, use_keypoints=True, use_segmentation=True)
119
+
120
+ def start_keypoints(
121
+ self,
122
+ process_frame: Callable[[Mat, IKeyPointsDataContext, Mat], None],
123
+ ) -> None:
124
+ """Start with keypoints only."""
125
+ self._run_loop(process_frame, use_keypoints=True, use_segmentation=False)
126
+
127
+ def start_segmentation(
128
+ self,
129
+ process_frame: Callable[[Mat, ISegmentationDataContext, Mat], None],
130
+ ) -> None:
131
+ """Start with segmentation only."""
132
+ self._run_loop(process_frame, use_keypoints=False, use_segmentation=True)
133
+
134
+ def _run_loop(
135
+ self,
136
+ process_frame: Callable[..., None],
137
+ use_keypoints: bool,
138
+ use_segmentation: bool,
139
+ ) -> None:
140
+ """Run processing loop."""
141
+ from rocket_welder_sdk.keypoints_protocol import KeyPointsSink
142
+
143
+ logger.info(
144
+ "Starting processing loop (keypoints=%s, segmentation=%s)",
145
+ use_keypoints,
146
+ use_segmentation,
147
+ )
148
+
149
+ # Initialize sinks
150
+ if use_keypoints:
151
+ cs = self._options.keypoints
152
+ logger.info("Initializing keypoints sink: %s -> %s", cs.protocol, cs.address)
153
+ self._keypoints_frame_sink = self._create_frame_sink(cs.protocol, cs.address)
154
+ self._keypoints_sink = KeyPointsSink(
155
+ frame_sink=self._keypoints_frame_sink,
156
+ master_frame_interval=cs.master_frame_interval,
157
+ owns_sink=False, # We manage frame sink lifecycle in close()
158
+ )
159
+ logger.debug(
160
+ "KeyPointsSink created with master_frame_interval=%d", cs.master_frame_interval
161
+ )
162
+
163
+ if use_segmentation:
164
+ seg_cs = self._options.segmentation
165
+ logger.info("Initializing segmentation sink: %s -> %s", seg_cs.protocol, seg_cs.address)
166
+ self._segmentation_frame_sink = self._create_frame_sink(seg_cs.protocol, seg_cs.address)
167
+ logger.debug("Segmentation frame sink created")
168
+
169
+ # TODO: Video capture loop - for now raise NotImplementedError
170
+ raise NotImplementedError(
171
+ "Video capture not implemented. Use process_frame_sync() or low-level API."
172
+ )
173
+
174
+ def process_frame_sync(
175
+ self,
176
+ frame_id: int,
177
+ input_frame: Mat,
178
+ output_frame: Mat,
179
+ width: int,
180
+ height: int,
181
+ ) -> tuple[Optional[IKeyPointsDataContext], Optional[ISegmentationDataContext]]:
182
+ """
183
+ Process a single frame synchronously.
184
+
185
+ Returns (keypoints_context, segmentation_context) for the caller to use.
186
+ Caller must call commit() on contexts when done.
187
+ """
188
+ from rocket_welder_sdk.segmentation_result import SegmentationResultWriter
189
+
190
+ kp_ctx: Optional[IKeyPointsDataContext] = None
191
+ seg_ctx: Optional[ISegmentationDataContext] = None
192
+
193
+ if self._keypoints_sink is not None:
194
+ kp_writer = self._keypoints_sink.create_writer(frame_id)
195
+ kp_ctx = KeyPointsDataContext(frame_id, kp_writer)
196
+
197
+ if self._segmentation_frame_sink is not None:
198
+ seg_writer = SegmentationResultWriter(
199
+ frame_id, width, height, frame_sink=self._segmentation_frame_sink
200
+ )
201
+ seg_ctx = SegmentationDataContext(frame_id, seg_writer)
202
+
203
+ return kp_ctx, seg_ctx
204
+
205
+ def _create_frame_sink(self, protocol: Any, address: str) -> IFrameSink:
206
+ """Create frame sink from protocol using FrameSinkFactory."""
207
+ return FrameSinkFactory.create(protocol, address, logger_instance=logger)
208
+
209
+ def close(self) -> None:
210
+ """Release resources."""
211
+ if self._closed:
212
+ return
213
+
214
+ logger.info("Closing RocketWelderClient")
215
+
216
+ # Close frame sinks (KeyPointsSink has owns_sink=False, so we manage lifecycle)
217
+ self._keypoints_sink = None
218
+ if self._keypoints_frame_sink is not None:
219
+ logger.debug("Closing keypoints frame sink")
220
+ self._keypoints_frame_sink.close()
221
+ self._keypoints_frame_sink = None
222
+
223
+ if self._segmentation_frame_sink is not None:
224
+ logger.debug("Closing segmentation frame sink")
225
+ self._segmentation_frame_sink.close()
226
+ self._segmentation_frame_sink = None
227
+
228
+ self._closed = True
229
+ logger.info("RocketWelderClient closed")
230
+
231
+ def __enter__(self) -> RocketWelderClient:
232
+ return self
233
+
234
+ def __exit__(self, *args: object) -> None:
235
+ self.close()