rocket-welder-sdk 1.1.34__py3-none-any.whl → 1.1.34a2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,262 +0,0 @@
1
- """
2
- RocketWelderClient - High-level API matching C# RocketWelder.SDK.
3
-
4
- Usage:
5
- with RocketWelderClient.from_environment() as client:
6
- # Define schema
7
- nose = client.keypoints.define_point("nose")
8
- person = client.segmentation.define_class(1, "person")
9
-
10
- # Start processing
11
- client.start(process_frame)
12
- """
13
-
14
- from __future__ import annotations
15
-
16
- import logging
17
- from dataclasses import dataclass, field
18
- from typing import TYPE_CHECKING, Any, Callable, Optional
19
-
20
- import numpy as np
21
- import numpy.typing as npt
22
- from typing_extensions import TypeAlias
23
-
24
- from .connection_strings import (
25
- KeyPointsConnectionString,
26
- SegmentationConnectionString,
27
- VideoSourceConnectionString,
28
- )
29
- from .data_context import (
30
- IKeyPointsDataContext,
31
- ISegmentationDataContext,
32
- KeyPointsDataContext,
33
- SegmentationDataContext,
34
- )
35
- from .schema import (
36
- IKeyPointsSchema,
37
- ISegmentationSchema,
38
- KeyPointsSchema,
39
- SegmentationSchema,
40
- )
41
- from .transport_protocol import TransportKind
42
-
43
- if TYPE_CHECKING:
44
- from rocket_welder_sdk.keypoints_protocol import KeyPointsSink
45
- from rocket_welder_sdk.transport.frame_sink import IFrameSink
46
-
47
- # Type alias for OpenCV Mat (numpy array)
48
- Mat: TypeAlias = npt.NDArray[np.uint8]
49
-
50
- logger = logging.getLogger(__name__)
51
-
52
-
53
- @dataclass
54
- class RocketWelderClientOptions:
55
- """Configuration options for RocketWelderClient."""
56
-
57
- video_source: VideoSourceConnectionString = field(
58
- default_factory=VideoSourceConnectionString.default
59
- )
60
- keypoints: KeyPointsConnectionString = field(default_factory=KeyPointsConnectionString.default)
61
- segmentation: SegmentationConnectionString = field(
62
- default_factory=SegmentationConnectionString.default
63
- )
64
-
65
- @classmethod
66
- def from_environment(cls) -> RocketWelderClientOptions:
67
- """Create from environment variables."""
68
- return cls(
69
- video_source=VideoSourceConnectionString.from_environment(),
70
- keypoints=KeyPointsConnectionString.from_environment(),
71
- segmentation=SegmentationConnectionString.from_environment(),
72
- )
73
-
74
-
75
- class RocketWelderClient:
76
- """
77
- High-level client for RocketWelder SDK.
78
-
79
- Mirrors C# RocketWelder.SDK.IRocketWelderClient interface.
80
- """
81
-
82
- def __init__(self, options: RocketWelderClientOptions) -> None:
83
- self._options = options
84
- self._keypoints_schema = KeyPointsSchema()
85
- self._segmentation_schema = SegmentationSchema()
86
- self._keypoints_sink: Optional[KeyPointsSink] = None
87
- self._keypoints_frame_sink: Optional[IFrameSink] = None
88
- self._segmentation_frame_sink: Optional[IFrameSink] = None
89
- self._closed = False
90
- logger.debug("RocketWelderClient created with options: %s", options)
91
-
92
- @classmethod
93
- def from_environment(cls) -> RocketWelderClient:
94
- """Create client from environment variables."""
95
- logger.info("Creating RocketWelderClient from environment variables")
96
- return cls(RocketWelderClientOptions.from_environment())
97
-
98
- @classmethod
99
- def create(cls, options: Optional[RocketWelderClientOptions] = None) -> RocketWelderClient:
100
- """Create client with explicit options."""
101
- return cls(options or RocketWelderClientOptions())
102
-
103
- @property
104
- def keypoints(self) -> IKeyPointsSchema:
105
- """Schema for defining keypoints."""
106
- return self._keypoints_schema
107
-
108
- @property
109
- def segmentation(self) -> ISegmentationSchema:
110
- """Schema for defining segmentation classes."""
111
- return self._segmentation_schema
112
-
113
- def start(
114
- self,
115
- process_frame: Callable[[Mat, ISegmentationDataContext, IKeyPointsDataContext, Mat], None],
116
- ) -> None:
117
- """Start with both keypoints and segmentation."""
118
- self._run_loop(process_frame, use_keypoints=True, use_segmentation=True)
119
-
120
- def start_keypoints(
121
- self,
122
- process_frame: Callable[[Mat, IKeyPointsDataContext, Mat], None],
123
- ) -> None:
124
- """Start with keypoints only."""
125
- self._run_loop(process_frame, use_keypoints=True, use_segmentation=False)
126
-
127
- def start_segmentation(
128
- self,
129
- process_frame: Callable[[Mat, ISegmentationDataContext, Mat], None],
130
- ) -> None:
131
- """Start with segmentation only."""
132
- self._run_loop(process_frame, use_keypoints=False, use_segmentation=True)
133
-
134
- def _run_loop(
135
- self,
136
- process_frame: Callable[..., None],
137
- use_keypoints: bool,
138
- use_segmentation: bool,
139
- ) -> None:
140
- """Run processing loop."""
141
- from rocket_welder_sdk.keypoints_protocol import KeyPointsSink
142
-
143
- logger.info(
144
- "Starting processing loop (keypoints=%s, segmentation=%s)",
145
- use_keypoints,
146
- use_segmentation,
147
- )
148
-
149
- # Initialize sinks
150
- if use_keypoints:
151
- cs = self._options.keypoints
152
- logger.info("Initializing keypoints sink: %s -> %s", cs.protocol, cs.address)
153
- self._keypoints_frame_sink = self._create_frame_sink(cs.protocol, cs.address)
154
- self._keypoints_sink = KeyPointsSink(
155
- frame_sink=self._keypoints_frame_sink,
156
- master_frame_interval=cs.master_frame_interval,
157
- owns_sink=False, # We manage frame sink lifecycle in close()
158
- )
159
- logger.debug(
160
- "KeyPointsSink created with master_frame_interval=%d", cs.master_frame_interval
161
- )
162
-
163
- if use_segmentation:
164
- seg_cs = self._options.segmentation
165
- logger.info("Initializing segmentation sink: %s -> %s", seg_cs.protocol, seg_cs.address)
166
- self._segmentation_frame_sink = self._create_frame_sink(seg_cs.protocol, seg_cs.address)
167
- logger.debug("Segmentation frame sink created")
168
-
169
- # TODO: Video capture loop - for now raise NotImplementedError
170
- raise NotImplementedError(
171
- "Video capture not implemented. Use process_frame_sync() or low-level API."
172
- )
173
-
174
- def process_frame_sync(
175
- self,
176
- frame_id: int,
177
- input_frame: Mat,
178
- output_frame: Mat,
179
- width: int,
180
- height: int,
181
- ) -> tuple[Optional[IKeyPointsDataContext], Optional[ISegmentationDataContext]]:
182
- """
183
- Process a single frame synchronously.
184
-
185
- Returns (keypoints_context, segmentation_context) for the caller to use.
186
- Caller must call commit() on contexts when done.
187
- """
188
- from rocket_welder_sdk.segmentation_result import SegmentationResultWriter
189
-
190
- kp_ctx: Optional[IKeyPointsDataContext] = None
191
- seg_ctx: Optional[ISegmentationDataContext] = None
192
-
193
- if self._keypoints_sink is not None:
194
- kp_writer = self._keypoints_sink.create_writer(frame_id)
195
- kp_ctx = KeyPointsDataContext(frame_id, kp_writer)
196
-
197
- if self._segmentation_frame_sink is not None:
198
- seg_writer = SegmentationResultWriter(
199
- frame_id, width, height, frame_sink=self._segmentation_frame_sink
200
- )
201
- seg_ctx = SegmentationDataContext(frame_id, seg_writer)
202
-
203
- return kp_ctx, seg_ctx
204
-
205
- def _create_frame_sink(self, protocol: Any, address: str) -> IFrameSink:
206
- """Create frame sink from protocol."""
207
- from rocket_welder_sdk.transport import NngFrameSink
208
- from rocket_welder_sdk.transport.stream_transport import StreamFrameSink
209
- from rocket_welder_sdk.transport.unix_socket_transport import UnixSocketFrameSink
210
-
211
- from .transport_protocol import TransportProtocol
212
-
213
- if not isinstance(protocol, TransportProtocol):
214
- raise TypeError(f"Expected TransportProtocol, got {type(protocol)}")
215
-
216
- if protocol.kind == TransportKind.FILE:
217
- logger.debug("Creating file sink: %s", address)
218
- file_handle = open(address, "wb")
219
- try:
220
- return StreamFrameSink(file_handle)
221
- except Exception:
222
- file_handle.close()
223
- raise
224
- elif protocol.kind == TransportKind.SOCKET:
225
- logger.debug("Creating Unix socket sink: %s", address)
226
- return UnixSocketFrameSink.connect(address)
227
- elif protocol.kind in (TransportKind.NNG_PUSH_IPC, TransportKind.NNG_PUSH_TCP):
228
- logger.debug("Creating NNG pusher: %s", address)
229
- return NngFrameSink.create_pusher(address)
230
- elif protocol.kind in (TransportKind.NNG_PUB_IPC, TransportKind.NNG_PUB_TCP):
231
- logger.debug("Creating NNG publisher: %s", address)
232
- return NngFrameSink.create_publisher(address)
233
- else:
234
- raise ValueError(f"Unsupported protocol: {protocol}")
235
-
236
- def close(self) -> None:
237
- """Release resources."""
238
- if self._closed:
239
- return
240
-
241
- logger.info("Closing RocketWelderClient")
242
-
243
- # Close frame sinks (KeyPointsSink has owns_sink=False, so we manage lifecycle)
244
- self._keypoints_sink = None
245
- if self._keypoints_frame_sink is not None:
246
- logger.debug("Closing keypoints frame sink")
247
- self._keypoints_frame_sink.close()
248
- self._keypoints_frame_sink = None
249
-
250
- if self._segmentation_frame_sink is not None:
251
- logger.debug("Closing segmentation frame sink")
252
- self._segmentation_frame_sink.close()
253
- self._segmentation_frame_sink = None
254
-
255
- self._closed = True
256
- logger.info("RocketWelderClient closed")
257
-
258
- def __enter__(self) -> RocketWelderClient:
259
- return self
260
-
261
- def __exit__(self, *args: object) -> None:
262
- self.close()