micro-sam 1.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. micro_sam/__init__.py +19 -0
  2. micro_sam/__version__.py +1 -0
  3. micro_sam/_model_settings.py +142 -0
  4. micro_sam/_test_util.py +22 -0
  5. micro_sam/_vendored.py +152 -0
  6. micro_sam/automatic_segmentation.py +551 -0
  7. micro_sam/bioimageio/__init__.py +1 -0
  8. micro_sam/bioimageio/bioengine_export.py +269 -0
  9. micro_sam/bioimageio/model_export.py +527 -0
  10. micro_sam/bioimageio/predictor_adaptor.py +129 -0
  11. micro_sam/evaluation/__init__.py +23 -0
  12. micro_sam/evaluation/benchmark_datasets.py +896 -0
  13. micro_sam/evaluation/evaluation.py +256 -0
  14. micro_sam/evaluation/experiments.py +82 -0
  15. micro_sam/evaluation/inference.py +767 -0
  16. micro_sam/evaluation/instance_segmentation.py +514 -0
  17. micro_sam/evaluation/livecell.py +479 -0
  18. micro_sam/evaluation/model_comparison.py +503 -0
  19. micro_sam/evaluation/multi_dimensional_segmentation.py +376 -0
  20. micro_sam/inference.py +538 -0
  21. micro_sam/instance_segmentation.py +1670 -0
  22. micro_sam/models/__init__.py +2 -0
  23. micro_sam/models/build_sam.py +142 -0
  24. micro_sam/models/peft_sam.py +493 -0
  25. micro_sam/models/sam_3d_wrapper.py +250 -0
  26. micro_sam/models/simple_sam_3d_wrapper.py +176 -0
  27. micro_sam/multi_dimensional_segmentation.py +764 -0
  28. micro_sam/napari.yaml +94 -0
  29. micro_sam/object_classification.py +261 -0
  30. micro_sam/precompute_state.py +340 -0
  31. micro_sam/prompt_based_segmentation.py +506 -0
  32. micro_sam/prompt_generators.py +377 -0
  33. micro_sam/sam_annotator/__init__.py +8 -0
  34. micro_sam/sam_annotator/_annotator.py +225 -0
  35. micro_sam/sam_annotator/_state.py +264 -0
  36. micro_sam/sam_annotator/_tooltips.py +94 -0
  37. micro_sam/sam_annotator/_widgets.py +2100 -0
  38. micro_sam/sam_annotator/annotator_2d.py +144 -0
  39. micro_sam/sam_annotator/annotator_3d.py +155 -0
  40. micro_sam/sam_annotator/annotator_tracking.py +394 -0
  41. micro_sam/sam_annotator/image_series_annotator.py +588 -0
  42. micro_sam/sam_annotator/object_classifier.py +524 -0
  43. micro_sam/sam_annotator/training_ui.py +262 -0
  44. micro_sam/sam_annotator/util.py +796 -0
  45. micro_sam/sample_data.py +410 -0
  46. micro_sam/training/__init__.py +12 -0
  47. micro_sam/training/joint_sam_trainer.py +223 -0
  48. micro_sam/training/sam_trainer.py +520 -0
  49. micro_sam/training/semantic_sam_trainer.py +180 -0
  50. micro_sam/training/simple_sam_trainer.py +70 -0
  51. micro_sam/training/trainable_sam.py +114 -0
  52. micro_sam/training/training.py +1204 -0
  53. micro_sam/training/util.py +355 -0
  54. micro_sam/util.py +1957 -0
  55. micro_sam/visualization.py +176 -0
  56. micro_sam-1.8.0.dist-info/METADATA +123 -0
  57. micro_sam-1.8.0.dist-info/RECORD +75 -0
  58. micro_sam-1.8.0.dist-info/WHEEL +5 -0
  59. micro_sam-1.8.0.dist-info/entry_points.txt +14 -0
  60. micro_sam-1.8.0.dist-info/licenses/LICENSE +21 -0
  61. micro_sam-1.8.0.dist-info/top_level.txt +2 -0
  62. test/__init__.py +0 -0
  63. test/test_automatic_segmentation.py +207 -0
  64. test/test_cli.py +132 -0
  65. test/test_instance_segmentation.py +153 -0
  66. test/test_models/__init__.py +0 -0
  67. test/test_models/test_peft_sam.py +71 -0
  68. test/test_models/test_sam_3d_wrapper.py +27 -0
  69. test/test_models/test_simple_sam_3d_wrapper.py +29 -0
  70. test/test_multi_dimensional_segmentation.py +99 -0
  71. test/test_prompt_based_segmentation.py +295 -0
  72. test/test_prompt_generators.py +185 -0
  73. test/test_training.py +269 -0
  74. test/test_util.py +280 -0
  75. test/test_vendored.py +82 -0
