dora-sam2 0.0.0__py3-none-any.whl → 0.3.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dora_sam2/main.py CHANGED
@@ -13,6 +13,10 @@ def main():
13
13
  pa.array([]) # initialize pyarrow array
14
14
  node = Node()
15
15
  frames = {}
16
+ last_pred = None
17
+ labels = None
18
+ return_type = pa.Array
19
+ image_id = None
16
20
  for event in node:
17
21
  event_type = event["type"]
18
22
 
@@ -59,33 +63,143 @@ def main():
59
63
  image = Image.fromarray(frame)
60
64
  frames[event_id] = image
61
65
 
66
+ # TODO: Fix the tracking code for SAM2.
67
+ continue
68
+ if last_pred is not None:
69
+ with (
70
+ torch.inference_mode(),
71
+ torch.autocast(
72
+ "cuda",
73
+ dtype=torch.bfloat16,
74
+ ),
75
+ ):
76
+ predictor.set_image(frames[image_id])
77
+
78
+ new_logits = []
79
+ new_masks = []
80
+
81
+ if len(last_pred.shape) < 3:
82
+ last_pred = np.expand_dims(last_pred, 0)
83
+
84
+ for mask in last_pred:
85
+ mask = np.expand_dims(mask, 0) # Make shape: 1x256x256
86
+ masks, _, new_logit = predictor.predict(
87
+ mask_input=mask,
88
+ multimask_output=False,
89
+ )
90
+ if len(masks.shape) == 4:
91
+ masks = masks[:, 0, :, :]
92
+ else:
93
+ masks = masks[0, :, :]
94
+
95
+ masks = masks > 0
96
+ new_masks.append(masks)
97
+ new_logits.append(new_logit)
98
+ ## Mask to 3 channel image
99
+
100
+ last_pred = np.concatenate(new_logits, axis=0)
101
+ masks = np.concatenate(new_masks, axis=0)
102
+
103
+ match return_type:
104
+ case pa.Array:
105
+ node.send_output(
106
+ "masks",
107
+ pa.array(masks.ravel()),
108
+ metadata={
109
+ "image_id": image_id,
110
+ "width": frames[image_id].width,
111
+ "height": frames[image_id].height,
112
+ },
113
+ )
114
+ case pa.StructArray:
115
+ node.send_output(
116
+ "masks",
117
+ pa.array(
118
+ [
119
+ {
120
+ "masks": masks.ravel(),
121
+ "labels": event["value"]["labels"],
122
+ }
123
+ ]
124
+ ),
125
+ metadata={
126
+ "image_id": image_id,
127
+ "width": frames[image_id].width,
128
+ "height": frames[image_id].height,
129
+ },
130
+ )
131
+
62
132
  elif "boxes2d" in event_id:
63
- boxes2d = event["value"].to_numpy()
133
+
134
+ if isinstance(event["value"], pa.StructArray):
135
+ boxes2d = event["value"][0].get("bbox").values.to_numpy()
136
+ labels = (
137
+ event["value"][0]
138
+ .get("labels")
139
+ .values.to_numpy(zero_copy_only=False)
140
+ )
141
+ return_type = pa.Array
142
+ else:
143
+ boxes2d = event["value"].to_numpy()
144
+ labels = None
145
+ return_type = pa.Array
146
+
64
147
  metadata = event["metadata"]
65
148
  encoding = metadata["encoding"]
66
149
  if encoding != "xyxy":
67
150
  raise RuntimeError(f"Unsupported boxes2d encoding: {encoding}")
68
-
151
+ boxes2d = boxes2d.reshape(-1, 4)
69
152
  image_id = metadata["image_id"]
