birder 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- birder/adversarial/base.py +1 -1
- birder/adversarial/simba.py +4 -4
- birder/common/cli.py +1 -1
- birder/common/fs_ops.py +11 -11
- birder/common/lib.py +2 -2
- birder/common/masking.py +3 -3
- birder/common/training_cli.py +24 -2
- birder/common/training_utils.py +28 -4
- birder/data/collators/detection.py +9 -1
- birder/data/transforms/detection.py +27 -8
- birder/data/transforms/mosaic.py +1 -1
- birder/datahub/classification.py +3 -3
- birder/inference/classification.py +3 -3
- birder/inference/data_parallel.py +1 -1
- birder/inference/detection.py +5 -5
- birder/inference/wbf.py +1 -1
- birder/introspection/attention_rollout.py +5 -5
- birder/introspection/feature_pca.py +4 -4
- birder/introspection/gradcam.py +1 -1
- birder/introspection/guided_backprop.py +2 -2
- birder/introspection/transformer_attribution.py +3 -3
- birder/layers/attention_pool.py +2 -2
- birder/model_registry/model_registry.py +2 -1
- birder/net/__init__.py +2 -0
- birder/net/_rope_vit_configs.py +5 -0
- birder/net/_vit_configs.py +0 -13
- birder/net/alexnet.py +5 -5
- birder/net/base.py +28 -3
- birder/net/biformer.py +17 -17
- birder/net/cait.py +2 -2
- birder/net/cas_vit.py +1 -1
- birder/net/coat.py +15 -15
- birder/net/convnext_v1.py +2 -10
- birder/net/convnext_v1_iso.py +198 -0
- birder/net/convnext_v2.py +2 -10
- birder/net/crossformer.py +9 -9
- birder/net/crossvit.py +1 -1
- birder/net/cspnet.py +1 -1
- birder/net/cswin_transformer.py +10 -10
- birder/net/davit.py +10 -10
- birder/net/deit.py +56 -3
- birder/net/deit3.py +27 -15
- birder/net/detection/__init__.py +4 -0
- birder/net/detection/{yolo_anchors.py → _yolo_anchors.py} +5 -5
- birder/net/detection/base.py +6 -5
- birder/net/detection/deformable_detr.py +26 -28
- birder/net/detection/detr.py +9 -9
- birder/net/detection/efficientdet.py +9 -28
- birder/net/detection/faster_rcnn.py +22 -22
- birder/net/detection/fcos.py +8 -8
- birder/net/detection/plain_detr.py +852 -0
- birder/net/detection/retinanet.py +4 -4
- birder/net/detection/rt_detr_v1.py +81 -25
- birder/net/detection/rt_detr_v2.py +1147 -0
- birder/net/detection/ssd.py +5 -5
- birder/net/detection/yolo_v2.py +12 -12
- birder/net/detection/yolo_v3.py +19 -19
- birder/net/detection/yolo_v4.py +16 -16
- birder/net/detection/yolo_v4_tiny.py +3 -3
- birder/net/edgenext.py +3 -3
- birder/net/edgevit.py +10 -14
- birder/net/efficientformer_v1.py +1 -1
- birder/net/efficientvim.py +9 -9
- birder/net/efficientvit_mit.py +2 -2
- birder/net/efficientvit_msft.py +3 -3
- birder/net/fasternet.py +1 -1
- birder/net/fastvit.py +5 -12
- birder/net/flexivit.py +28 -15
- birder/net/focalnet.py +5 -9
- birder/net/gc_vit.py +11 -11
- birder/net/ghostnet_v1.py +1 -1
- birder/net/ghostnet_v2.py +1 -1
- birder/net/groupmixformer.py +12 -12
- birder/net/hgnet_v1.py +1 -1
- birder/net/hgnet_v2.py +4 -4
- birder/net/hiera.py +6 -6
- birder/net/hieradet.py +9 -9
- birder/net/hornet.py +3 -3
- birder/net/iformer.py +4 -4
- birder/net/inception_next.py +4 -14
- birder/net/levit.py +3 -3
- birder/net/lit_v1.py +13 -15
- birder/net/lit_v1_tiny.py +9 -9
- birder/net/lit_v2.py +14 -15
- birder/net/maxvit.py +10 -22
- birder/net/metaformer.py +2 -2
- birder/net/mim/crossmae.py +5 -5
- birder/net/mim/fcmae.py +3 -5
- birder/net/mim/mae_hiera.py +7 -7
- birder/net/mim/mae_vit.py +3 -5
- birder/net/mim/simmim.py +2 -3
- birder/net/mobilenet_v4_hybrid.py +4 -4
- birder/net/mobileone.py +5 -12
- birder/net/mobilevit_v1.py +2 -2
- birder/net/mobilevit_v2.py +5 -9
- birder/net/mvit_v2.py +24 -24
- birder/net/nextvit.py +2 -2
- birder/net/pit.py +11 -26
- birder/net/pvt_v1.py +4 -4
- birder/net/pvt_v2.py +5 -11
- birder/net/regionvit.py +15 -15
- birder/net/regnet.py +1 -1
- birder/net/repghost.py +4 -5
- birder/net/repvgg.py +3 -5
- birder/net/repvit.py +2 -2
- birder/net/resnest.py +1 -1
- birder/net/rope_deit3.py +29 -15
- birder/net/rope_flexivit.py +28 -15
- birder/net/rope_vit.py +41 -23
- birder/net/sequencer2d.py +3 -4
- birder/net/shufflenet_v1.py +1 -1
- birder/net/shufflenet_v2.py +1 -1
- birder/net/simple_vit.py +47 -5
- birder/net/smt.py +7 -7
- birder/net/ssl/barlow_twins.py +1 -1
- birder/net/ssl/byol.py +2 -2
- birder/net/ssl/capi.py +3 -3
- birder/net/ssl/data2vec2.py +1 -1
- birder/net/ssl/dino_v2.py +11 -1
- birder/net/ssl/franca.py +26 -2
- birder/net/ssl/i_jepa.py +4 -4
- birder/net/ssl/mmcr.py +1 -1
- birder/net/swiftformer.py +1 -1
- birder/net/swin_transformer_v1.py +4 -5
- birder/net/swin_transformer_v2.py +4 -7
- birder/net/tiny_vit.py +3 -3
- birder/net/transnext.py +19 -19
- birder/net/uniformer.py +4 -4
- birder/net/vgg.py +1 -10
- birder/net/vit.py +38 -25
- birder/net/vit_parallel.py +35 -20
- birder/net/vit_sam.py +10 -10
- birder/net/vovnet_v2.py +1 -1
- birder/net/xcit.py +9 -7
- birder/ops/msda.py +4 -4
- birder/ops/swattention.py +10 -10
- birder/results/classification.py +3 -3
- birder/results/gui.py +8 -8
- birder/scripts/benchmark.py +37 -12
- birder/scripts/evaluate.py +1 -1
- birder/scripts/predict.py +3 -3
- birder/scripts/predict_detection.py +2 -2
- birder/scripts/train.py +63 -15
- birder/scripts/train_barlow_twins.py +10 -7
- birder/scripts/train_byol.py +10 -7
- birder/scripts/train_capi.py +15 -10
- birder/scripts/train_data2vec.py +10 -7
- birder/scripts/train_data2vec2.py +10 -7
- birder/scripts/train_detection.py +29 -14
- birder/scripts/train_dino_v1.py +13 -9
- birder/scripts/train_dino_v2.py +27 -14
- birder/scripts/train_dino_v2_dist.py +28 -15
- birder/scripts/train_franca.py +16 -9
- birder/scripts/train_i_jepa.py +12 -9
- birder/scripts/train_ibot.py +15 -11
- birder/scripts/train_kd.py +64 -17
- birder/scripts/train_mim.py +11 -8
- birder/scripts/train_mmcr.py +11 -8
- birder/scripts/train_rotnet.py +11 -7
- birder/scripts/train_simclr.py +10 -7
- birder/scripts/train_vicreg.py +10 -7
- birder/tools/adversarial.py +4 -4
- birder/tools/auto_anchors.py +5 -5
- birder/tools/avg_model.py +1 -1
- birder/tools/convert_model.py +30 -22
- birder/tools/det_results.py +1 -1
- birder/tools/download_model.py +1 -1
- birder/tools/ensemble_model.py +1 -1
- birder/tools/introspection.py +11 -2
- birder/tools/labelme_to_coco.py +2 -2
- birder/tools/model_info.py +12 -14
- birder/tools/pack.py +8 -8
- birder/tools/quantize_model.py +53 -4
- birder/tools/results.py +2 -2
- birder/tools/show_det_iterator.py +19 -6
- birder/tools/show_iterator.py +2 -2
- birder/tools/similarity.py +5 -5
- birder/tools/stats.py +4 -6
- birder/tools/voc_to_coco.py +1 -1
- birder/version.py +1 -1
- {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/METADATA +3 -3
- birder-0.4.1.dist-info/RECORD +300 -0
- {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/WHEEL +1 -1
- birder-0.4.0.dist-info/RECORD +0 -297
- {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/entry_points.txt +0 -0
- {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/licenses/LICENSE +0 -0
- {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/top_level.txt +0 -0
birder/scripts/train_simclr.py
CHANGED
|
@@ -67,7 +67,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
67
67
|
#
|
|
68
68
|
# Initialize
|
|
69
69
|
#
|
|
70
|
-
|
|
70
|
+
device, device_id, disable_tqdm = training_utils.init_training(args, logger)
|
|
71
71
|
|
|
72
72
|
if args.size is None:
|
|
73
73
|
args.size = registry.get_default_size(args.network)
|
|
@@ -90,11 +90,11 @@ def train(args: argparse.Namespace) -> None:
|
|
|
90
90
|
elif args.wds is True:
|
|
91
91
|
wds_path: str | list[str]
|
|
92
92
|
if args.wds_info is not None:
|
|
93
|
-
|
|
93
|
+
wds_path, dataset_size = wds_args_from_info(args.wds_info, args.wds_split)
|
|
94
94
|
if args.wds_size is not None:
|
|
95
95
|
dataset_size = args.wds_size
|
|
96
96
|
else:
|
|
97
|
-
|
|
97
|
+
wds_path, dataset_size = prepare_wds_args(args.data_path[0], args.wds_size, device)
|
|
98
98
|
|
|
99
99
|
training_dataset = make_wds_dataset(
|
|
100
100
|
wds_path,
|
|
@@ -124,7 +124,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
124
124
|
|
|
125
125
|
# Data loaders and samplers
|
|
126
126
|
virtual_epoch_mode = args.steps_per_epoch is not None
|
|
127
|
-
|
|
127
|
+
train_sampler, _ = training_utils.get_samplers(
|
|
128
128
|
args, training_dataset, validation_dataset=None, infinite=virtual_epoch_mode
|
|
129
129
|
)
|
|
130
130
|
|
|
@@ -187,7 +187,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
187
187
|
|
|
188
188
|
network_name = get_mim_network_name("simclr", encoder=args.network, tag=args.tag)
|
|
189
189
|
|
|
190
|
-
backbone = registry.net_factory(args.network, sample_shape[1],
|
|
190
|
+
backbone = registry.net_factory(args.network, 0, sample_shape[1], config=args.model_config, size=args.size)
|
|
191
191
|
net = SimCLR(
|
|
192
192
|
backbone,
|
|
193
193
|
config={
|
|
@@ -199,7 +199,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
199
199
|
|
|
200
200
|
if args.resume_epoch is not None:
|
|
201
201
|
begin_epoch = args.resume_epoch + 1
|
|
202
|
-
|
|
202
|
+
net, training_states = fs_ops.load_simple_checkpoint(
|
|
203
203
|
device, net, network_name, epoch=args.resume_epoch, strict=not args.non_strict_weights
|
|
204
204
|
)
|
|
205
205
|
|
|
@@ -258,7 +258,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
258
258
|
optimizer.step = torch.compile(optimizer.step, fullgraph=False)
|
|
259
259
|
|
|
260
260
|
# Gradient scaler and AMP related tasks
|
|
261
|
-
|
|
261
|
+
scaler, amp_dtype = training_utils.get_amp_scaler(args.amp, args.amp_dtype)
|
|
262
262
|
|
|
263
263
|
# Load states
|
|
264
264
|
if args.load_states is True:
|
|
@@ -370,6 +370,9 @@ def train(args: argparse.Namespace) -> None:
|
|
|
370
370
|
tic = time.time()
|
|
371
371
|
net.train()
|
|
372
372
|
|
|
373
|
+
# Clear metrics
|
|
374
|
+
running_loss.clear()
|
|
375
|
+
|
|
373
376
|
if args.distributed is True or virtual_epoch_mode is True:
|
|
374
377
|
train_sampler.set_epoch(epoch)
|
|
375
378
|
|
birder/scripts/train_vicreg.py
CHANGED
|
@@ -70,7 +70,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
70
70
|
#
|
|
71
71
|
# Initialize
|
|
72
72
|
#
|
|
73
|
-
|
|
73
|
+
device, device_id, disable_tqdm = training_utils.init_training(args, logger)
|
|
74
74
|
|
|
75
75
|
if args.size is None:
|
|
76
76
|
args.size = registry.get_default_size(args.network)
|
|
@@ -93,11 +93,11 @@ def train(args: argparse.Namespace) -> None:
|
|
|
93
93
|
elif args.wds is True:
|
|
94
94
|
wds_path: str | list[str]
|
|
95
95
|
if args.wds_info is not None:
|
|
96
|
-
|
|
96
|
+
wds_path, dataset_size = wds_args_from_info(args.wds_info, args.wds_split)
|
|
97
97
|
if args.wds_size is not None:
|
|
98
98
|
dataset_size = args.wds_size
|
|
99
99
|
else:
|
|
100
|
-
|
|
100
|
+
wds_path, dataset_size = prepare_wds_args(args.data_path[0], args.wds_size, device)
|
|
101
101
|
|
|
102
102
|
training_dataset = make_wds_dataset(
|
|
103
103
|
wds_path,
|
|
@@ -127,7 +127,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
127
127
|
|
|
128
128
|
# Data loaders and samplers
|
|
129
129
|
virtual_epoch_mode = args.steps_per_epoch is not None
|
|
130
|
-
|
|
130
|
+
train_sampler, _ = training_utils.get_samplers(
|
|
131
131
|
args, training_dataset, validation_dataset=None, infinite=virtual_epoch_mode
|
|
132
132
|
)
|
|
133
133
|
|
|
@@ -190,7 +190,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
190
190
|
|
|
191
191
|
network_name = get_mim_network_name("vicreg", encoder=args.network, tag=args.tag)
|
|
192
192
|
|
|
193
|
-
backbone = registry.net_factory(args.network, sample_shape[1],
|
|
193
|
+
backbone = registry.net_factory(args.network, 0, sample_shape[1], config=args.model_config, size=args.size)
|
|
194
194
|
net = VICReg(
|
|
195
195
|
backbone,
|
|
196
196
|
config={
|
|
@@ -205,7 +205,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
205
205
|
|
|
206
206
|
if args.resume_epoch is not None:
|
|
207
207
|
begin_epoch = args.resume_epoch + 1
|
|
208
|
-
|
|
208
|
+
net, training_states = fs_ops.load_simple_checkpoint(
|
|
209
209
|
device, net, network_name, epoch=args.resume_epoch, strict=not args.non_strict_weights
|
|
210
210
|
)
|
|
211
211
|
|
|
@@ -264,7 +264,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
264
264
|
optimizer.step = torch.compile(optimizer.step, fullgraph=False)
|
|
265
265
|
|
|
266
266
|
# Gradient scaler and AMP related tasks
|
|
267
|
-
|
|
267
|
+
scaler, amp_dtype = training_utils.get_amp_scaler(args.amp, args.amp_dtype)
|
|
268
268
|
|
|
269
269
|
# Load states
|
|
270
270
|
if args.load_states is True:
|
|
@@ -376,6 +376,9 @@ def train(args: argparse.Namespace) -> None:
|
|
|
376
376
|
tic = time.time()
|
|
377
377
|
net.train()
|
|
378
378
|
|
|
379
|
+
# Clear metrics
|
|
380
|
+
running_loss.clear()
|
|
381
|
+
|
|
379
382
|
if args.distributed is True or virtual_epoch_mode is True:
|
|
380
383
|
train_sampler.set_epoch(epoch)
|
|
381
384
|
|
birder/tools/adversarial.py
CHANGED
|
@@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
|
|
|
29
29
|
def _load_model_and_transform(
|
|
30
30
|
args: argparse.Namespace, device: torch.device
|
|
31
31
|
) -> tuple[torch.nn.Module, dict[str, int], RGBType, Callable[..., torch.Tensor], Callable[..., torch.Tensor]]:
|
|
32
|
-
|
|
32
|
+
net, model_info = fs_ops.load_model(
|
|
33
33
|
device, args.network, tag=args.tag, epoch=args.epoch, inference=True, reparameterized=args.reparameterized
|
|
34
34
|
)
|
|
35
35
|
|
|
@@ -105,8 +105,8 @@ def _display_results(
|
|
|
105
105
|
success: Optional[bool],
|
|
106
106
|
result: AttackResult,
|
|
107
107
|
) -> None:
|
|
108
|
-
|
|
109
|
-
|
|
108
|
+
orig_label, orig_prob = original_pred
|
|
109
|
+
adv_label, adv_prob = adv_pred
|
|
110
110
|
|
|
111
111
|
# Log results
|
|
112
112
|
logger.info(f"Original: {orig_label} ({orig_prob * 100:.2f}%)")
|
|
@@ -139,7 +139,7 @@ def run_attack(args: argparse.Namespace) -> None:
|
|
|
139
139
|
|
|
140
140
|
logger.info(f"Using device {device}")
|
|
141
141
|
|
|
142
|
-
|
|
142
|
+
net, class_to_idx, rgb_stats, transform, reverse_transform = _load_model_and_transform(args, device)
|
|
143
143
|
label_names = [name for name, _idx in sorted(class_to_idx.items(), key=lambda item: item[1])]
|
|
144
144
|
img = Image.open(args.image_path)
|
|
145
145
|
input_tensor = transform(img).unsqueeze(dim=0).to(device)
|
birder/tools/auto_anchors.py
CHANGED
|
@@ -92,7 +92,7 @@ def _load_coco_boxes(
|
|
|
92
92
|
stats["missing_images"] += 1
|
|
93
93
|
continue
|
|
94
94
|
|
|
95
|
-
|
|
95
|
+
img_w, img_h, file_name = images[image_id]
|
|
96
96
|
if file_name in ignore_list:
|
|
97
97
|
stats["ignored_images"] += 1
|
|
98
98
|
continue
|
|
@@ -219,7 +219,7 @@ def _validate_args(
|
|
|
219
219
|
output_format = args.format if args.format is not None else (preset["format"] if preset else None)
|
|
220
220
|
if num_scales is None or num_anchors is None or output_format is None:
|
|
221
221
|
raise cli.ValidationError(
|
|
222
|
-
"Missing configuration. Provide --num-scales, --num-anchors
|
|
222
|
+
"Missing configuration. Provide --num-scales, --num-anchors and --format or use a --preset"
|
|
223
223
|
)
|
|
224
224
|
if num_scales < 1:
|
|
225
225
|
raise cli.ValidationError("--num-scales must be >= 1")
|
|
@@ -244,10 +244,10 @@ def _validate_args(
|
|
|
244
244
|
|
|
245
245
|
# pylint: disable=too-many-locals
|
|
246
246
|
def auto_anchors(args: argparse.Namespace) -> None:
|
|
247
|
-
|
|
247
|
+
size, num_scales, num_anchors, output_format, strides = _validate_args(args)
|
|
248
248
|
|
|
249
249
|
ignore_list = _load_ignore_list(args.ignore_file)
|
|
250
|
-
|
|
250
|
+
boxes, stats = _load_coco_boxes(
|
|
251
251
|
args.coco_json_path, size, ignore_list, args.min_size, ignore_crowd=not args.include_crowd
|
|
252
252
|
)
|
|
253
253
|
|
|
@@ -262,7 +262,7 @@ def auto_anchors(args: argparse.Namespace) -> None:
|
|
|
262
262
|
f"missing_size={stats['missing_size']}, too_small={stats['too_small']}"
|
|
263
263
|
)
|
|
264
264
|
|
|
265
|
-
|
|
265
|
+
anchors, _assignments = _kmeans_anchors(boxes, num_anchors, args.seed, args.max_iter)
|
|
266
266
|
areas = anchors.prod(dim=1)
|
|
267
267
|
anchors = anchors[torch.argsort(areas)]
|
|
268
268
|
anchors_per_scale = num_anchors // num_scales
|
birder/tools/avg_model.py
CHANGED
|
@@ -44,7 +44,7 @@ def avg_models(
|
|
|
44
44
|
num_classes = lib.get_num_labels_from_signature(signature)
|
|
45
45
|
size = lib.get_size_from_signature(signature)
|
|
46
46
|
|
|
47
|
-
net = registry.net_factory(network,
|
|
47
|
+
net = registry.net_factory(network, num_classes, input_channels, size=size)
|
|
48
48
|
if reparameterized is True:
|
|
49
49
|
net.reparameterize_model()
|
|
50
50
|
|
birder/tools/convert_model.py
CHANGED
|
@@ -74,6 +74,7 @@ def onnx_export(
|
|
|
74
74
|
net: torch.nn.Module,
|
|
75
75
|
signature: SignatureType | DetectionSignatureType,
|
|
76
76
|
class_to_idx: dict[str, int],
|
|
77
|
+
rgb_stats: RGBType,
|
|
77
78
|
model_path: str | Path,
|
|
78
79
|
dynamo: bool,
|
|
79
80
|
trace: bool,
|
|
@@ -117,9 +118,19 @@ def onnx_export(
|
|
|
117
118
|
|
|
118
119
|
signature["inputs"][0]["data_shape"][0] = 0
|
|
119
120
|
|
|
120
|
-
logger.info("Saving
|
|
121
|
-
with open(f"{model_path}
|
|
122
|
-
json.dump(
|
|
121
|
+
logger.info("Saving model data json...")
|
|
122
|
+
with open(f"{model_path}_data.json", "w", encoding="utf-8") as handle:
|
|
123
|
+
json.dump(
|
|
124
|
+
{
|
|
125
|
+
"birder_version": __version__,
|
|
126
|
+
"task": net.task,
|
|
127
|
+
"class_to_idx": class_to_idx,
|
|
128
|
+
"signature": signature,
|
|
129
|
+
"rgb_stats": rgb_stats,
|
|
130
|
+
},
|
|
131
|
+
handle,
|
|
132
|
+
indent=2,
|
|
133
|
+
)
|
|
123
134
|
|
|
124
135
|
# Test exported model
|
|
125
136
|
onnx_model = onnx.load(str(model_path))
|
|
@@ -238,7 +249,7 @@ def main(args: argparse.Namespace) -> None:
|
|
|
238
249
|
signature: SignatureType | DetectionSignatureType
|
|
239
250
|
backbone_custom_config = None
|
|
240
251
|
if args.backbone is None:
|
|
241
|
-
|
|
252
|
+
net, (class_to_idx, signature, rgb_stats, custom_config) = fs_ops.load_model(
|
|
242
253
|
device,
|
|
243
254
|
args.network,
|
|
244
255
|
config=args.model_config,
|
|
@@ -251,22 +262,20 @@ def main(args: argparse.Namespace) -> None:
|
|
|
251
262
|
network_name = lib.get_network_name(args.network, tag=args.tag)
|
|
252
263
|
|
|
253
264
|
else:
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
export_mode=True,
|
|
269
|
-
)
|
|
265
|
+
net, (class_to_idx, signature, rgb_stats, custom_config, backbone_custom_config) = fs_ops.load_detection_model(
|
|
266
|
+
device,
|
|
267
|
+
args.network,
|
|
268
|
+
config=args.model_config,
|
|
269
|
+
tag=args.tag,
|
|
270
|
+
reparameterized=args.reparameterized,
|
|
271
|
+
backbone=args.backbone,
|
|
272
|
+
backbone_config=args.backbone_model_config,
|
|
273
|
+
backbone_tag=args.backbone_tag,
|
|
274
|
+
backbone_reparameterized=args.backbone_reparameterized,
|
|
275
|
+
epoch=args.epoch,
|
|
276
|
+
new_size=args.resize,
|
|
277
|
+
inference=True,
|
|
278
|
+
export_mode=True,
|
|
270
279
|
)
|
|
271
280
|
network_name = lib.get_detection_network_name(
|
|
272
281
|
args.network, tag=args.tag, backbone=args.backbone, backbone_tag=args.backbone_tag
|
|
@@ -407,8 +416,7 @@ def main(args: argparse.Namespace) -> None:
|
|
|
407
416
|
)
|
|
408
417
|
|
|
409
418
|
elif args.onnx is True or args.onnx_dynamo is True:
|
|
410
|
-
|
|
411
|
-
onnx_export(net, signature, class_to_idx, model_path, args.onnx_dynamo, args.trace)
|
|
419
|
+
onnx_export(net, signature, class_to_idx, rgb_stats, model_path, args.onnx_dynamo, args.trace)
|
|
412
420
|
|
|
413
421
|
elif args.config is True:
|
|
414
422
|
config_export(net, signature, rgb_stats, model_path)
|
birder/tools/det_results.py
CHANGED
|
@@ -239,7 +239,7 @@ def main(args: argparse.Namespace) -> None:
|
|
|
239
239
|
logger.warning("Cannot compare confusion matrix, processing only the first file")
|
|
240
240
|
|
|
241
241
|
results = next(iter(results_dict.values()))
|
|
242
|
-
|
|
242
|
+
cnf_matrix, label_names = confusion_matrix_data(
|
|
243
243
|
results, args.cnf_score_threshold, args.cnf_iou_threshold, args.classes, args.cnf_errors_only
|
|
244
244
|
)
|
|
245
245
|
title = f"Confusion matrix (score >= {args.cnf_score_threshold:.2f}, IoU >= {args.cnf_iou_threshold:.2f})"
|
birder/tools/download_model.py
CHANGED
|
@@ -52,7 +52,7 @@ def main(args: argparse.Namespace) -> None:
|
|
|
52
52
|
)
|
|
53
53
|
raise SystemExit(1)
|
|
54
54
|
|
|
55
|
-
|
|
55
|
+
model_file, url = get_pretrained_model_url(args.model_name, args.format)
|
|
56
56
|
dst = settings.MODELS_DIR.joinpath(model_file)
|
|
57
57
|
if dst.exists() is True and args.force is False:
|
|
58
58
|
logger.warning(f"File {model_file} already exists... aborting")
|
birder/tools/ensemble_model.py
CHANGED
|
@@ -58,7 +58,7 @@ def main(args: argparse.Namespace) -> None:
|
|
|
58
58
|
signature_list = []
|
|
59
59
|
rgb_stats_list = []
|
|
60
60
|
for network in args.networks:
|
|
61
|
-
|
|
61
|
+
net, model_info = fs_ops.load_model(device, network, inference=True, pts=args.pts, pt2=args.pt2)
|
|
62
62
|
nets.append(net)
|
|
63
63
|
class_to_idx_list.append(model_info.class_to_idx)
|
|
64
64
|
signature_list.append(model_info.signature)
|
birder/tools/introspection.py
CHANGED
|
@@ -126,6 +126,14 @@ def set_parser(subparsers: Any) -> None:
|
|
|
126
126
|
formatter_class=cli.ArgumentHelpFormatter,
|
|
127
127
|
)
|
|
128
128
|
subparser.add_argument("-n", "--network", type=str, required=True, help="the neural network to use")
|
|
129
|
+
subparser.add_argument(
|
|
130
|
+
"--model-config",
|
|
131
|
+
action=cli.FlexibleDictAction,
|
|
132
|
+
help=(
|
|
133
|
+
"override the model default configuration, accepts key-value pairs or JSON "
|
|
134
|
+
"('drop_path_rate=0.2' or '{\"units\": [3, 24, 36, 3], \"dropout\": 0.2}'"
|
|
135
|
+
),
|
|
136
|
+
)
|
|
129
137
|
subparser.add_argument("-e", "--epoch", type=int, metavar="N", help="model checkpoint to load")
|
|
130
138
|
subparser.add_argument("-t", "--tag", type=str, help="model tag (from the training phase)")
|
|
131
139
|
subparser.add_argument(
|
|
@@ -145,7 +153,7 @@ def set_parser(subparsers: Any) -> None:
|
|
|
145
153
|
subparser.add_argument(
|
|
146
154
|
"--target",
|
|
147
155
|
type=str,
|
|
148
|
-
help="target class, leave empty to use predicted class (gradcam, guided-backprop
|
|
156
|
+
help="target class, leave empty to use predicted class (gradcam, guided-backprop and transformer-attribution)",
|
|
149
157
|
)
|
|
150
158
|
subparser.add_argument("--block-name", type=str, default="body", help="target block (gradcam only)")
|
|
151
159
|
subparser.add_argument(
|
|
@@ -203,9 +211,10 @@ def main(args: argparse.Namespace) -> None:
|
|
|
203
211
|
|
|
204
212
|
logger.info(f"Using device {device}")
|
|
205
213
|
|
|
206
|
-
|
|
214
|
+
net, model_info = fs_ops.load_model(
|
|
207
215
|
device,
|
|
208
216
|
args.network,
|
|
217
|
+
config=args.model_config,
|
|
209
218
|
tag=args.tag,
|
|
210
219
|
epoch=args.epoch,
|
|
211
220
|
new_size=args.size,
|
birder/tools/labelme_to_coco.py
CHANGED
birder/tools/model_info.py
CHANGED
|
@@ -73,7 +73,7 @@ def main(args: argparse.Namespace) -> None:
|
|
|
73
73
|
signature: SignatureType | DetectionSignatureType
|
|
74
74
|
backbone_custom_config = None
|
|
75
75
|
if args.backbone is None:
|
|
76
|
-
|
|
76
|
+
net, (class_to_idx, signature, rgb_stats, custom_config) = fs_ops.load_model(
|
|
77
77
|
device,
|
|
78
78
|
args.network,
|
|
79
79
|
tag=args.tag,
|
|
@@ -86,19 +86,17 @@ def main(args: argparse.Namespace) -> None:
|
|
|
86
86
|
)
|
|
87
87
|
|
|
88
88
|
else:
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
st=args.st,
|
|
101
|
-
)
|
|
89
|
+
net, (class_to_idx, signature, rgb_stats, custom_config, backbone_custom_config) = fs_ops.load_detection_model(
|
|
90
|
+
device,
|
|
91
|
+
args.network,
|
|
92
|
+
tag=args.tag,
|
|
93
|
+
backbone=args.backbone,
|
|
94
|
+
backbone_tag=args.backbone_tag,
|
|
95
|
+
epoch=args.epoch,
|
|
96
|
+
inference=True,
|
|
97
|
+
pts=args.pts,
|
|
98
|
+
pt2=args.pt2,
|
|
99
|
+
st=args.st,
|
|
102
100
|
)
|
|
103
101
|
|
|
104
102
|
model_info = get_model_info(net)
|
birder/tools/pack.py
CHANGED
|
@@ -114,7 +114,7 @@ def read_worker(q_in: Any, q_out: Any, error_event: Any, size: Optional[int], fi
|
|
|
114
114
|
break
|
|
115
115
|
|
|
116
116
|
try:
|
|
117
|
-
|
|
117
|
+
idx, path, target = deq
|
|
118
118
|
if size is None:
|
|
119
119
|
suffix = Path(path).suffix[1:]
|
|
120
120
|
if file_format != suffix:
|
|
@@ -172,7 +172,7 @@ def wds_write_worker(
|
|
|
172
172
|
while more:
|
|
173
173
|
deq: Optional[tuple[int, bytes, str, int]] = q_out.get()
|
|
174
174
|
if deq is not None:
|
|
175
|
-
|
|
175
|
+
idx, sample, suffix, target = deq
|
|
176
176
|
buf[idx] = (sample, suffix, target)
|
|
177
177
|
|
|
178
178
|
else:
|
|
@@ -180,7 +180,7 @@ def wds_write_worker(
|
|
|
180
180
|
|
|
181
181
|
# Ensures ordered write
|
|
182
182
|
while count in buf:
|
|
183
|
-
|
|
183
|
+
sample, suffix, target = buf[count]
|
|
184
184
|
del buf[count]
|
|
185
185
|
|
|
186
186
|
if args.no_cls is True:
|
|
@@ -238,7 +238,7 @@ def directory_write_worker(
|
|
|
238
238
|
while more:
|
|
239
239
|
deq: Optional[tuple[int, bytes, str, int]] = q_out.get()
|
|
240
240
|
if deq is not None:
|
|
241
|
-
|
|
241
|
+
idx, sample, suffix, target = deq
|
|
242
242
|
buf[idx] = (sample, suffix, target)
|
|
243
243
|
|
|
244
244
|
else:
|
|
@@ -246,7 +246,7 @@ def directory_write_worker(
|
|
|
246
246
|
|
|
247
247
|
# Ensures ordered write
|
|
248
248
|
while count in buf:
|
|
249
|
-
|
|
249
|
+
sample, suffix, target = buf[count]
|
|
250
250
|
del buf[count]
|
|
251
251
|
with open(
|
|
252
252
|
pack_path.joinpath(idx_to_class[target]).joinpath(f"{count:06d}.{suffix}"), "wb"
|
|
@@ -274,7 +274,7 @@ def pack(args: argparse.Namespace, pack_path: Path) -> None:
|
|
|
274
274
|
if len(line.strip()) == 0 or line.strip().startswith("#") is True:
|
|
275
275
|
continue
|
|
276
276
|
|
|
277
|
-
|
|
277
|
+
data_path, r = line.split()
|
|
278
278
|
data_path = os.path.expanduser(data_path)
|
|
279
279
|
repeats = int(r)
|
|
280
280
|
for _ in range(repeats):
|
|
@@ -391,7 +391,7 @@ def pack(args: argparse.Namespace, pack_path: Path) -> None:
|
|
|
391
391
|
cleanup_processes()
|
|
392
392
|
raise RuntimeError()
|
|
393
393
|
|
|
394
|
-
|
|
394
|
+
path, target = dataset[sample_idx]
|
|
395
395
|
|
|
396
396
|
while True:
|
|
397
397
|
try:
|
|
@@ -430,7 +430,7 @@ def pack(args: argparse.Namespace, pack_path: Path) -> None:
|
|
|
430
430
|
raise RuntimeError()
|
|
431
431
|
|
|
432
432
|
if args.type == "wds":
|
|
433
|
-
|
|
433
|
+
wds_path, num_shards = fs_ops.wds_braces_from_path(pack_path, prefix=f"{args.suffix}-{args.split}")
|
|
434
434
|
logger.info(f"Packed {len(dataset):,} samples into {num_shards} shards at {wds_path}")
|
|
435
435
|
elif args.type == "directory":
|
|
436
436
|
logger.info(f"Packed {len(dataset):,} samples")
|
birder/tools/quantize_model.py
CHANGED
|
@@ -1,7 +1,9 @@
|
|
|
1
1
|
import argparse
|
|
2
2
|
import itertools
|
|
3
|
+
import json
|
|
3
4
|
import logging
|
|
4
5
|
import time
|
|
6
|
+
from pathlib import Path
|
|
5
7
|
from typing import Any
|
|
6
8
|
|
|
7
9
|
import torch
|
|
@@ -15,7 +17,11 @@ from birder.common import fs_ops
|
|
|
15
17
|
from birder.common import lib
|
|
16
18
|
from birder.common.lib import get_network_name
|
|
17
19
|
from birder.conf import settings
|
|
20
|
+
from birder.data.transforms.classification import RGBType
|
|
18
21
|
from birder.data.transforms.classification import inference_preset
|
|
22
|
+
from birder.net.base import SignatureType
|
|
23
|
+
from birder.net.detection.base import DetectionSignatureType
|
|
24
|
+
from birder.version import __version__
|
|
19
25
|
|
|
20
26
|
try:
|
|
21
27
|
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e
|
|
@@ -28,8 +34,10 @@ except ImportError:
|
|
|
28
34
|
_HAS_TORCHAO = False
|
|
29
35
|
|
|
30
36
|
try:
|
|
37
|
+
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
|
|
31
38
|
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import XNNPACKQuantizer
|
|
32
39
|
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import get_symmetric_quantization_config
|
|
40
|
+
from executorch.exir import to_edge_transform_and_lower
|
|
33
41
|
|
|
34
42
|
_HAS_EXECUTORCH = True
|
|
35
43
|
except ImportError:
|
|
@@ -54,6 +62,33 @@ def _build_quantizer(backend: str) -> Any:
|
|
|
54
62
|
raise ValueError(f"Unsupported backend: {backend}")
|
|
55
63
|
|
|
56
64
|
|
|
65
|
+
def _save_pte(
|
|
66
|
+
exported_net: torch.export.ExportedProgram,
|
|
67
|
+
dst: str | Path,
|
|
68
|
+
task: str,
|
|
69
|
+
class_to_idx: dict[str, int],
|
|
70
|
+
signature: SignatureType | DetectionSignatureType,
|
|
71
|
+
rgb_stats: RGBType,
|
|
72
|
+
) -> None:
|
|
73
|
+
edge_program = to_edge_transform_and_lower(exported_net, partitioner=[XnnpackPartitioner()])
|
|
74
|
+
executorch_program = edge_program.to_executorch()
|
|
75
|
+
with open(dst, "wb") as f:
|
|
76
|
+
f.write(executorch_program.buffer)
|
|
77
|
+
|
|
78
|
+
with open(f"{dst}_data.json", "w", encoding="utf-8") as handle:
|
|
79
|
+
json.dump(
|
|
80
|
+
{
|
|
81
|
+
"birder_version": __version__,
|
|
82
|
+
"task": task,
|
|
83
|
+
"class_to_idx": class_to_idx,
|
|
84
|
+
"signature": signature,
|
|
85
|
+
"rgb_stats": rgb_stats,
|
|
86
|
+
},
|
|
87
|
+
handle,
|
|
88
|
+
indent=2,
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
|
|
57
92
|
def set_parser(subparsers: Any) -> None:
|
|
58
93
|
subparser = subparsers.add_parser(
|
|
59
94
|
"quantize-model",
|
|
@@ -65,6 +100,7 @@ def set_parser(subparsers: Any) -> None:
|
|
|
65
100
|
"python -m birder.tools quantize-model -n convnext_v2_tiny -t eu-common\n"
|
|
66
101
|
"python -m birder.tools quantize-model --network densenet_121 -e 100 --num-calibration-batches 256\n"
|
|
67
102
|
"python -m birder.tools quantize-model -n efficientnet_v2_s -e 200 --qbackend xnnpack --batch-size 1\n"
|
|
103
|
+
"python -m birder.tools quantize-model -n hgnet_v2_b4 --qbackend xnnpack --pte\n"
|
|
68
104
|
),
|
|
69
105
|
formatter_class=cli.ArgumentHelpFormatter,
|
|
70
106
|
)
|
|
@@ -81,6 +117,9 @@ def set_parser(subparsers: Any) -> None:
|
|
|
81
117
|
subparser.add_argument(
|
|
82
118
|
"--qbackend", type=str, choices=["x86", "xnnpack"], default="x86", help="quantization backend"
|
|
83
119
|
)
|
|
120
|
+
subparser.add_argument(
|
|
121
|
+
"--pte", default=False, action="store_true", help="lower quantized model to ExecuTorch PTE format"
|
|
122
|
+
)
|
|
84
123
|
subparser.add_argument("--batch-size", type=int, default=1, metavar="N", help="the batch size")
|
|
85
124
|
subparser.add_argument(
|
|
86
125
|
"--num-calibration-batches",
|
|
@@ -96,8 +135,13 @@ def set_parser(subparsers: Any) -> None:
|
|
|
96
135
|
|
|
97
136
|
# pylint: disable=too-many-locals
|
|
98
137
|
def main(args: argparse.Namespace) -> None:
|
|
138
|
+
if args.pte is True and args.qbackend != "xnnpack":
|
|
139
|
+
raise cli.ValidationError("--pte requires --qbackend xnnpack")
|
|
140
|
+
|
|
99
141
|
network_name = get_network_name(args.network, tag=args.tag)
|
|
100
142
|
model_path = fs_ops.model_path(network_name, epoch=args.epoch, quantized=True, pt2=True)
|
|
143
|
+
if args.pte is True:
|
|
144
|
+
model_path = model_path.with_suffix(".pte")
|
|
101
145
|
if model_path.exists() is True and args.force is False:
|
|
102
146
|
logger.warning("Quantized model already exists... aborting")
|
|
103
147
|
raise SystemExit(1)
|
|
@@ -105,7 +149,7 @@ def main(args: argparse.Namespace) -> None:
|
|
|
105
149
|
device = torch.device("cpu")
|
|
106
150
|
|
|
107
151
|
# Load model
|
|
108
|
-
|
|
152
|
+
net, (class_to_idx, signature, rgb_stats, *_) = fs_ops.load_model(
|
|
109
153
|
device, args.network, tag=args.tag, epoch=args.epoch, inference=True, reparameterized=args.reparameterized
|
|
110
154
|
)
|
|
111
155
|
net.eval()
|
|
@@ -154,9 +198,14 @@ def main(args: argparse.Namespace) -> None:
|
|
|
154
198
|
exported_quantized_net = torch.export.export(quantized_net, example_inputs)
|
|
155
199
|
|
|
156
200
|
toc = time.time()
|
|
157
|
-
|
|
201
|
+
minutes, seconds = divmod(toc - tic, 60)
|
|
158
202
|
logger.info(f"{int(minutes):0>2}m{seconds:04.1f}s to quantize model")
|
|
159
203
|
|
|
160
204
|
model_path = fs_ops.model_path(network_name, epoch=args.epoch, quantized=True, pt2=True)
|
|
161
|
-
|
|
162
|
-
|
|
205
|
+
if args.pte is True:
|
|
206
|
+
model_path = model_path.with_suffix(".pte")
|
|
207
|
+
logger.info(f"Lowering quantized model to PTE {model_path}...")
|
|
208
|
+
_save_pte(exported_quantized_net, model_path, task, class_to_idx, signature, rgb_stats)
|
|
209
|
+
else:
|
|
210
|
+
logger.info(f"Saving quantized PT2 model {model_path}...")
|
|
211
|
+
fs_ops.save_pt2(exported_quantized_net, model_path, task, class_to_idx, signature, rgb_stats)
|
birder/tools/results.py
CHANGED
|
@@ -125,7 +125,7 @@ def print_most_confused_pairs(most_confused_df: pl.DataFrame) -> None:
|
|
|
125
125
|
|
|
126
126
|
def convert_to_sparse(results_file: str, sparse_k: int) -> None:
|
|
127
127
|
logger.info(f"Converting {results_file} to sparse format (k={sparse_k})...")
|
|
128
|
-
|
|
128
|
+
_, detected_sparse_k = detect_file_format(results_file)
|
|
129
129
|
|
|
130
130
|
if detected_sparse_k is not None:
|
|
131
131
|
logger.info(f"File is already in sparse format (with k={detected_sparse_k}). Skipping conversion.")
|
|
@@ -233,7 +233,7 @@ def main(args: argparse.Namespace) -> None:
|
|
|
233
233
|
logger.warning("Cannot print mistakes in compare mode. processing only the first file")
|
|
234
234
|
|
|
235
235
|
if args.imperfect_only is True:
|
|
236
|
-
|
|
236
|
+
result_name, results = next(iter(results_dict.items()))
|
|
237
237
|
mistake_prediction_indices = results.mistakes["prediction"].unique().to_numpy().tolist()
|
|
238
238
|
mistake_label_indices = results.mistakes["label"].unique().to_numpy().tolist()
|
|
239
239
|
imperfect_class_indices = np.unique(mistake_prediction_indices + mistake_label_indices).tolist()
|