rocket-welder-sdk 1.1.32__py3-none-any.whl → 1.1.34__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,262 @@
1
+ """
2
+ RocketWelderClient - High-level API matching C# RocketWelder.SDK.
3
+
4
+ Usage:
5
+ with RocketWelderClient.from_environment() as client:
6
+ # Define schema
7
+ nose = client.keypoints.define_point("nose")
8
+ person = client.segmentation.define_class(1, "person")
9
+
10
+ # Start processing
11
+ client.start(process_frame)
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import logging
17
+ from dataclasses import dataclass, field
18
+ from typing import TYPE_CHECKING, Any, Callable, Optional
19
+
20
+ import numpy as np
21
+ import numpy.typing as npt
22
+ from typing_extensions import TypeAlias
23
+
24
+ from .connection_strings import (
25
+ KeyPointsConnectionString,
26
+ SegmentationConnectionString,
27
+ VideoSourceConnectionString,
28
+ )
29
+ from .data_context import (
30
+ IKeyPointsDataContext,
31
+ ISegmentationDataContext,
32
+ KeyPointsDataContext,
33
+ SegmentationDataContext,
34
+ )
35
+ from .schema import (
36
+ IKeyPointsSchema,
37
+ ISegmentationSchema,
38
+ KeyPointsSchema,
39
+ SegmentationSchema,
40
+ )
41
+ from .transport_protocol import TransportKind
42
+
43
+ if TYPE_CHECKING:
44
+ from rocket_welder_sdk.keypoints_protocol import KeyPointsSink
45
+ from rocket_welder_sdk.transport.frame_sink import IFrameSink
46
+
47
+ # Type alias for OpenCV Mat (numpy array)
48
+ Mat: TypeAlias = npt.NDArray[np.uint8]
49
+
50
+ logger = logging.getLogger(__name__)
51
+
52
+
53
+ @dataclass
54
+ class RocketWelderClientOptions:
55
+ """Configuration options for RocketWelderClient."""
56
+
57
+ video_source: VideoSourceConnectionString = field(
58
+ default_factory=VideoSourceConnectionString.default
59
+ )
60
+ keypoints: KeyPointsConnectionString = field(default_factory=KeyPointsConnectionString.default)
61
+ segmentation: SegmentationConnectionString = field(
62
+ default_factory=SegmentationConnectionString.default
63
+ )
64
+
65
+ @classmethod
66
+ def from_environment(cls) -> RocketWelderClientOptions:
67
+ """Create from environment variables."""
68
+ return cls(
69
+ video_source=VideoSourceConnectionString.from_environment(),
70
+ keypoints=KeyPointsConnectionString.from_environment(),
71
+ segmentation=SegmentationConnectionString.from_environment(),
72
+ )
73
+
74
+
75
+ class RocketWelderClient:
76
+ """
77
+ High-level client for RocketWelder SDK.
78
+
79
+ Mirrors C# RocketWelder.SDK.IRocketWelderClient interface.
80
+ """
81
+
82
+ def __init__(self, options: RocketWelderClientOptions) -> None:
83
+ self._options = options
84
+ self._keypoints_schema = KeyPointsSchema()
85
+ self._segmentation_schema = SegmentationSchema()
86
+ self._keypoints_sink: Optional[KeyPointsSink] = None
87
+ self._keypoints_frame_sink: Optional[IFrameSink] = None
88
+ self._segmentation_frame_sink: Optional[IFrameSink] = None
89
+ self._closed = False
90
+ logger.debug("RocketWelderClient created with options: %s", options)
91
+
92
+ @classmethod
93
+ def from_environment(cls) -> RocketWelderClient:
94
+ """Create client from environment variables."""
95
+ logger.info("Creating RocketWelderClient from environment variables")
96
+ return cls(RocketWelderClientOptions.from_environment())
97
+
98
+ @classmethod
99
+ def create(cls, options: Optional[RocketWelderClientOptions] = None) -> RocketWelderClient:
100
+ """Create client with explicit options."""
101
+ return cls(options or RocketWelderClientOptions())
102
+
103
+ @property
104
+ def keypoints(self) -> IKeyPointsSchema:
105
+ """Schema for defining keypoints."""
106
+ return self._keypoints_schema
107
+
108
+ @property
109
+ def segmentation(self) -> ISegmentationSchema:
110
+ """Schema for defining segmentation classes."""
111
+ return self._segmentation_schema
112
+
113
+ def start(
114
+ self,
115
+ process_frame: Callable[[Mat, ISegmentationDataContext, IKeyPointsDataContext, Mat], None],
116
+ ) -> None:
117
+ """Start with both keypoints and segmentation."""
118
+ self._run_loop(process_frame, use_keypoints=True, use_segmentation=True)
119
+
120
+ def start_keypoints(
121
+ self,
122
+ process_frame: Callable[[Mat, IKeyPointsDataContext, Mat], None],
123
+ ) -> None:
124
+ """Start with keypoints only."""
125
+ self._run_loop(process_frame, use_keypoints=True, use_segmentation=False)
126
+
127
+ def start_segmentation(
128
+ self,
129
+ process_frame: Callable[[Mat, ISegmentationDataContext, Mat], None],
130
+ ) -> None:
131
+ """Start with segmentation only."""
132
+ self._run_loop(process_frame, use_keypoints=False, use_segmentation=True)
133
+
134
+ def _run_loop(
135
+ self,
136
+ process_frame: Callable[..., None],
137
+ use_keypoints: bool,
138
+ use_segmentation: bool,
139
+ ) -> None:
140
+ """Run processing loop."""
141
+ from rocket_welder_sdk.keypoints_protocol import KeyPointsSink
142
+
143
+ logger.info(
144
+ "Starting processing loop (keypoints=%s, segmentation=%s)",
145
+ use_keypoints,
146
+ use_segmentation,
147
+ )
148
+
149
+ # Initialize sinks
150
+ if use_keypoints:
151
+ cs = self._options.keypoints
152
+ logger.info("Initializing keypoints sink: %s -> %s", cs.protocol, cs.address)
153
+ self._keypoints_frame_sink = self._create_frame_sink(cs.protocol, cs.address)
154
+ self._keypoints_sink = KeyPointsSink(
155
+ frame_sink=self._keypoints_frame_sink,
156
+ master_frame_interval=cs.master_frame_interval,
157
+ owns_sink=False, # We manage frame sink lifecycle in close()
158
+ )
159
+ logger.debug(
160
+ "KeyPointsSink created with master_frame_interval=%d", cs.master_frame_interval
161
+ )
162
+
163
+ if use_segmentation:
164
+ seg_cs = self._options.segmentation
165
+ logger.info("Initializing segmentation sink: %s -> %s", seg_cs.protocol, seg_cs.address)
166
+ self._segmentation_frame_sink = self._create_frame_sink(seg_cs.protocol, seg_cs.address)
167
+ logger.debug("Segmentation frame sink created")
168
+
169
+ # TODO: Video capture loop - for now raise NotImplementedError
170
+ raise NotImplementedError(
171
+ "Video capture not implemented. Use process_frame_sync() or low-level API."
172
+ )
173
+
174
+ def process_frame_sync(
175
+ self,
176
+ frame_id: int,
177
+ input_frame: Mat,
178
+ output_frame: Mat,
179
+ width: int,
180
+ height: int,
181
+ ) -> tuple[Optional[IKeyPointsDataContext], Optional[ISegmentationDataContext]]:
182
+ """
183
+ Process a single frame synchronously.
184
+
185
+ Returns (keypoints_context, segmentation_context) for the caller to use.
186
+ Caller must call commit() on contexts when done.
187
+ """
188
+ from rocket_welder_sdk.segmentation_result import SegmentationResultWriter
189
+
190
+ kp_ctx: Optional[IKeyPointsDataContext] = None
191
+ seg_ctx: Optional[ISegmentationDataContext] = None
192
+
193
+ if self._keypoints_sink is not None:
194
+ kp_writer = self._keypoints_sink.create_writer(frame_id)
195
+ kp_ctx = KeyPointsDataContext(frame_id, kp_writer)
196
+
197
+ if self._segmentation_frame_sink is not None:
198
+ seg_writer = SegmentationResultWriter(
199
+ frame_id, width, height, frame_sink=self._segmentation_frame_sink
200
+ )
201
+ seg_ctx = SegmentationDataContext(frame_id, seg_writer)
202
+
203
+ return kp_ctx, seg_ctx
204
+
205
+ def _create_frame_sink(self, protocol: Any, address: str) -> IFrameSink:
206
+ """Create frame sink from protocol."""
207
+ from rocket_welder_sdk.transport import NngFrameSink
208
+ from rocket_welder_sdk.transport.stream_transport import StreamFrameSink
209
+ from rocket_welder_sdk.transport.unix_socket_transport import UnixSocketFrameSink
210
+
211
+ from .transport_protocol import TransportProtocol
212
+
213
+ if not isinstance(protocol, TransportProtocol):
214
+ raise TypeError(f"Expected TransportProtocol, got {type(protocol)}")
215
+
216
+ if protocol.kind == TransportKind.FILE:
217
+ logger.debug("Creating file sink: %s", address)
218
+ file_handle = open(address, "wb")
219
+ try:
220
+ return StreamFrameSink(file_handle)
221
+ except Exception:
222
+ file_handle.close()
223
+ raise
224
+ elif protocol.kind == TransportKind.SOCKET:
225
+ logger.debug("Creating Unix socket sink: %s", address)
226
+ return UnixSocketFrameSink.connect(address)
227
+ elif protocol.kind in (TransportKind.NNG_PUSH_IPC, TransportKind.NNG_PUSH_TCP):
228
+ logger.debug("Creating NNG pusher: %s", address)
229
+ return NngFrameSink.create_pusher(address)
230
+ elif protocol.kind in (TransportKind.NNG_PUB_IPC, TransportKind.NNG_PUB_TCP):
231
+ logger.debug("Creating NNG publisher: %s", address)
232
+ return NngFrameSink.create_publisher(address)
233
+ else:
234
+ raise ValueError(f"Unsupported protocol: {protocol}")
235
+
236
+ def close(self) -> None:
237
+ """Release resources."""
238
+ if self._closed:
239
+ return
240
+
241
+ logger.info("Closing RocketWelderClient")
242
+
243
+ # Close frame sinks (KeyPointsSink has owns_sink=False, so we manage lifecycle)
244
+ self._keypoints_sink = None
245
+ if self._keypoints_frame_sink is not None:
246
+ logger.debug("Closing keypoints frame sink")
247
+ self._keypoints_frame_sink.close()
248
+ self._keypoints_frame_sink = None
249
+
250
+ if self._segmentation_frame_sink is not None:
251
+ logger.debug("Closing segmentation frame sink")
252
+ self._segmentation_frame_sink.close()
253
+ self._segmentation_frame_sink = None
254
+
255
+ self._closed = True
256
+ logger.info("RocketWelderClient closed")
257
+
258
+ def __enter__(self) -> RocketWelderClient:
259
+ return self
260
+
261
+ def __exit__(self, *args: object) -> None:
262
+ self.close()
@@ -0,0 +1,331 @@
1
+ """
2
+ Strongly-typed connection strings with parsing support.
3
+
4
+ Connection string format: protocol://path?param1=value1&param2=value2
5
+
6
+ Examples:
7
+ nng+push+ipc://tmp/keypoints?masterFrameInterval=300
8
+ nng+pub+tcp://localhost:5555
9
+ file:///path/to/output.bin
10
+ socket:///tmp/my.sock
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import contextlib
16
+ import os
17
+ from dataclasses import dataclass, field
18
+ from enum import Enum, auto
19
+ from typing import Dict, Optional
20
+ from urllib.parse import parse_qs
21
+
22
+ from .transport_protocol import TransportProtocol
23
+
24
+
25
+ class VideoSourceType(Enum):
26
+ """Type of video source."""
27
+
28
+ CAMERA = auto()
29
+ FILE = auto()
30
+ SHARED_MEMORY = auto()
31
+ RTSP = auto()
32
+ HTTP = auto()
33
+
34
+
35
+ @dataclass(frozen=True)
36
+ class VideoSourceConnectionString:
37
+ """
38
+ Strongly-typed connection string for video source input.
39
+
40
+ Supported formats:
41
+ - "0", "1", etc. - Camera device index
42
+ - file://path/to/video.mp4 - Video file
43
+ - shm://buffer_name - Shared memory buffer
44
+ - rtsp://host/stream - RTSP stream
45
+ """
46
+
47
+ value: str
48
+ source_type: VideoSourceType
49
+ camera_index: Optional[int] = None
50
+ path: Optional[str] = None
51
+ parameters: Dict[str, str] = field(default_factory=dict)
52
+
53
+ @classmethod
54
+ def default(cls) -> VideoSourceConnectionString:
55
+ """Default video source (camera 0)."""
56
+ return cls.parse("0")
57
+
58
+ @classmethod
59
+ def from_environment(cls, variable_name: str = "VIDEO_SOURCE") -> VideoSourceConnectionString:
60
+ """Create from environment variable or use default."""
61
+ value = os.environ.get(variable_name) or os.environ.get("CONNECTION_STRING")
62
+ return cls.parse(value) if value else cls.default()
63
+
64
+ @classmethod
65
+ def parse(cls, s: str) -> VideoSourceConnectionString:
66
+ """Parse a connection string."""
67
+ result = cls.try_parse(s)
68
+ if result is None:
69
+ raise ValueError(f"Invalid video source connection string: {s}")
70
+ return result
71
+
72
+ @classmethod
73
+ def try_parse(cls, s: str) -> Optional[VideoSourceConnectionString]:
74
+ """Try to parse a connection string."""
75
+ if not s or not s.strip():
76
+ return None
77
+
78
+ s = s.strip()
79
+ parameters: Dict[str, str] = {}
80
+
81
+ # Extract query parameters
82
+ if "?" in s:
83
+ base, query = s.split("?", 1)
84
+ for key, values in parse_qs(query).items():
85
+ parameters[key.lower()] = values[0] if values else ""
86
+ s = base
87
+
88
+ # Check for camera index first
89
+ if s.isdigit():
90
+ return cls(
91
+ value=s,
92
+ source_type=VideoSourceType.CAMERA,
93
+ camera_index=int(s),
94
+ parameters=parameters,
95
+ )
96
+
97
+ # Parse protocol
98
+ if s.startswith("file://"):
99
+ path = "/" + s[7:] # Restore absolute path
100
+ return cls(
101
+ value=s,
102
+ source_type=VideoSourceType.FILE,
103
+ path=path,
104
+ parameters=parameters,
105
+ )
106
+ elif s.startswith("shm://"):
107
+ path = s[6:]
108
+ return cls(
109
+ value=s,
110
+ source_type=VideoSourceType.SHARED_MEMORY,
111
+ path=path,
112
+ parameters=parameters,
113
+ )
114
+ elif s.startswith("rtsp://"):
115
+ return cls(
116
+ value=s,
117
+ source_type=VideoSourceType.RTSP,
118
+ path=s,
119
+ parameters=parameters,
120
+ )
121
+ elif s.startswith("http://") or s.startswith("https://"):
122
+ return cls(
123
+ value=s,
124
+ source_type=VideoSourceType.HTTP,
125
+ path=s,
126
+ parameters=parameters,
127
+ )
128
+ elif "://" not in s:
129
+ # Assume file path
130
+ return cls(
131
+ value=s,
132
+ source_type=VideoSourceType.FILE,
133
+ path=s,
134
+ parameters=parameters,
135
+ )
136
+
137
+ return None
138
+
139
+ def __str__(self) -> str:
140
+ return self.value
141
+
142
+
143
+ @dataclass(frozen=True)
144
+ class KeyPointsConnectionString:
145
+ """
146
+ Strongly-typed connection string for KeyPoints output.
147
+
148
+ Supported protocols:
149
+ - file:///path/to/file.bin - File output (absolute path)
150
+ - socket:///tmp/socket.sock - Unix domain socket
151
+ - nng+push+ipc://tmp/keypoints - NNG Push over IPC
152
+ - nng+push+tcp://host:port - NNG Push over TCP
153
+
154
+ Supported parameters:
155
+ - masterFrameInterval: Interval between master frames (default: 300)
156
+ """
157
+
158
+ value: str
159
+ protocol: TransportProtocol
160
+ address: str
161
+ master_frame_interval: int = 300
162
+ parameters: Dict[str, str] = field(default_factory=dict)
163
+
164
+ @classmethod
165
+ def default(cls) -> KeyPointsConnectionString:
166
+ """Default connection string for KeyPoints."""
167
+ return cls.parse("nng+push+ipc://tmp/rocket-welder-keypoints?masterFrameInterval=300")
168
+
169
+ @classmethod
170
+ def from_environment(
171
+ cls, variable_name: str = "KEYPOINTS_CONNECTION_STRING"
172
+ ) -> KeyPointsConnectionString:
173
+ """Create from environment variable or use default."""
174
+ value = os.environ.get(variable_name)
175
+ return cls.parse(value) if value else cls.default()
176
+
177
+ @classmethod
178
+ def parse(cls, s: str) -> KeyPointsConnectionString:
179
+ """Parse a connection string."""
180
+ result = cls.try_parse(s)
181
+ if result is None:
182
+ raise ValueError(f"Invalid KeyPoints connection string: {s}")
183
+ return result
184
+
185
+ @classmethod
186
+ def try_parse(cls, s: str) -> Optional[KeyPointsConnectionString]:
187
+ """Try to parse a connection string."""
188
+ if not s or not s.strip():
189
+ return None
190
+
191
+ s = s.strip()
192
+ parameters: Dict[str, str] = {}
193
+
194
+ # Extract query parameters
195
+ endpoint_part = s
196
+ if "?" in s:
197
+ endpoint_part, query = s.split("?", 1)
198
+ for key, values in parse_qs(query).items():
199
+ parameters[key.lower()] = values[0] if values else ""
200
+
201
+ # Parse protocol and address
202
+ scheme_end = endpoint_part.find("://")
203
+ if scheme_end <= 0:
204
+ return None
205
+
206
+ schema_str = endpoint_part[:scheme_end]
207
+ path_part = endpoint_part[scheme_end + 3 :] # skip "://"
208
+
209
+ protocol = TransportProtocol.try_parse(schema_str)
210
+ if protocol is None:
211
+ return None
212
+
213
+ # Build address based on protocol type
214
+ if protocol.is_file:
215
+ # file:///absolute/path -> /absolute/path
216
+ address = path_part if path_part.startswith("/") else "/" + path_part
217
+ elif protocol.is_socket:
218
+ # socket:///tmp/sock -> /tmp/sock
219
+ address = path_part if path_part.startswith("/") else "/" + path_part
220
+ elif protocol.is_nng:
221
+ # NNG protocols need proper address format
222
+ address = protocol.create_nng_address(path_part)
223
+ else:
224
+ return None
225
+
226
+ # Parse masterFrameInterval
227
+ master_frame_interval = 300 # default
228
+ if "masterframeinterval" in parameters:
229
+ with contextlib.suppress(ValueError):
230
+ master_frame_interval = int(parameters["masterframeinterval"])
231
+
232
+ return cls(
233
+ value=s,
234
+ protocol=protocol,
235
+ address=address,
236
+ master_frame_interval=master_frame_interval,
237
+ parameters=parameters,
238
+ )
239
+
240
+ def __str__(self) -> str:
241
+ return self.value
242
+
243
+
244
+ @dataclass(frozen=True)
245
+ class SegmentationConnectionString:
246
+ """
247
+ Strongly-typed connection string for Segmentation output.
248
+
249
+ Supported protocols:
250
+ - file:///path/to/file.bin - File output (absolute path)
251
+ - socket:///tmp/socket.sock - Unix domain socket
252
+ - nng+push+ipc://tmp/segmentation - NNG Push over IPC
253
+ - nng+push+tcp://host:port - NNG Push over TCP
254
+ """
255
+
256
+ value: str
257
+ protocol: TransportProtocol
258
+ address: str
259
+ parameters: Dict[str, str] = field(default_factory=dict)
260
+
261
+ @classmethod
262
+ def default(cls) -> SegmentationConnectionString:
263
+ """Default connection string for Segmentation."""
264
+ return cls.parse("nng+push+ipc://tmp/rocket-welder-segmentation")
265
+
266
+ @classmethod
267
+ def from_environment(
268
+ cls, variable_name: str = "SEGMENTATION_CONNECTION_STRING"
269
+ ) -> SegmentationConnectionString:
270
+ """Create from environment variable or use default."""
271
+ value = os.environ.get(variable_name)
272
+ return cls.parse(value) if value else cls.default()
273
+
274
+ @classmethod
275
+ def parse(cls, s: str) -> SegmentationConnectionString:
276
+ """Parse a connection string."""
277
+ result = cls.try_parse(s)
278
+ if result is None:
279
+ raise ValueError(f"Invalid Segmentation connection string: {s}")
280
+ return result
281
+
282
+ @classmethod
283
+ def try_parse(cls, s: str) -> Optional[SegmentationConnectionString]:
284
+ """Try to parse a connection string."""
285
+ if not s or not s.strip():
286
+ return None
287
+
288
+ s = s.strip()
289
+ parameters: Dict[str, str] = {}
290
+
291
+ # Extract query parameters
292
+ endpoint_part = s
293
+ if "?" in s:
294
+ endpoint_part, query = s.split("?", 1)
295
+ for key, values in parse_qs(query).items():
296
+ parameters[key.lower()] = values[0] if values else ""
297
+
298
+ # Parse protocol and address
299
+ scheme_end = endpoint_part.find("://")
300
+ if scheme_end <= 0:
301
+ return None
302
+
303
+ schema_str = endpoint_part[:scheme_end]
304
+ path_part = endpoint_part[scheme_end + 3 :] # skip "://"
305
+
306
+ protocol = TransportProtocol.try_parse(schema_str)
307
+ if protocol is None:
308
+ return None
309
+
310
+ # Build address based on protocol type
311
+ if protocol.is_file:
312
+ # file:///absolute/path -> /absolute/path
313
+ address = path_part if path_part.startswith("/") else "/" + path_part
314
+ elif protocol.is_socket:
315
+ # socket:///tmp/sock -> /tmp/sock
316
+ address = path_part if path_part.startswith("/") else "/" + path_part
317
+ elif protocol.is_nng:
318
+ # NNG protocols need proper address format
319
+ address = protocol.create_nng_address(path_part)
320
+ else:
321
+ return None
322
+
323
+ return cls(
324
+ value=s,
325
+ protocol=protocol,
326
+ address=address,
327
+ parameters=parameters,
328
+ )
329
+
330
+ def __str__(self) -> str:
331
+ return self.value