dora-sam2 0.3.11__py3-none-any.whl → 0.3.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dora_sam2/main.py +56 -2
- {dora_sam2-0.3.11.dist-info → dora_sam2-0.3.12.dist-info}/METADATA +1 -1
- dora_sam2-0.3.12.dist-info/RECORD +8 -0
- {dora_sam2-0.3.11.dist-info → dora_sam2-0.3.12.dist-info}/WHEEL +1 -1
- dora_sam2-0.3.11.dist-info/RECORD +0 -8
- {dora_sam2-0.3.11.dist-info → dora_sam2-0.3.12.dist-info}/entry_points.txt +0 -0
- {dora_sam2-0.3.11.dist-info → dora_sam2-0.3.12.dist-info}/top_level.txt +0 -0
dora_sam2/main.py
CHANGED
@@ -133,7 +133,9 @@ def main():
|
|
133
133
|
)
|
134
134
|
|
135
135
|
if "boxes2d" in event_id:
|
136
|
-
|
136
|
+
if len(event["value"]) == 0:
|
137
|
+
node.send_output("masks", pa.array([]))
|
138
|
+
continue
|
137
139
|
if isinstance(event["value"], pa.StructArray):
|
138
140
|
boxes2d = event["value"][0].get("bbox").values.to_numpy()
|
139
141
|
labels = (
|
@@ -162,7 +164,59 @@ def main():
|
|
162
164
|
):
|
163
165
|
predictor.set_image(frames[image_id])
|
164
166
|
masks, _scores, last_pred = predictor.predict(
|
165
|
-
box=boxes2d,
|
167
|
+
box=boxes2d,
|
168
|
+
point_labels=labels,
|
169
|
+
multimask_output=False,
|
170
|
+
)
|
171
|
+
|
172
|
+
if len(masks.shape) == 4:
|
173
|
+
masks = masks[:, 0, :, :]
|
174
|
+
last_pred = last_pred[:, 0, :, :]
|
175
|
+
else:
|
176
|
+
masks = masks[0, :, :]
|
177
|
+
last_pred = last_pred[0, :, :]
|
178
|
+
|
179
|
+
masks = masks > 0
|
180
|
+
metadata["image_id"] = image_id
|
181
|
+
metadata["width"] = frames[image_id].width
|
182
|
+
metadata["height"] = frames[image_id].height
|
183
|
+
## Mask to 3 channel image
|
184
|
+
match return_type:
|
185
|
+
case pa.Array:
|
186
|
+
node.send_output("masks", pa.array(masks.ravel()), metadata)
|
187
|
+
case pa.StructArray:
|
188
|
+
node.send_output(
|
189
|
+
"masks",
|
190
|
+
pa.array(
|
191
|
+
[
|
192
|
+
{
|
193
|
+
"masks": masks.ravel(),
|
194
|
+
"labels": event["value"]["labels"],
|
195
|
+
},
|
196
|
+
],
|
197
|
+
),
|
198
|
+
metadata,
|
199
|
+
)
|
200
|
+
elif "points" in event_id:
|
201
|
+
points = event["value"].to_numpy().reshape((-1, 2))
|
202
|
+
return_type = pa.Array
|
203
|
+
if len(frames) == 0:
|
204
|
+
continue
|
205
|
+
first_image = next(iter(frames.keys()))
|
206
|
+
image_id = event["metadata"].get("image_id", first_image)
|
207
|
+
with (
|
208
|
+
torch.inference_mode(),
|
209
|
+
torch.autocast(
|
210
|
+
"cuda",
|
211
|
+
dtype=torch.bfloat16,
|
212
|
+
),
|
213
|
+
):
|
214
|
+
predictor.set_image(frames[image_id])
|
215
|
+
labels = [i for i in range(len(points))]
|
216
|
+
masks, _scores, last_pred = predictor.predict(
|
217
|
+
points,
|
218
|
+
point_labels=labels,
|
219
|
+
multimask_output=False,
|
166
220
|
)
|
167
221
|
|
168
222
|
if len(masks.shape) == 4:
|
@@ -0,0 +1,8 @@
|
|
1
|
+
dora_sam2/__init__.py,sha256=tF7WHhHiDweUUzyHsbmFe_ktphE08aA5j33E4ja1udA,381
|
2
|
+
dora_sam2/__main__.py,sha256=NlAb6Jbmmn82K8Ahdi12sliZYdzyY7QaoCCHRuoR_Hg,90
|
3
|
+
dora_sam2/main.py,sha256=Jo6y5vY8LafpdUfb_gHgy7zLJMZz_8sLBw0gL3HeG0M,10873
|
4
|
+
dora_sam2-0.3.12.dist-info/METADATA,sha256=IW0SXlpL6v3CPkV9Rf7rNdeB0RUDq1TmTFwJ-tcOS1w,820
|
5
|
+
dora_sam2-0.3.12.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
6
|
+
dora_sam2-0.3.12.dist-info/entry_points.txt,sha256=eObMaDQauVA_sv4J6fsNjn8V-8syGJK7mO-LrsBu1aA,50
|
7
|
+
dora_sam2-0.3.12.dist-info/top_level.txt,sha256=IgKcOITGe2Nlyc79J6dwh3dcp3Wsf-IipIy-1h9GcPE,10
|
8
|
+
dora_sam2-0.3.12.dist-info/RECORD,,
|
@@ -1,8 +0,0 @@
|
|
1
|
-
dora_sam2/__init__.py,sha256=tF7WHhHiDweUUzyHsbmFe_ktphE08aA5j33E4ja1udA,381
|
2
|
-
dora_sam2/__main__.py,sha256=NlAb6Jbmmn82K8Ahdi12sliZYdzyY7QaoCCHRuoR_Hg,90
|
3
|
-
dora_sam2/main.py,sha256=aQj1zZTtpzebwaayyPqD4B82iFscLKuZ5v7COs3dTNE,8487
|
4
|
-
dora_sam2-0.3.11.dist-info/METADATA,sha256=jTkg2BE_lLdUTroTRkZ6hdsT2xwwtF9Z2D_SYtgG8_Y,820
|
5
|
-
dora_sam2-0.3.11.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
6
|
-
dora_sam2-0.3.11.dist-info/entry_points.txt,sha256=eObMaDQauVA_sv4J6fsNjn8V-8syGJK7mO-LrsBu1aA,50
|
7
|
-
dora_sam2-0.3.11.dist-info/top_level.txt,sha256=IgKcOITGe2Nlyc79J6dwh3dcp3Wsf-IipIy-1h9GcPE,10
|
8
|
-
dora_sam2-0.3.11.dist-info/RECORD,,
|
File without changes
|
File without changes
|