micro_sam/__init__.py ADDED
@@ -0,0 +1,19 @@
1
+ """
2
+ .. include:: ../doc/start_page.md
3
+ .. include:: ../doc/installation.md
4
+ .. include:: ../doc/annotation_tools.md
5
+ .. include:: ../doc/cli_tools.md
6
+ .. include:: ../doc/python_library.md
7
+ .. include:: ../doc/finetuned_models.md
8
+ .. include:: ../doc/apg.md
9
+ .. include:: ../doc/data_submission.md
10
+ .. include:: ../doc/faq.md
11
+ .. include:: ../doc/contributing.md
12
+ .. include:: ../doc/band.md
13
+ """
14
+
15
+ import os
16
+
17
+ from .__version__ import __version__
18
+
19
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
@@ -0,0 +1 @@
1
+ __version__ = "1.8.0"
@@ -0,0 +1,142 @@
1
+ # The settings for the instance segmentation widget with ais.
2
+ AIS_SETTINGS = {
3
+ "vit_t_lm": {
4
+ "center_distance_thresh": 0.5,
5
+ "boundary_distance_thresh": 0.5,
6
+ "distance_smoothing": 2.0,
7
+ "min_size": 100,
8
+ },
9
+ "vit_b_lm": {
10
+ "center_distance_thresh": 0.4,
11
+ "boundary_distance_thresh": 0.5,
12
+ "distance_smoothing": 2.0,
13
+ "min_size": 100,
14
+ },
15
+ "vit_l_lm": {
16
+ "center_distance_thresh": 0.4,
17
+ "boundary_distance_thresh": 0.4,
18
+ "distance_smoothing": 1.6,
19
+ "min_size": 100,
20
+ },
21
+ "vit_h_lm": {
22
+ "center_distance_thresh": 0.5,
23
+ "boundary_distance_thresh": 0.5,
24
+ "distance_smoothing": 1.4,
25
+ "min_size": 100,
26
+ },
27
+
28
+ "vit_t_em_organelles": {
29
+ "center_distance_thresh": 0.4,
30
+ "boundary_distance_thresh": 0.5,
31
+ "distance_smoothing": 1.2,
32
+ "min_size": 100,
33
+ },
34
+ "vit_b_em_organelles": {
35
+ "center_distance_thresh": 0.3,
36
+ "boundary_distance_thresh": 0.4,
37
+ "distance_smoothing": 1.2,
38
+ "min_size": 100,
39
+ },
40
+ "vit_l_em_organelles": {
41
+ "center_distance_thresh": 0.3,
42
+ "boundary_distance_thresh": 0.4,
43
+ "distance_smoothing": 1.2,
44
+ "min_size": 100,
45
+ },
46
+ "vit_h_em_organelles": {
47
+ "center_distance_thresh": 0.3,
48
+ "boundary_distance_thresh": 0.4,
49
+ "distance_smoothing": 1.2,
50
+ "min_size": 100,
51
+ }
52
+ }
53
+
54
+ # The settings for the instance segmentation widget with amg.
55
+ AMG_SETTINGS = {
56
+ "vit_t_lm": {
57
+ "pred_iou_thresh": 0.6,
58
+ "stability_score_thresh": 0.65,
59
+ "min_object_size": 100,
60
+ },
61
+ "vit_b_lm": {
62
+ "pred_iou_thresh": 0.65,
63
+ "stability_score_thresh": 0.7,
64
+ "min_object_size": 100,
65
+ },
66
+ "vit_l_lm": {
67
+ "pred_iou_thresh": 0.65,
68
+ "stability_score_thresh": 0.73,
69
+ "min_object_size": 100,
70
+ },
71
+ "vit_h_lm": {
72
+ "pred_iou_thresh": 0.65,
73
+ "stability_score_thresh": 0.7,
74
+ "min_object_size": 100,
75
+ },
76
+
77
+ "vit_t_em_organelles": {
78
+ "pred_iou_thresh": 0.75,
79
+ "stability_score_thresh": 0.75,
80
+ "min_object_size": 100,
81
+ },
82
+ "vit_b_em_organelles": {
83
+ "pred_iou_thresh": 0.75,
84
+ "stability_score_thresh": 0.75,
85
+ "min_object_size": 100,
86
+ },
87
+ "vit_l_em_organelles": {
88
+ "pred_iou_thresh": 0.8,
89
+ "stability_score_thresh": 0.8,
90
+ "min_object_size": 100,
91
+ },
92
+ "vit_h_em_organelles": {
93
+ "pred_iou_thresh": 0.8,
94
+ "stability_score_thresh": 0.8,
95
+ "min_object_size": 100,
96
+ },
97
+ }
98
+
99
+ # The settings for the nd segment widget.
100
+ ND_SEGMENT_SETTINGS = {
101
+ "vit_t_lm": {
102
+ "projection_mode": "box",
103
+ "iou_threshold": 0.8,
104
+ "box_extension": 0.025,
105
+ },
106
+ "vit_b_lm": {
107
+ "projection_mode": "box",
108
+ "iou_threshold": 0.8,
109
+ "box_extension": 0.025,
110
+ },
111
+ "vit_l_lm": {
112
+ "projection_mode": "box",
113
+ "iou_threshold": 0.8,
114
+ "box_extension": 0.025,
115
+ },
116
+ "vit_h_lm": {
117
+ "projection_mode": "box",
118
+ "iou_threshold": 0.8,
119
+ "box_extension": 0.0025,
120
+ },
121
+
122
+ "vit_t_em_organelles": {
123
+ "projection_mode": "single_point",
124
+ "iou_threshold": 0.6,
125
+ "box_extension": 0.025,
126
+ },
127
+ "vit_b_em_organelles": {
128
+ "projection_mode": "single_point",
129
+ "iou_threshold": 0.6,
130
+ "box_extension": 0.025,
131
+ },
132
+ "vit_l_em_organelles": {
133
+ "projection_mode": "single_point",
134
+ "iou_threshold": 0.6,
135
+ "box_extension": 0.025,
136
+ },
137
+ "vit_h_em_organelles": {
138
+ "projection_mode": "single_point",
139
+ "iou_threshold": 0.6,
140
+ "box_extension": 0.025,
141
+ }
142
+ }
@@ -0,0 +1,22 @@
1
+ import numpy as np
2
+
3
+
4
+ def check_layer_initialization(viewer, expected_shape):
5
+ """Utility function to check the initial layer setup is correct."""
6
+
7
+ assert len(viewer.layers) == 6
8
+ expected_layer_names = [
9
+ "image", "auto_segmentation", "committed_objects", "current_object", "point_prompts", "prompts"
10
+ ]
11
+
12
+ for layer_name in expected_layer_names:
13
+ assert layer_name in viewer.layers
14
+
15
+ # Check prompt layers
16
+ assert viewer.layers["prompts"].data == [] # shape data is list, not numpy array
17
+ np.testing.assert_equal(viewer.layers["point_prompts"].data, 0)
18
+
19
+ # Check segmentation layers.
20
+ for layer_name in ["auto_segmentation", "committed_objects", "current_object"]:
21
+ assert viewer.layers[layer_name].data.shape == expected_shape
22
+ np.testing.assert_equal(viewer.layers[layer_name].data, 0)
micro_sam/_vendored.py ADDED
@@ -0,0 +1,152 @@
1
+ """Functions from other third party libraries.
2
+
3
+ We can remove these functions once the bugs affecting our code is fixed upstream.
4
+
5
+ The license type of the thrid party software project must be compatible with
6
+ the software license the micro-sam project is distributed under.
7
+ """
8
+
9
+ from typing import Any, Dict, List, Literal
10
+
11
+ import numpy as np
12
+
13
+ import torch
14
+
15
+ try:
16
+ from numba import njit
17
+ HAVE_NUMBA = True
18
+ except (ImportError, SystemError):
19
+ HAVE_NUMBA = False
20
+
21
+ def njit(func):
22
+ return func
23
+
24
+ from bioimage_cpp.utils import compute_rle as _compute_rle_cpp
25
+
26
+
27
+ def _compute_rle_bioimage_cpp(mask):
28
+ # bioimage-cpp returns an int64 numpy array; convert to a list to match the
29
+ # numpy/numba implementations and the COCO/SAM RLE format.
30
+ return _compute_rle_cpp(mask).tolist()
31
+
32
+
33
+ def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
34
+ """Calculates boxes in XYXY format around masks. Return [0, 0, 0, 0] for an empty mask.
35
+
36
+ For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4.
37
+
38
+ This function is adapted from https://github.com/facebookresearch/segment-anything/segment_anything/util/amg.py
39
+ so that it also supports tensors that have MPS device.
40
+ It further ensures that inputs are boolean tensors, otherwise the function yields wrong results.
41
+ See https://github.com/facebookresearch/segment-anything/issues/552 for details.
42
+ """
43
+ assert masks.dtype == torch.bool, masks.dtype
44
+
45
+ # torch.max below raises an error on empty inputs, just skip in this case
46
+ if torch.numel(masks) == 0:
47
+ return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
48
+
49
+ # Normalize shape to CxHxW
50
+ shape = masks.shape
51
+ h, w = shape[-2:]
52
+ if len(shape) > 2:
53
+ masks = masks.flatten(0, -3)
54
+ else:
55
+ masks = masks.unsqueeze(0)
56
+
57
+ # Get top and bottom edges
58
+ in_height, _ = torch.max(masks, dim=-1)
59
+ in_height_coords = in_height * torch.arange(h, dtype=torch.int, device=in_height.device)[None, :]
60
+ bottom_edges, _ = torch.max(in_height_coords, dim=-1)
61
+ in_height_coords = in_height_coords + h * (~in_height)
62
+ in_height_coords = in_height_coords.type(torch.int)
63
+ top_edges, _ = torch.min(in_height_coords, dim=-1)
64
+
65
+ # Get left and right edges
66
+ in_width, _ = torch.max(masks, dim=-2)
67
+ in_width_coords = in_width * torch.arange(w, dtype=torch.int, device=in_width.device)[None, :]
68
+ right_edges, _ = torch.max(in_width_coords, dim=-1)
69
+ in_width_coords = in_width_coords + w * (~in_width)
70
+ in_width_coords = in_width_coords.type(torch.int)
71
+ left_edges, _ = torch.min(in_width_coords, dim=-1)
72
+
73
+ # If the mask is empty the right edge will be to the left of the left edge.
74
+ # Replace these boxes with [0, 0, 0, 0]
75
+ empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
76
+ out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1)
77
+ out = out * (~empty_filter).unsqueeze(-1)
78
+
79
+ # Return to original shape
80
+ if len(shape) > 2:
81
+ out = out.reshape(*shape[:-2], 4)
82
+ else:
83
+ out = out[0]
84
+
85
+ return out
86
+
87
+
88
+ @njit
89
+ def _compute_rle_numba(mask):
90
+ val = mask[0]
91
+ counts = [int(x) for x in range(0)] if val == 0 else [0]
92
+ count = 0
93
+ for m in mask:
94
+ if val == m:
95
+ count += 1
96
+ else:
97
+ val = m
98
+ counts.append(count)
99
+ count = 1
100
+ counts.append(count)
101
+ return counts
102
+
103
+
104
+ def _compute_rle_numpy(mask):
105
+ diffs = mask[1:] != mask[:-1] # pairwise unequal (string safe)
106
+ indices = np.append(np.where(diffs), len(mask) - 1) # must include last element position
107
+ # count needs to start with 0 if the mask begins with 1
108
+ counts = [] if mask[0] == 0 else [0]
109
+ # compute the actual RLE
110
+ counts += np.diff(np.append(-1, indices)).tolist()
111
+ return counts
112
+
113
+
114
+ def mask_to_rle_pytorch(
115
+ tensor: torch.Tensor, rle_implementation: Literal["default", "numpy", "numba", "bioimage_cpp"] = "default"
116
+ ) -> List[Dict[str, Any]]:
117
+ """Calculates the runlength encoding of binary input masks.
118
+
119
+ This replaces the function in
120
+ https://github.com/facebookresearch/segment-anything/segment_anything/util/amg.py
121
+ with a version that computes the RLE purely on the CPU.
122
+ This does not lead to any performance deficits when running on the GPU, but it speeds the computation
123
+ up significantly compared to running this on an MPS device.
124
+ The RLE implementation is based on
125
+ https://stackoverflow.com/questions/1066758/find-length-of-sequences-of-identical-values-in-a-numpy-array-run-length-encodi
126
+ """
127
+ # Put in fortran order and flatten h, w
128
+ b, h, w = tensor.shape
129
+ tensor = tensor.permute(0, 2, 1).flatten(1)
130
+ tensor = tensor.detach().cpu().numpy()
131
+
132
+ if rle_implementation == "default":
133
+ rle_implementation = "bioimage_cpp"
134
+
135
+ if rle_implementation == "numba":
136
+ assert HAVE_NUMBA
137
+ rle_impl = _compute_rle_numba
138
+ elif rle_implementation == "numpy":
139
+ rle_impl = _compute_rle_numpy
140
+ elif rle_implementation == "bioimage_cpp":
141
+ rle_impl = _compute_rle_bioimage_cpp
142
+ else:
143
+ raise ValueError(
144
+ f"RLE implementation {rle_implementation} is not available. "
145
+ "Has to be one of 'numpy', 'numba' or 'bioimage_cpp'."
146
+ )
147
+
148
+ out = []
149
+ for mask in tensor:
150
+ counts = rle_impl(mask)
151
+ out.append({"size": [h, w], "counts": counts})
152
+ return out