dora-sam2 0.0.0__tar.gz → 0.3.10rc1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,12 +1,12 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: dora-sam2
3
- Version: 0.0.0
3
+ Version: 0.3.10rc1
4
4
  Summary: dora-sam2
5
5
  Author-email: Your Name <email@email.com>
6
6
  License: MIT
7
7
  Requires-Python: >=3.10
8
8
  Description-Content-Type: text/markdown
9
- Requires-Dist: dora-rs>=0.3.6
9
+ Requires-Dist: dora-rs>=0.3.10rc1
10
10
  Requires-Dist: huggingface-hub>=0.29.0
11
11
  Requires-Dist: opencv-python>=4.11.0.86
12
12
  Requires-Dist: sam2>=1.1.0
@@ -0,0 +1,208 @@
1
+ import cv2
2
+ import numpy as np
3
+ import pyarrow as pa
4
+ import torch
5
+ from dora import Node
6
+ from PIL import Image
7
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
8
+
9
+ predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-large")
10
+
11
+
12
+ def main():
13
+ pa.array([]) # initialize pyarrow array
14
+ node = Node()
15
+ frames = {}
16
+ last_pred = None
17
+ labels = None
18
+ return_type = pa.Array
19
+ image_id = None
20
+ for event in node:
21
+ event_type = event["type"]
22
+
23
+ if event_type == "INPUT":
24
+ event_id = event["id"]
25
+
26
+ if "image" in event_id:
27
+ storage = event["value"]
28
+ metadata = event["metadata"]
29
+ encoding = metadata["encoding"]
30
+ width = metadata["width"]
31
+ height = metadata["height"]
32
+
33
+ if (
34
+ encoding == "bgr8"
35
+ or encoding == "rgb8"
36
+ or encoding in ["jpeg", "jpg", "jpe", "bmp", "webp", "png"]
37
+ ):
38
+ channels = 3
39
+ storage_type = np.uint8
40
+ else:
41
+ error = f"Unsupported image encoding: {encoding}"
42
+ raise RuntimeError(error)
43
+
44
+ if encoding == "bgr8":
45
+ frame = (
46
+ storage.to_numpy()
47
+ .astype(storage_type)
48
+ .reshape((height, width, channels))
49
+ )
50
+ frame = frame[:, :, ::-1] # OpenCV image (BGR to RGB)
51
+ elif encoding == "rgb8":
52
+ frame = (
53
+ storage.to_numpy()
54
+ .astype(storage_type)
55
+ .reshape((height, width, channels))
56
+ )
57
+ elif encoding in ["jpeg", "jpg", "jpe", "bmp", "webp", "png"]:
58
+ storage = storage.to_numpy()
59
+ frame = cv2.imdecode(storage, cv2.IMREAD_COLOR)
60
+ frame = frame[:, :, ::-1] # OpenCV image (BGR to RGB)
61
+ else:
62
+ raise RuntimeError(f"Unsupported image encoding: {encoding}")
63
+ image = Image.fromarray(frame)
64
+ frames[event_id] = image
65
+
66
+ # TODO: Fix the tracking code for SAM2.
67
+ continue
68
+ if last_pred is not None:
69
+ with (
70
+ torch.inference_mode(),
71
+ torch.autocast(
72
+ "cuda",
73
+ dtype=torch.bfloat16,
74
+ ),
75
+ ):
76
+ predictor.set_image(frames[image_id])
77
+
78
+ new_logits = []
79
+ new_masks = []
80
+
81
+ if len(last_pred.shape) < 3:
82
+ last_pred = np.expand_dims(last_pred, 0)
83
+
84
+ for mask in last_pred:
85
+ mask = np.expand_dims(mask, 0) # Make shape: 1x256x256
86
+ masks, _, new_logit = predictor.predict(
87
+ mask_input=mask,
88
+ multimask_output=False,
89
+ )
90
+ if len(masks.shape) == 4:
91
+ masks = masks[:, 0, :, :]
92
+ else:
93
+ masks = masks[0, :, :]
94
+
95
+ masks = masks > 0
96
+ new_masks.append(masks)
97
+ new_logits.append(new_logit)
98
+ ## Mask to 3 channel image
99
+
100
+ last_pred = np.concatenate(new_logits, axis=0)
101
+ masks = np.concatenate(new_masks, axis=0)
102
+
103
+ match return_type:
104
+ case pa.Array:
105
+ node.send_output(
106
+ "masks",
107
+ pa.array(masks.ravel()),
108
+ metadata={
109
+ "image_id": image_id,
110
+ "width": frames[image_id].width,
111
+ "height": frames[image_id].height,
112
+ },
113
+ )
114
+ case pa.StructArray:
115
+ node.send_output(
116
+ "masks",
117
+ pa.array(
118
+ [
119
+ {
120
+ "masks": masks.ravel(),
121
+ "labels": event["value"]["labels"],
122
+ }
123
+ ]
124
+ ),
125
+ metadata={
126
+ "image_id": image_id,
127
+ "width": frames[image_id].width,
128
+ "height": frames[image_id].height,
129
+ },
130
+ )
131
+
132
+ elif "boxes2d" in event_id:
133
+
134
+ if isinstance(event["value"], pa.StructArray):
135
+ boxes2d = event["value"][0].get("bbox").values.to_numpy()
136
+ labels = (
137
+ event["value"][0]
138
+ .get("labels")
139
+ .values.to_numpy(zero_copy_only=False)
140
+ )
141
+ return_type = pa.Array
142
+ else:
143
+ boxes2d = event["value"].to_numpy()
144
+ labels = None
145
+ return_type = pa.Array
146
+
147
+ metadata = event["metadata"]
148
+ encoding = metadata["encoding"]
149
+ if encoding != "xyxy":
150
+ raise RuntimeError(f"Unsupported boxes2d encoding: {encoding}")
151
+ boxes2d = boxes2d.reshape(-1, 4)
152
+ image_id = metadata["image_id"]
153
+ with (
154
+ torch.inference_mode(),
155
+ torch.autocast(
156
+ "cuda",
157
+ dtype=torch.bfloat16,
158
+ ),
159
+ ):
160
+ predictor.set_image(frames[image_id])
161
+ masks, _scores, last_pred = predictor.predict(
162
+ box=boxes2d, point_labels=labels, multimask_output=False
163
+ )
164
+
165
+ if len(masks.shape) == 4:
166
+ masks = masks[:, 0, :, :]
167
+ last_pred = last_pred[:, 0, :, :]
168
+ else:
169
+ masks = masks[0, :, :]
170
+ last_pred = last_pred[0, :, :]
171
+
172
+ masks = masks > 0
173
+ ## Mask to 3 channel image
174
+ match return_type:
175
+ case pa.Array:
176
+ node.send_output(
177
+ "masks",
178
+ pa.array(masks.ravel()),
179
+ metadata={
180
+ "image_id": image_id,
181
+ "width": frames[image_id].width,
182
+ "height": frames[image_id].height,
183
+ },
184
+ )
185
+ case pa.StructArray:
186
+ node.send_output(
187
+ "masks",
188
+ pa.array(
189
+ [
190
+ {
191
+ "masks": masks.ravel(),
192
+ "labels": event["value"]["labels"],
193
+ }
194
+ ]
195
+ ),
196
+ metadata={
197
+ "image_id": image_id,
198
+ "width": frames[image_id].width,
199
+ "height": frames[image_id].height,
200
+ },
201
+ )
202
+
203
+ elif event_type == "ERROR":
204
+ print("Event Error:" + event["error"])
205
+
206
+
207
+ if __name__ == "__main__":
208
+ main()
@@ -1,12 +1,12 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: dora-sam2
3
- Version: 0.0.0
3
+ Version: 0.3.10rc1
4
4
  Summary: dora-sam2
