rocket-welder-sdk 1.1.31__py3-none-any.whl → 1.1.33__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,330 @@
1
+ """
2
+ Strongly-typed connection strings with parsing support.
3
+
4
+ Connection string format: protocol://path?param1=value1&param2=value2
5
+
6
+ Examples:
7
+ nng+push+ipc://tmp/keypoints?masterFrameInterval=300
8
+ nng+pub+tcp://localhost:5555
9
+ file://path/to/output.bin
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import contextlib
15
+ import os
16
+ from dataclasses import dataclass, field
17
+ from enum import Enum, auto
18
+ from typing import Dict, Optional
19
+ from urllib.parse import parse_qs
20
+
21
+ from .transport_protocol import TransportProtocol
22
+
23
+
24
+ class VideoSourceType(Enum):
25
+ """Type of video source."""
26
+
27
+ CAMERA = auto()
28
+ FILE = auto()
29
+ SHARED_MEMORY = auto()
30
+ RTSP = auto()
31
+ HTTP = auto()
32
+
33
+
34
+ @dataclass(frozen=True)
35
+ class VideoSourceConnectionString:
36
+ """
37
+ Strongly-typed connection string for video source input.
38
+
39
+ Supported formats:
40
+ - "0", "1", etc. - Camera device index
41
+ - file://path/to/video.mp4 - Video file
42
+ - shm://buffer_name - Shared memory buffer
43
+ - rtsp://host/stream - RTSP stream
44
+ """
45
+
46
+ value: str
47
+ source_type: VideoSourceType
48
+ camera_index: Optional[int] = None
49
+ path: Optional[str] = None
50
+ parameters: Dict[str, str] = field(default_factory=dict)
51
+
52
+ @classmethod
53
+ def default(cls) -> VideoSourceConnectionString:
54
+ """Default video source (camera 0)."""
55
+ return cls.parse("0")
56
+
57
+ @classmethod
58
+ def from_environment(cls, variable_name: str = "VIDEO_SOURCE") -> VideoSourceConnectionString:
59
+ """Create from environment variable or use default."""
60
+ value = os.environ.get(variable_name) or os.environ.get("CONNECTION_STRING")
61
+ return cls.parse(value) if value else cls.default()
62
+
63
+ @classmethod
64
+ def parse(cls, s: str) -> VideoSourceConnectionString:
65
+ """Parse a connection string."""
66
+ result = cls.try_parse(s)
67
+ if result is None:
68
+ raise ValueError(f"Invalid video source connection string: {s}")
69
+ return result
70
+
71
+ @classmethod
72
+ def try_parse(cls, s: str) -> Optional[VideoSourceConnectionString]:
73
+ """Try to parse a connection string."""
74
+ if not s or not s.strip():
75
+ return None
76
+
77
+ s = s.strip()
78
+ parameters: Dict[str, str] = {}
79
+
80
+ # Extract query parameters
81
+ if "?" in s:
82
+ base, query = s.split("?", 1)
83
+ for key, values in parse_qs(query).items():
84
+ parameters[key.lower()] = values[0] if values else ""
85
+ s = base
86
+
87
+ # Check for camera index first
88
+ if s.isdigit():
89
+ return cls(
90
+ value=s,
91
+ source_type=VideoSourceType.CAMERA,
92
+ camera_index=int(s),
93
+ parameters=parameters,
94
+ )
95
+
96
+ # Parse protocol
97
+ if s.startswith("file://"):
98
+ path = "/" + s[7:] # Restore absolute path
99
+ return cls(
100
+ value=s,
101
+ source_type=VideoSourceType.FILE,
102
+ path=path,
103
+ parameters=parameters,
104
+ )
105
+ elif s.startswith("shm://"):
106
+ path = s[6:]
107
+ return cls(
108
+ value=s,
109
+ source_type=VideoSourceType.SHARED_MEMORY,
110
+ path=path,
111
+ parameters=parameters,
112
+ )
113
+ elif s.startswith("rtsp://"):
114
+ return cls(
115
+ value=s,
116
+ source_type=VideoSourceType.RTSP,
117
+ path=s,
118
+ parameters=parameters,
119
+ )
120
+ elif s.startswith("http://") or s.startswith("https://"):
121
+ return cls(
122
+ value=s,
123
+ source_type=VideoSourceType.HTTP,
124
+ path=s,
125
+ parameters=parameters,
126
+ )
127
+ elif "://" not in s:
128
+ # Assume file path
129
+ return cls(
130
+ value=s,
131
+ source_type=VideoSourceType.FILE,
132
+ path=s,
133
+ parameters=parameters,
134
+ )
135
+
136
+ return None
137
+
138
+ def __str__(self) -> str:
139
+ return self.value
140
+
141
+
142
+ @dataclass(frozen=True)
143
+ class KeyPointsConnectionString:
144
+ """
145
+ Strongly-typed connection string for KeyPoints output.
146
+
147
+ Supported protocols (composable with + operator):
148
+ - Transport.Nng + Transport.Push + Transport.Ipc → nng+push+ipc://tmp/keypoints
149
+ - Transport.Nng + Transport.Push + Transport.Tcp → nng+push+tcp://host:port
150
+ - file://path/to/file.bin - File output
151
+
152
+ Supported parameters:
153
+ - masterFrameInterval: Interval between master frames (default: 300)
154
+ """
155
+
156
+ value: str
157
+ protocol: Optional[TransportProtocol] = None
158
+ is_file: bool = False
159
+ address: str = ""
160
+ master_frame_interval: int = 300
161
+ parameters: Dict[str, str] = field(default_factory=dict)
162
+
163
+ @classmethod
164
+ def default(cls) -> KeyPointsConnectionString:
165
+ """Default connection string for KeyPoints."""
166
+ return cls.parse("nng+push+ipc://tmp/rocket-welder-keypoints?masterFrameInterval=300")
167
+
168
+ @classmethod
169
+ def from_environment(
170
+ cls, variable_name: str = "KEYPOINTS_CONNECTION_STRING"
171
+ ) -> KeyPointsConnectionString:
172
+ """Create from environment variable or use default."""
173
+ value = os.environ.get(variable_name)
174
+ return cls.parse(value) if value else cls.default()
175
+
176
+ @classmethod
177
+ def parse(cls, s: str) -> KeyPointsConnectionString:
178
+ """Parse a connection string."""
179
+ result = cls.try_parse(s)
180
+ if result is None:
181
+ raise ValueError(f"Invalid KeyPoints connection string: {s}")
182
+ return result
183
+
184
+ @classmethod
185
+ def try_parse(cls, s: str) -> Optional[KeyPointsConnectionString]:
186
+ """Try to parse a connection string."""
187
+ if not s or not s.strip():
188
+ return None
189
+
190
+ s = s.strip()
191
+ parameters: Dict[str, str] = {}
192
+
193
+ # Extract query parameters
194
+ endpoint_part = s
195
+ if "?" in s:
196
+ endpoint_part, query = s.split("?", 1)
197
+ for key, values in parse_qs(query).items():
198
+ parameters[key.lower()] = values[0] if values else ""
199
+
200
+ # Parse protocol and address
201
+ scheme_end = endpoint_part.find("://")
202
+ if scheme_end > 0:
203
+ protocol_str = endpoint_part[:scheme_end]
204
+ path_part = endpoint_part[scheme_end + 3 :] # skip "://"
205
+
206
+ if protocol_str.lower() == "file":
207
+ address = "/" + path_part # Restore absolute path
208
+ is_file = True
209
+ protocol = None
210
+ else:
211
+ protocol = TransportProtocol.try_parse(protocol_str)
212
+ if protocol is None:
213
+ return None
214
+ address = protocol.create_nng_address(path_part)
215
+ is_file = False
216
+ elif s.startswith("/"):
217
+ # Assume absolute file path
218
+ address = s
219
+ is_file = True
220
+ protocol = None
221
+ else:
222
+ return None
223
+
224
+ # Parse masterFrameInterval
225
+ master_frame_interval = 300 # default
226
+ if "masterframeinterval" in parameters:
227
+ with contextlib.suppress(ValueError):
228
+ master_frame_interval = int(parameters["masterframeinterval"])
229
+
230
+ return cls(
231
+ value=s,
232
+ protocol=protocol,
233
+ is_file=is_file,
234
+ address=address,
235
+ master_frame_interval=master_frame_interval,
236
+ parameters=parameters,
237
+ )
238
+
239
+ def __str__(self) -> str:
240
+ return self.value
241
+
242
+
243
+ @dataclass(frozen=True)
244
+ class SegmentationConnectionString:
245
+ """
246
+ Strongly-typed connection string for Segmentation output.
247
+
248
+ Supported protocols (composable with + operator):
249
+ - Transport.Nng + Transport.Push + Transport.Ipc → nng+push+ipc://tmp/segmentation
250
+ - Transport.Nng + Transport.Push + Transport.Tcp → nng+push+tcp://host:port
251
+ - file://path/to/file.bin - File output
252
+ """
253
+
254
+ value: str
255
+ protocol: Optional[TransportProtocol] = None
256
+ is_file: bool = False
257
+ address: str = ""
258
+ parameters: Dict[str, str] = field(default_factory=dict)
259
+
260
+ @classmethod
261
+ def default(cls) -> SegmentationConnectionString:
262
+ """Default connection string for Segmentation."""
263
+ return cls.parse("nng+push+ipc://tmp/rocket-welder-segmentation")
264
+
265
+ @classmethod
266
+ def from_environment(
267
+ cls, variable_name: str = "SEGMENTATION_CONNECTION_STRING"
268
+ ) -> SegmentationConnectionString:
269
+ """Create from environment variable or use default."""
270
+ value = os.environ.get(variable_name)
271
+ return cls.parse(value) if value else cls.default()
272
+
273
+ @classmethod
274
+ def parse(cls, s: str) -> SegmentationConnectionString:
275
+ """Parse a connection string."""
276
+ result = cls.try_parse(s)
277
+ if result is None:
278
+ raise ValueError(f"Invalid Segmentation connection string: {s}")
279
+ return result
280
+
281
+ @classmethod
282
+ def try_parse(cls, s: str) -> Optional[SegmentationConnectionString]:
283
+ """Try to parse a connection string."""
284
+ if not s or not s.strip():
285
+ return None
286
+
287
+ s = s.strip()
288
+ parameters: Dict[str, str] = {}
289
+
290
+ # Extract query parameters
291
+ endpoint_part = s
292
+ if "?" in s:
293
+ endpoint_part, query = s.split("?", 1)
294
+ for key, values in parse_qs(query).items():
295
+ parameters[key.lower()] = values[0] if values else ""
296
+
297
+ # Parse protocol and address
298
+ scheme_end = endpoint_part.find("://")
299
+ if scheme_end > 0:
300
+ protocol_str = endpoint_part[:scheme_end]
301
+ path_part = endpoint_part[scheme_end + 3 :] # skip "://"
302
+
303
+ if protocol_str.lower() == "file":
304
+ address = "/" + path_part # Restore absolute path
305
+ is_file = True
306
+ protocol = None
307
+ else:
308
+ protocol = TransportProtocol.try_parse(protocol_str)
309
+ if protocol is None:
310
+ return None
311
+ address = protocol.create_nng_address(path_part)
312
+ is_file = False
313
+ elif s.startswith("/"):
314
+ # Assume absolute file path
315
+ address = s
316
+ is_file = True
317
+ protocol = None
318
+ else:
319
+ return None
320
+
321
+ return cls(
322
+ value=s,
323
+ protocol=protocol,
324
+ is_file=is_file,
325
+ address=address,
326
+ parameters=parameters,
327
+ )
328
+
329
+ def __str__(self) -> str:
330
+ return self.value
@@ -0,0 +1,163 @@
1
+ """
2
+ Data context types for per-frame keypoints and segmentation data.
3
+
4
+ Implements the Unit of Work pattern - contexts are created per-frame
5
+ and auto-commit when the processing delegate returns.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from abc import ABC, abstractmethod
11
+ from typing import TYPE_CHECKING, Sequence, Tuple, Union
12
+
13
+ import numpy as np
14
+ import numpy.typing as npt
15
+
16
+ if TYPE_CHECKING:
17
+ from rocket_welder_sdk.keypoints_protocol import IKeyPointsWriter
18
+ from rocket_welder_sdk.segmentation_result import SegmentationResultWriter
19
+
20
+ from .schema import KeyPoint, SegmentClass
21
+
22
+ # Type aliases
23
+ Point = Tuple[int, int]
24
+
25
+
26
+ class IKeyPointsDataContext(ABC):
27
+ """
28
+ Unit of Work for keypoints data, scoped to a single frame.
29
+
30
+ Auto-commits when the processing delegate returns.
31
+ """
32
+
33
+ @property
34
+ @abstractmethod
35
+ def frame_id(self) -> int:
36
+ """Current frame ID."""
37
+ pass
38
+
39
+ @abstractmethod
40
+ def add(self, point: KeyPoint, x: int, y: int, confidence: float) -> None:
41
+ """
42
+ Add a keypoint detection for this frame.
43
+
44
+ Args:
45
+ point: KeyPoint from schema definition
46
+ x: X coordinate in pixels
47
+ y: Y coordinate in pixels
48
+ confidence: Detection confidence (0.0 to 1.0)
49
+ """
50
+ pass
51
+
52
+ @abstractmethod
53
+ def add_point(self, point: KeyPoint, position: Point, confidence: float) -> None:
54
+ """
55
+ Add a keypoint detection using a Point tuple.
56
+
57
+ Args:
58
+ point: KeyPoint from schema definition
59
+ position: (x, y) tuple
60
+ confidence: Detection confidence (0.0 to 1.0)
61
+ """
62
+ pass
63
+
64
+
65
+ class ISegmentationDataContext(ABC):
66
+ """
67
+ Unit of Work for segmentation data, scoped to a single frame.
68
+
69
+ Auto-commits when the processing delegate returns.
70
+ """
71
+
72
+ @property
73
+ @abstractmethod
74
+ def frame_id(self) -> int:
75
+ """Current frame ID."""
76
+ pass
77
+
78
+ @abstractmethod
79
+ def add(
80
+ self,
81
+ segment_class: SegmentClass,
82
+ instance_id: int,
83
+ points: Union[Sequence[Point], npt.NDArray[np.int32]],
84
+ ) -> None:
85
+ """
86
+ Add a segmentation instance for this frame.
87
+
88
+ Args:
89
+ segment_class: SegmentClass from schema definition
90
+ instance_id: Instance ID (for multiple instances of same class, 0-255)
91
+ points: Contour points defining the instance boundary
92
+ """
93
+ pass
94
+
95
+
96
+ class KeyPointsDataContext(IKeyPointsDataContext):
97
+ """Implementation of keypoints data context."""
98
+
99
+ def __init__(
100
+ self,
101
+ frame_id: int,
102
+ writer: IKeyPointsWriter,
103
+ ) -> None:
104
+ from .schema import KeyPoint # noqa: F401
105
+
106
+ self._frame_id = frame_id
107
+ self._writer = writer
108
+
109
+ @property
110
+ def frame_id(self) -> int:
111
+ return self._frame_id
112
+
113
+ def add(self, point: KeyPoint, x: int, y: int, confidence: float) -> None:
114
+ """Add a keypoint detection for this frame."""
115
+ self._writer.append(point.id, x, y, confidence)
116
+
117
+ def add_point(self, point: KeyPoint, position: Point, confidence: float) -> None:
118
+ """Add a keypoint detection using a Point tuple."""
119
+ self._writer.append_point(point.id, position, confidence)
120
+
121
+ def commit(self) -> None:
122
+ """Commit the context (called automatically when delegate returns)."""
123
+ self._writer.close()
124
+
125
+
126
+ class SegmentationDataContext(ISegmentationDataContext):
127
+ """Implementation of segmentation data context."""
128
+
129
+ def __init__(
130
+ self,
131
+ frame_id: int,
132
+ writer: SegmentationResultWriter,
133
+ ) -> None:
134
+ from .schema import SegmentClass # noqa: F401
135
+
136
+ self._frame_id = frame_id
137
+ self._writer = writer
138
+
139
+ @property
140
+ def frame_id(self) -> int:
141
+ return self._frame_id
142
+
143
+ def add(
144
+ self,
145
+ segment_class: SegmentClass,
146
+ instance_id: int,
147
+ points: Union[Sequence[Point], npt.NDArray[np.int32]],
148
+ ) -> None:
149
+ """Add a segmentation instance for this frame."""
150
+ if instance_id < 0 or instance_id > 255:
151
+ raise ValueError(f"instance_id must be 0-255, got {instance_id}")
152
+
153
+ # Convert to numpy array if needed
154
+ if isinstance(points, np.ndarray):
155
+ points_array = points
156
+ else:
157
+ points_array = np.array(points, dtype=np.int32)
158
+
159
+ self._writer.append(segment_class.class_id, instance_id, points_array)
160
+
161
+ def commit(self) -> None:
162
+ """Commit the context (called automatically when delegate returns)."""
163
+ self._writer.close()
@@ -0,0 +1,180 @@
1
+ """
2
+ Schema types for KeyPoints and Segmentation.
3
+
4
+ Provides type-safe definitions for keypoints and segmentation classes
5
+ that are defined at initialization time and used during processing.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import json
11
+ from abc import ABC, abstractmethod
12
+ from dataclasses import dataclass
13
+ from typing import Dict, List
14
+
15
+
16
+ @dataclass(frozen=True)
17
+ class KeyPoint:
18
+ """
19
+ A keypoint definition with ID and name.
20
+
21
+ Created via IKeyPointsSchema.define_point().
22
+ Used as a type-safe handle when adding keypoints to data context.
23
+ """
24
+
25
+ id: int
26
+ name: str
27
+
28
+ def __str__(self) -> str:
29
+ return f"KeyPoint({self.id}, '{self.name}')"
30
+
31
+
32
+ @dataclass(frozen=True)
33
+ class SegmentClass:
34
+ """
35
+ A segmentation class definition with class ID and name.
36
+
37
+ Created via ISegmentationSchema.define_class().
38
+ Used as a type-safe handle when adding instances to data context.
39
+ """
40
+
41
+ class_id: int
42
+ name: str
43
+
44
+ def __str__(self) -> str:
45
+ return f"SegmentClass({self.class_id}, '{self.name}')"
46
+
47
+
48
+ class IKeyPointsSchema(ABC):
49
+ """
50
+ Interface for defining keypoints schema.
51
+
52
+ Keypoints are defined once at initialization and referenced by handle
53
+ when adding data to the context.
54
+ """
55
+
56
+ @abstractmethod
57
+ def define_point(self, name: str) -> KeyPoint:
58
+ """
59
+ Define a new keypoint.
60
+
61
+ Args:
62
+ name: Human-readable name for the keypoint (e.g., "nose", "left_eye")
63
+
64
+ Returns:
65
+ KeyPoint handle for use with IKeyPointsDataContext.add()
66
+ """
67
+ pass
68
+
69
+ @property
70
+ @abstractmethod
71
+ def defined_points(self) -> List[KeyPoint]:
72
+ """Get all defined keypoints."""
73
+ pass
74
+
75
+ @abstractmethod
76
+ def get_metadata_json(self) -> str:
77
+ """Get JSON metadata for serialization."""
78
+ pass
79
+
80
+
81
+ class ISegmentationSchema(ABC):
82
+ """
83
+ Interface for defining segmentation classes schema.
84
+
85
+ Classes are defined once at initialization and referenced by handle
86
+ when adding instances to the context.
87
+ """
88
+
89
+ @abstractmethod
90
+ def define_class(self, class_id: int, name: str) -> SegmentClass:
91
+ """
92
+ Define a new segmentation class.
93
+
94
+ Args:
95
+ class_id: Unique class identifier (0-255)
96
+ name: Human-readable name for the class (e.g., "person", "car")
97
+
98
+ Returns:
99
+ SegmentClass handle for use with ISegmentationDataContext.add()
100
+ """
101
+ pass
102
+
103
+ @property
104
+ @abstractmethod
105
+ def defined_classes(self) -> List[SegmentClass]:
106
+ """Get all defined classes."""
107
+ pass
108
+
109
+ @abstractmethod
110
+ def get_metadata_json(self) -> str:
111
+ """Get JSON metadata for serialization."""
112
+ pass
113
+
114
+
115
+ class KeyPointsSchema(IKeyPointsSchema):
116
+ """Implementation of keypoints schema."""
117
+
118
+ def __init__(self) -> None:
119
+ self._points: Dict[str, KeyPoint] = {}
120
+ self._next_id = 0
121
+
122
+ def define_point(self, name: str) -> KeyPoint:
123
+ """Define a new keypoint."""
124
+ if name in self._points:
125
+ raise ValueError(f"Keypoint '{name}' already defined")
126
+
127
+ point = KeyPoint(id=self._next_id, name=name)
128
+ self._points[name] = point
129
+ self._next_id += 1
130
+ return point
131
+
132
+ @property
133
+ def defined_points(self) -> List[KeyPoint]:
134
+ """Get all defined keypoints."""
135
+ return list(self._points.values())
136
+
137
+ def get_metadata_json(self) -> str:
138
+ """Get JSON metadata for serialization."""
139
+ return json.dumps(
140
+ {
141
+ "version": "1.0",
142
+ "compute_module_name": "",
143
+ "points": {p.name: p.id for p in self._points.values()},
144
+ },
145
+ indent=2,
146
+ )
147
+
148
+
149
+ class SegmentationSchema(ISegmentationSchema):
150
+ """Implementation of segmentation schema."""
151
+
152
+ def __init__(self) -> None:
153
+ self._classes: Dict[int, SegmentClass] = {}
154
+
155
+ def define_class(self, class_id: int, name: str) -> SegmentClass:
156
+ """Define a new segmentation class."""
157
+ if class_id < 0 or class_id > 255:
158
+ raise ValueError(f"class_id must be 0-255, got {class_id}")
159
+
160
+ if class_id in self._classes:
161
+ raise ValueError(f"Class ID {class_id} already defined")
162
+
163
+ segment_class = SegmentClass(class_id=class_id, name=name)
164
+ self._classes[class_id] = segment_class
165
+ return segment_class
166
+
167
+ @property
168
+ def defined_classes(self) -> List[SegmentClass]:
169
+ """Get all defined classes."""
170
+ return list(self._classes.values())
171
+
172
+ def get_metadata_json(self) -> str:
173
+ """Get JSON metadata for serialization."""
174
+ return json.dumps(
175
+ {
176
+ "version": "1.0",
177
+ "classes": {str(c.class_id): c.name for c in self._classes.values()},
178
+ },
179
+ indent=2,
180
+ )