70
- with torch.inference_mode(), torch.autocast(
71
- "cuda",
72
- dtype=torch.bfloat16,
153
+ with (
154
+ torch.inference_mode(),
155
+ torch.autocast(
156
+ "cuda",
157
+ dtype=torch.bfloat16,
158
+ ),
73
159
  ):
74
160
  predictor.set_image(frames[image_id])
75
- masks, _, _ = predictor.predict(box=boxes2d)
76
- masks = masks[0]
77
- ## Mask to 3 channel image
78
-
79
- node.send_output(
80
- "masks",
81
- pa.array(masks.ravel()),
82
- metadata={
83
- "image_id": image_id,
84
- "width": frames[image_id].width,
85
- "height": frames[image_id].height,
86
- },
161
+ masks, _scores, last_pred = predictor.predict(
162
+ box=boxes2d, point_labels=labels, multimask_output=False
87
163
  )
88
164
 
165
+ if len(masks.shape) == 4:
166
+ masks = masks[:, 0, :, :]
167
+ last_pred = last_pred[:, 0, :, :]
168
+ else:
169
+ masks = masks[0, :, :]
170
+ last_pred = last_pred[0, :, :]
171
+
172
+ masks = masks > 0
173
+ ## Mask to 3 channel image
174
+ match return_type:
175
+ case pa.Array:
176
+ node.send_output(
177
+ "masks",
178
+ pa.array(masks.ravel()),
179
+ metadata={
180
+ "image_id": image_id,
181
+ "width": frames[image_id].width,
182
+ "height": frames[image_id].height,
183
+ },
184
+ )
185
+ case pa.StructArray:
186
+ node.send_output(
187
+ "masks",
188
+ pa.array(
189
+ [
190
+ {
191
+ "masks": masks.ravel(),
192
+ "labels": event["value"]["labels"],
193
+ }
194
+ ]
195
+ ),
196
+ metadata={
197
+ "image_id": image_id,
198
+ "width": frames[image_id].width,
199
+ "height": frames[image_id].height,
200
+ },
201
+ )
202
+
89
203
  elif event_type == "ERROR":
90
204
  print("Event Error:" + event["error"])
91
205
 
@@ -1,12 +1,12 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: dora-sam2
3
- Version: 0.0.0
3
+ Version: 0.3.10
4
4
  Summary: dora-sam2
5
5
  Author-email: Your Name <email@email.com>
6
6
  License: MIT
7
7
  Requires-Python: >=3.10
8
8
  Description-Content-Type: text/markdown
9
- Requires-Dist: dora-rs>=0.3.6
9
+ Requires-Dist: dora-rs>=0.3.9
10
10
  Requires-Dist: huggingface-hub>=0.29.0
11
11
  Requires-Dist: opencv-python>=4.11.0.86
12
12
  Requires-Dist: sam2>=1.1.0
@@ -0,0 +1,8 @@
1
+ dora_sam2/__init__.py,sha256=HuSK3dnyI9Pb5QAuaKFwQQ3J5SIZnLcKHPJO0norGzc,353
2
+ dora_sam2/__main__.py,sha256=Vdhw8YA1K3wPMlbJQYL5WqvRzAKVeZ16mZQFO9VRmCo,62
3
+ dora_sam2/main.py,sha256=zwMJwenWZdmpMfriS8C903g7QdO-qOJrn2lP4PCcPyo,8425
4
+ dora_sam2-0.3.10.dist-info/METADATA,sha256=bC15e7QBHnl8nyT-E5tQLoPdI9LAviaLNsUD7ZglJB0,820
5
+ dora_sam2-0.3.10.dist-info/WHEEL,sha256=jB7zZ3N9hIM9adW7qlTAyycLYW9npaWKLRzaoVcLKcM,91
6
+ dora_sam2-0.3.10.dist-info/entry_points.txt,sha256=eObMaDQauVA_sv4J6fsNjn8V-8syGJK7mO-LrsBu1aA,50
7
+ dora_sam2-0.3.10.dist-info/top_level.txt,sha256=IgKcOITGe2Nlyc79J6dwh3dcp3Wsf-IipIy-1h9GcPE,10
8
+ dora_sam2-0.3.10.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.8.0)
2
+ Generator: setuptools (75.8.2)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,8 +0,0 @@
1
- dora_sam2/__init__.py,sha256=HuSK3dnyI9Pb5QAuaKFwQQ3J5SIZnLcKHPJO0norGzc,353
2
- dora_sam2/__main__.py,sha256=Vdhw8YA1K3wPMlbJQYL5WqvRzAKVeZ16mZQFO9VRmCo,62
3
- dora_sam2/main.py,sha256=7z7kZyI25eOSqnbwUEL8bGXOUnoa6UQ_AXThrjHUAB0,3380
4
- dora_sam2-0.0.0.dist-info/METADATA,sha256=If_OUNOaxv2wAK8BIze_gb07Bh9SpeMiMksV3N0a4qY,819
5
- dora_sam2-0.0.0.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
6
- dora_sam2-0.0.0.dist-info/entry_points.txt,sha256=eObMaDQauVA_sv4J6fsNjn8V-8syGJK7mO-LrsBu1aA,50
7
- dora_sam2-0.0.0.dist-info/top_level.txt,sha256=IgKcOITGe2Nlyc79J6dwh3dcp3Wsf-IipIy-1h9GcPE,10
8
- dora_sam2-0.0.0.dist-info/RECORD,,