5
5
  Author-email: Your Name <email@email.com>
6
6
  License: MIT
7
7
  Requires-Python: >=3.10
8
8
  Description-Content-Type: text/markdown
9
- Requires-Dist: dora-rs>=0.3.6
9
+ Requires-Dist: dora-rs>=0.3.10rc1
10
10
  Requires-Dist: huggingface-hub>=0.29.0
11
11
  Requires-Dist: opencv-python>=4.11.0.86
12
12
  Requires-Dist: sam2>=1.1.0
@@ -1,4 +1,4 @@
1
- dora-rs>=0.3.6
1
+ dora-rs>=0.3.10rc1
2
2
  huggingface-hub>=0.29.0
3
3
  opencv-python>=4.11.0.86
4
4
  sam2>=1.1.0
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "dora-sam2"
3
- version = "0.0.0"
3
+ version = "0.3.10-rc1"
4
4
  authors = [{ name = "Your Name", email = "email@email.com" }]
5
5
  description = "dora-sam2"
6
6
  license = { text = "MIT" }
@@ -8,7 +8,7 @@ readme = "README.md"
8
8
  requires-python = ">=3.10"
9
9
 
10
10
  dependencies = [
11
- "dora-rs >= 0.3.6",
11
+ "dora-rs >= 0.3.10rc1",
12
12
  "huggingface-hub>=0.29.0",
13
13
  "opencv-python>=4.11.0.86",
14
14
  "sam2>=1.1.0",
@@ -1,94 +0,0 @@
1
- import cv2
2
- import numpy as np
3
- import pyarrow as pa
4
- import torch
5
- from dora import Node
6
- from PIL import Image
7
- from sam2.sam2_image_predictor import SAM2ImagePredictor
8
-
9
- predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-large")
10
-
11
-
12
- def main():
13
- pa.array([]) # initialize pyarrow array
14
- node = Node()
15
- frames = {}
16
- for event in node:
17
- event_type = event["type"]
18
-
19
- if event_type == "INPUT":
20
- event_id = event["id"]
21
-
22
- if "image" in event_id:
23
- storage = event["value"]
24
- metadata = event["metadata"]
25
- encoding = metadata["encoding"]
26
- width = metadata["width"]
27
- height = metadata["height"]
28
-
29
- if (
30
- encoding == "bgr8"
31
- or encoding == "rgb8"
32
- or encoding in ["jpeg", "jpg", "jpe", "bmp", "webp", "png"]
33
- ):
34
- channels = 3
35
- storage_type = np.uint8
36
- else:
37
- error = f"Unsupported image encoding: {encoding}"
38
- raise RuntimeError(error)
39
-
40
- if encoding == "bgr8":
41
- frame = (
42
- storage.to_numpy()
43
- .astype(storage_type)
44
- .reshape((height, width, channels))
45
- )
46
- frame = frame[:, :, ::-1] # OpenCV image (BGR to RGB)
47
- elif encoding == "rgb8":
48
- frame = (
49
- storage.to_numpy()
50
- .astype(storage_type)
51
- .reshape((height, width, channels))
52
- )
53
- elif encoding in ["jpeg", "jpg", "jpe", "bmp", "webp", "png"]:
54
- storage = storage.to_numpy()
55
- frame = cv2.imdecode(storage, cv2.IMREAD_COLOR)
56
- frame = frame[:, :, ::-1] # OpenCV image (BGR to RGB)
57
- else:
58
- raise RuntimeError(f"Unsupported image encoding: {encoding}")
59
- image = Image.fromarray(frame)
60
- frames[event_id] = image
61
-
62
- elif "boxes2d" in event_id:
63
- boxes2d = event["value"].to_numpy()
64
- metadata = event["metadata"]
65
- encoding = metadata["encoding"]
66
- if encoding != "xyxy":
67
- raise RuntimeError(f"Unsupported boxes2d encoding: {encoding}")
68
-
69
- image_id = metadata["image_id"]
70
- with torch.inference_mode(), torch.autocast(
71
- "cuda",
72
- dtype=torch.bfloat16,
73
- ):
74
- predictor.set_image(frames[image_id])
75
- masks, _, _ = predictor.predict(box=boxes2d)
76
- masks = masks[0]
77
- ## Mask to 3 channel image
78
-
79
- node.send_output(
80
- "masks",
81
- pa.array(masks.ravel()),
82
- metadata={
83
- "image_id": image_id,
84
- "width": frames[image_id].width,
85
- "height": frames[image_id].height,
86
- },
87
- )
88
-
89
- elif event_type == "ERROR":
90
- print("Event Error:" + event["error"])
91
-
92
-
93
- if __name__ == "__main__":
94
- main()
File without changes
File without changes