sleap-nn 0.1.0__py3-none-any.whl → 0.1.0a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. sleap_nn/__init__.py +2 -4
  2. sleap_nn/architectures/convnext.py +0 -5
  3. sleap_nn/architectures/encoder_decoder.py +6 -25
  4. sleap_nn/architectures/swint.py +0 -8
  5. sleap_nn/cli.py +60 -364
  6. sleap_nn/config/data_config.py +5 -11
  7. sleap_nn/config/get_config.py +4 -10
  8. sleap_nn/config/trainer_config.py +0 -76
  9. sleap_nn/data/augmentation.py +241 -50
  10. sleap_nn/data/custom_datasets.py +39 -411
  11. sleap_nn/data/instance_cropping.py +1 -1
  12. sleap_nn/data/resizing.py +2 -2
  13. sleap_nn/data/utils.py +17 -135
  14. sleap_nn/evaluation.py +22 -81
  15. sleap_nn/inference/bottomup.py +20 -86
  16. sleap_nn/inference/peak_finding.py +19 -88
  17. sleap_nn/inference/predictors.py +117 -224
  18. sleap_nn/legacy_models.py +11 -65
  19. sleap_nn/predict.py +9 -37
  20. sleap_nn/train.py +4 -74
  21. sleap_nn/training/callbacks.py +105 -1046
  22. sleap_nn/training/lightning_modules.py +65 -602
  23. sleap_nn/training/model_trainer.py +184 -211
  24. {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a0.dist-info}/METADATA +3 -15
  25. sleap_nn-0.1.0a0.dist-info/RECORD +65 -0
  26. {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a0.dist-info}/WHEEL +1 -1
  27. sleap_nn/data/skia_augmentation.py +0 -414
  28. sleap_nn/export/__init__.py +0 -21
  29. sleap_nn/export/cli.py +0 -1778
  30. sleap_nn/export/exporters/__init__.py +0 -51
  31. sleap_nn/export/exporters/onnx_exporter.py +0 -80
  32. sleap_nn/export/exporters/tensorrt_exporter.py +0 -291
  33. sleap_nn/export/metadata.py +0 -225
  34. sleap_nn/export/predictors/__init__.py +0 -63
  35. sleap_nn/export/predictors/base.py +0 -22
  36. sleap_nn/export/predictors/onnx.py +0 -154
  37. sleap_nn/export/predictors/tensorrt.py +0 -312
  38. sleap_nn/export/utils.py +0 -307
  39. sleap_nn/export/wrappers/__init__.py +0 -25
  40. sleap_nn/export/wrappers/base.py +0 -96
  41. sleap_nn/export/wrappers/bottomup.py +0 -243
  42. sleap_nn/export/wrappers/bottomup_multiclass.py +0 -195
  43. sleap_nn/export/wrappers/centered_instance.py +0 -56
  44. sleap_nn/export/wrappers/centroid.py +0 -58
  45. sleap_nn/export/wrappers/single_instance.py +0 -83
  46. sleap_nn/export/wrappers/topdown.py +0 -180
  47. sleap_nn/export/wrappers/topdown_multiclass.py +0 -304
  48. sleap_nn/inference/postprocessing.py +0 -284
  49. sleap_nn/training/schedulers.py +0 -191
  50. sleap_nn-0.1.0.dist-info/RECORD +0 -88
  51. {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a0.dist-info}/entry_points.txt +0 -0
  52. {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a0.dist-info}/licenses/LICENSE +0 -0
  53. {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a0.dist-info}/top_level.txt +0 -0
@@ -56,8 +56,6 @@ from rich.progress import (
56
56
  MofNCompleteColumn,
57
57
  )
58
58
  from time import time
59
- import json
60
- import sys
61
59
 
62
60
 
63
61
  def _filter_user_labeled_frames(
@@ -135,8 +133,6 @@ class Predictor(ABC):
135
133
  `backbone_config`. This determines the downsampling factor applied by the backbone,
136
134
  and is used to ensure that input images are padded or resized to be compatible
137
135
  with the model's architecture. Default: 16.
138
- gui: If True, outputs JSON progress lines for GUI integration instead of
139
- Rich progress bars. Default: False.
140
136
  """
141
137
 
142
138
  preprocess: bool = True
@@ -156,7 +152,6 @@ class Predictor(ABC):
156
152
  ] = None
157
153
  instances_key: bool = False
158
154
  max_stride: int = 16
159
- gui: bool = False
160
155
 
161
156
  @classmethod
162
157
  def from_model_paths(
@@ -360,7 +355,7 @@ class Predictor(ABC):
360
355
  def make_pipeline(
361
356
  self,
362
357
  data_path: str,
363
- queue_maxsize: int = 32,
358
+ queue_maxsize: int = 8,
364
359
  frames: Optional[list] = None,
365
360
  only_labeled_frames: bool = False,
366
361
  only_suggested_frames: bool = False,
@@ -386,102 +381,6 @@ class Predictor(ABC):
386
381
  v[n] = v[n].cpu().numpy()
387
382
  return output
388
383
 
389
- def _process_batch(self) -> tuple:
390
- """Process a single batch of frames from the pipeline.
391
-
392
- Returns:
393
- Tuple of (imgs, fidxs, vidxs, org_szs, instances, eff_scales, done)
394
- where done is True if the pipeline has finished.
395
- """
396
- imgs = []
397
- fidxs = []
398
- vidxs = []
399
- org_szs = []
400
- instances = []
401
- eff_scales = []
402
- done = False
403
-
404
- for _ in range(self.batch_size):
405
- frame = self.pipeline.frame_buffer.get()
406
- if frame["image"] is None:
407
- done = True
408
- break
409
- frame["image"], eff_scale = apply_sizematcher(
410
- frame["image"],
411
- self.preprocess_config["max_height"],
412
- self.preprocess_config["max_width"],
413
- )
414
- if self.instances_key:
415
- frame["instances"] = frame["instances"] * eff_scale
416
- if self.preprocess_config["ensure_rgb"] and frame["image"].shape[-3] != 3:
417
- frame["image"] = frame["image"].repeat(1, 3, 1, 1)
418
- elif (
419
- self.preprocess_config["ensure_grayscale"]
420
- and frame["image"].shape[-3] != 1
421
- ):
422
- frame["image"] = F.rgb_to_grayscale(
423
- frame["image"], num_output_channels=1
424
- )
425
-
426
- eff_scales.append(torch.tensor(eff_scale))
427
- imgs.append(frame["image"].unsqueeze(dim=0))
428
- fidxs.append(frame["frame_idx"])
429
- vidxs.append(frame["video_idx"])
430
- org_szs.append(frame["orig_size"].unsqueeze(dim=0))
431
- if self.instances_key:
432
- instances.append(frame["instances"].unsqueeze(dim=0))
433
-
434
- return imgs, fidxs, vidxs, org_szs, instances, eff_scales, done
435
-
436
- def _run_inference_on_batch(
437
- self, imgs, fidxs, vidxs, org_szs, instances, eff_scales
438
- ) -> Iterator[Dict[str, np.ndarray]]:
439
- """Run inference on a prepared batch of frames.
440
-
441
- Args:
442
- imgs: List of image tensors.
443
- fidxs: List of frame indices.
444
- vidxs: List of video indices.
445
- org_szs: List of original sizes.
446
- instances: List of instance tensors.
447
- eff_scales: List of effective scales.
448
-
449
- Yields:
450
- Dictionaries containing inference results for each frame.
451
- """
452
- # TODO: all preprocessing should be moved into InferenceModels to be exportable.
453
- imgs = torch.concatenate(imgs, dim=0)
454
- fidxs = torch.tensor(fidxs, dtype=torch.int32)
455
- vidxs = torch.tensor(vidxs, dtype=torch.int32)
456
- org_szs = torch.concatenate(org_szs, dim=0)
457
- eff_scales = torch.tensor(eff_scales, dtype=torch.float32)
458
- if self.instances_key:
459
- instances = torch.concatenate(instances, dim=0)
460
- ex = {
461
- "image": imgs,
462
- "frame_idx": fidxs,
463
- "video_idx": vidxs,
464
- "orig_size": org_szs,
465
- "eff_scale": eff_scales,
466
- }
467
- if self.instances_key:
468
- ex["instances"] = instances
469
- if self.preprocess:
470
- scale = self.preprocess_config["scale"]
471
- if scale != 1.0:
472
- if self.instances_key:
473
- ex["image"], ex["instances"] = apply_resizer(
474
- ex["image"], ex["instances"]
475
- )
476
- else:
477
- ex["image"] = resize_image(ex["image"], scale)
478
- ex["image"] = apply_pad_to_stride(ex["image"], self.max_stride)
479
- outputs_list = self.inference_model(ex)
480
- if outputs_list is not None:
481
- for output in outputs_list:
482
- output = self._convert_tensors_to_numpy(output)
483
- yield output
484
-
485
384
  def _predict_generator(self) -> Iterator[Dict[str, np.ndarray]]:
486
385
  """Create a generator that yields batches of inference results.
487
386
 
@@ -501,14 +400,114 @@ class Predictor(ABC):
501
400
  # Loop over data batches.
502
401
  self.pipeline.start()
503
402
  total_frames = self.pipeline.total_len()
403
+ done = False
504
404
 
505
405
  try:
506
- if self.gui:
507
- # GUI mode: emit JSON progress lines
508
- yield from self._predict_generator_gui(total_frames)
509
- else:
510
- # Normal mode: use Rich progress bar
511
- yield from self._predict_generator_rich(total_frames)
406
+ with Progress(
407
+ "{task.description}",
408
+ BarColumn(),
409
+ "[progress.percentage]{task.percentage:>3.0f}%",
410
+ MofNCompleteColumn(),
411
+ "ETA:",
412
+ TimeRemainingColumn(),
413
+ "Elapsed:",
414
+ TimeElapsedColumn(),
415
+ RateColumn(),
416
+ auto_refresh=False,
417
+ refresh_per_second=4, # Change to self.report_rate if needed
418
+ speed_estimate_period=5,
419
+ ) as progress:
420
+ task = progress.add_task("Predicting...", total=total_frames)
421
+ last_report = time()
422
+
423
+ done = False
424
+ while not done:
425
+ imgs = []
426
+ fidxs = []
427
+ vidxs = []
428
+ org_szs = []
429
+ instances = []
430
+ eff_scales = []
431
+ for _ in range(self.batch_size):
432
+ frame = self.pipeline.frame_buffer.get()
433
+ if frame["image"] is None:
434
+ done = True
435
+ break
436
+ frame["image"], eff_scale = apply_sizematcher(
437
+ frame["image"],
438
+ self.preprocess_config["max_height"],
439
+ self.preprocess_config["max_width"],
440
+ )
441
+ if self.instances_key:
442
+ frame["instances"] = frame["instances"] * eff_scale
443
+ if (
444
+ self.preprocess_config["ensure_rgb"]
445
+ and frame["image"].shape[-3] != 3
446
+ ):
447
+ frame["image"] = frame["image"].repeat(1, 3, 1, 1)
448
+ elif (
449
+ self.preprocess_config["ensure_grayscale"]
450
+ and frame["image"].shape[-3] != 1
451
+ ):
452
+ frame["image"] = F.rgb_to_grayscale(
453
+ frame["image"], num_output_channels=1
454
+ )
455
+
456
+ eff_scales.append(torch.tensor(eff_scale))
457
+ imgs.append(frame["image"].unsqueeze(dim=0))
458
+ fidxs.append(frame["frame_idx"])
459
+ vidxs.append(frame["video_idx"])
460
+ org_szs.append(frame["orig_size"].unsqueeze(dim=0))
461
+ if self.instances_key:
462
+ instances.append(frame["instances"].unsqueeze(dim=0))
463
+ if imgs:
464
+ # TODO: all preprocessing should be moved into InferenceModels to be exportable.
465
+ imgs = torch.concatenate(imgs, dim=0)
466
+ fidxs = torch.tensor(fidxs, dtype=torch.int32)
467
+ vidxs = torch.tensor(vidxs, dtype=torch.int32)
468
+ org_szs = torch.concatenate(org_szs, dim=0)
469
+ eff_scales = torch.tensor(eff_scales, dtype=torch.float32)
470
+ if self.instances_key:
471
+ instances = torch.concatenate(instances, dim=0)
472
+ ex = {
473
+ "image": imgs,
474
+ "frame_idx": fidxs,
475
+ "video_idx": vidxs,
476
+ "orig_size": org_szs,
477
+ "eff_scale": eff_scales,
478
+ }
479
+ if self.instances_key:
480
+ ex["instances"] = instances
481
+ if self.preprocess:
482
+ scale = self.preprocess_config["scale"]
483
+ if scale != 1.0:
484
+ if self.instances_key:
485
+ ex["image"], ex["instances"] = apply_resizer(
486
+ ex["image"], ex["instances"]
487
+ )
488
+ else:
489
+ ex["image"] = resize_image(ex["image"], scale)
490
+ ex["image"] = apply_pad_to_stride(
491
+ ex["image"], self.max_stride
492
+ )
493
+ outputs_list = self.inference_model(ex)
494
+ if outputs_list is not None:
495
+ for output in outputs_list:
496
+ output = self._convert_tensors_to_numpy(output)
497
+ yield output
498
+
499
+ # Advance progress
500
+ num_frames = (
501
+ len(ex["frame_idx"])
502
+ if "frame_idx" in ex
503
+ else self.batch_size
504
+ )
505
+ progress.update(task, advance=num_frames)
506
+
507
+ # Manually refresh progress bar
508
+ if time() - last_report > 0.25:
509
+ progress.refresh()
510
+ last_report = time()
512
511
 
513
512
  except KeyboardInterrupt:
514
513
  logger.info("Inference interrupted by user")
@@ -521,112 +520,6 @@ class Predictor(ABC):
521
520
 
522
521
  self.pipeline.join()
523
522
 
524
- def _predict_generator_gui(
525
- self, total_frames: int
526
- ) -> Iterator[Dict[str, np.ndarray]]:
527
- """Generator for GUI mode with JSON progress output.
528
-
529
- Args:
530
- total_frames: Total number of frames to process.
531
-
532
- Yields:
533
- Dictionaries containing inference results for each frame.
534
- """
535
- start_time = time()
536
- frames_processed = 0
537
- last_report = time()
538
- done = False
539
-
540
- while not done:
541
- imgs, fidxs, vidxs, org_szs, instances, eff_scales, done = (
542
- self._process_batch()
543
- )
544
-
545
- if imgs:
546
- yield from self._run_inference_on_batch(
547
- imgs, fidxs, vidxs, org_szs, instances, eff_scales
548
- )
549
-
550
- # Update progress
551
- num_frames = len(fidxs)
552
- frames_processed += num_frames
553
-
554
- # Emit JSON progress (throttled to ~4Hz)
555
- if time() - last_report > 0.25:
556
- elapsed = time() - start_time
557
- rate = frames_processed / elapsed if elapsed > 0 else 0
558
- remaining = total_frames - frames_processed
559
- eta = remaining / rate if rate > 0 else 0
560
-
561
- progress_data = {
562
- "n_processed": frames_processed,
563
- "n_total": total_frames,
564
- "rate": round(rate, 1),
565
- "eta": round(eta, 1),
566
- }
567
- print(json.dumps(progress_data), flush=True)
568
- last_report = time()
569
-
570
- # Final progress emit to ensure 100% is shown
571
- elapsed = time() - start_time
572
- progress_data = {
573
- "n_processed": total_frames,
574
- "n_total": total_frames,
575
- "rate": round(frames_processed / elapsed, 1) if elapsed > 0 else 0,
576
- "eta": 0,
577
- }
578
- print(json.dumps(progress_data), flush=True)
579
-
580
- def _predict_generator_rich(
581
- self, total_frames: int
582
- ) -> Iterator[Dict[str, np.ndarray]]:
583
- """Generator for normal mode with Rich progress bar.
584
-
585
- Args:
586
- total_frames: Total number of frames to process.
587
-
588
- Yields:
589
- Dictionaries containing inference results for each frame.
590
- """
591
- with Progress(
592
- "{task.description}",
593
- BarColumn(),
594
- "[progress.percentage]{task.percentage:>3.0f}%",
595
- MofNCompleteColumn(),
596
- "ETA:",
597
- TimeRemainingColumn(),
598
- "Elapsed:",
599
- TimeElapsedColumn(),
600
- RateColumn(),
601
- auto_refresh=False,
602
- refresh_per_second=4,
603
- speed_estimate_period=5,
604
- ) as progress:
605
- task = progress.add_task("Predicting...", total=total_frames)
606
- last_report = time()
607
- done = False
608
-
609
- while not done:
610
- imgs, fidxs, vidxs, org_szs, instances, eff_scales, done = (
611
- self._process_batch()
612
- )
613
-
614
- if imgs:
615
- yield from self._run_inference_on_batch(
616
- imgs, fidxs, vidxs, org_szs, instances, eff_scales
617
- )
618
-
619
- # Advance progress
620
- num_frames = len(fidxs)
621
- progress.update(task, advance=num_frames)
622
-
623
- # Manually refresh progress bar
624
- if time() - last_report > 0.25:
625
- progress.refresh()
626
- last_report = time()
627
-
628
- self.pipeline.join()
629
-
630
523
  def predict(
631
524
  self,
632
525
  make_labels: bool = True,
@@ -1214,7 +1107,7 @@ class TopDownPredictor(Predictor):
1214
1107
  def make_pipeline(
1215
1108
  self,
1216
1109
  inference_object: Union[str, Path, sio.Labels, sio.Video],
1217
- queue_maxsize: int = 32,
1110
+ queue_maxsize: int = 8,
1218
1111
  frames: Optional[list] = None,
1219
1112
  only_labeled_frames: bool = False,
1220
1113
  only_suggested_frames: bool = False,
@@ -1228,7 +1121,7 @@ class TopDownPredictor(Predictor):
1228
1121
 
1229
1122
  Args:
1230
1123
  inference_object: (str) Path to `.slp` file or `.mp4` or sio.Labels or sio.Video to run inference on.
1231
- queue_maxsize: (int) Maximum size of the frame buffer queue. Default: 32.
1124
+ queue_maxsize: (int) Maximum size of the frame buffer queue. Default: 8.
1232
1125
  frames: (list) List of frames indices. If `None`, all frames in the video are used. Default: None.
1233
1126
  only_labeled_frames: (bool) `True` if inference should be run only on user-labeled frames. Default: `False`.
1234
1127
  only_suggested_frames: (bool) `True` if inference should be run only on unlabeled suggested frames. Default: `False`.
@@ -1644,7 +1537,7 @@ class SingleInstancePredictor(Predictor):
1644
1537
  def make_pipeline(
1645
1538
  self,
1646
1539
  inference_object: Union[str, Path, sio.Labels, sio.Video],
1647
- queue_maxsize: int = 32,
1540
+ queue_maxsize: int = 8,
1648
1541
  frames: Optional[list] = None,
1649
1542
  only_labeled_frames: bool = False,
1650
1543
  only_suggested_frames: bool = False,
@@ -1658,7 +1551,7 @@ class SingleInstancePredictor(Predictor):
1658
1551
 
1659
1552
  Args:
1660
1553
  inference_object: (str) Path to `.slp` file or `.mp4` or sio.Labels or sio.Video to run inference on.
1661
- queue_maxsize: (int) Maximum size of the frame buffer queue. Default: 32.
1554
+ queue_maxsize: (int) Maximum size of the frame buffer queue. Default: 8.
1662
1555
  frames: List of frames indices. If `None`, all frames in the video are used. Default: None.
1663
1556
  only_labeled_frames: (bool) `True` if inference should be run only on user-labeled frames. Default: `False`.
1664
1557
  only_suggested_frames: (bool) `True` if inference should be run only on unlabeled suggested frames. Default: `False`.
@@ -2094,7 +1987,7 @@ class BottomUpPredictor(Predictor):
2094
1987
  def make_pipeline(
2095
1988
  self,
2096
1989
  inference_object: Union[str, Path, sio.Labels, sio.Video],
2097
- queue_maxsize: int = 32,
1990
+ queue_maxsize: int = 8,
2098
1991
  frames: Optional[list] = None,
2099
1992
  only_labeled_frames: bool = False,
2100
1993
  only_suggested_frames: bool = False,
@@ -2108,7 +2001,7 @@ class BottomUpPredictor(Predictor):
2108
2001
 
2109
2002
  Args:
2110
2003
  inference_object: (str) Path to `.slp` file or `.mp4` or sio.Labels or sio.Video to run inference on.
2111
- queue_maxsize: (int) Maximum size of the frame buffer queue. Default: 32.
2004
+ queue_maxsize: (int) Maximum size of the frame buffer queue. Default: 8.
2112
2005
  frames: List of frames indices. If `None`, all frames in the video are used. Default: None.
2113
2006
  only_labeled_frames: (bool) `True` if inference should be run only on user-labeled frames. Default: `False`.
2114
2007
  only_suggested_frames: (bool) `True` if inference should be run only on unlabeled suggested frames. Default: `False`.
@@ -2541,7 +2434,7 @@ class BottomUpMultiClassPredictor(Predictor):
2541
2434
  def make_pipeline(
2542
2435
  self,
2543
2436
  inference_object: Union[str, Path, sio.Labels, sio.Video],
2544
- queue_maxsize: int = 32,
2437
+ queue_maxsize: int = 8,
2545
2438
  frames: Optional[list] = None,
2546
2439
  only_labeled_frames: bool = False,
2547
2440
  only_suggested_frames: bool = False,
@@ -2555,7 +2448,7 @@ class BottomUpMultiClassPredictor(Predictor):
2555
2448
 
2556
2449
  Args:
2557
2450
  inference_object: (str) Path to `.slp` file or `.mp4` or sio.Labels or sio.Video to run inference on.
2558
- queue_maxsize: (int) Maximum size of the frame buffer queue. Default: 32.
2451
+ queue_maxsize: (int) Maximum size of the frame buffer queue. Default: 8.
2559
2452
  frames: List of frames indices. If `None`, all frames in the video are used. Default: None.
2560
2453
  only_labeled_frames: (bool) `True` if inference should be run only on user-labeled frames. Default: `False`.
2561
2454
  only_suggested_frames: (bool) `True` if inference should be run only on unlabeled suggested frames. Default: `False`.
@@ -3296,7 +3189,7 @@ class TopDownMultiClassPredictor(Predictor):
3296
3189
  def make_pipeline(
3297
3190
  self,
3298
3191
  inference_object: Union[str, Path, sio.Labels, sio.Video],
3299
- queue_maxsize: int = 32,
3192
+ queue_maxsize: int = 8,
3300
3193
  frames: Optional[list] = None,
3301
3194
  only_labeled_frames: bool = False,
3302
3195
  only_suggested_frames: bool = False,
@@ -3310,7 +3203,7 @@ class TopDownMultiClassPredictor(Predictor):
3310
3203
 
3311
3204
  Args:
3312
3205
  inference_object: (str) Path to `.slp` file or `.mp4` or sio.Labels or sio.Video to run inference on.
3313
- queue_maxsize: (int) Maximum size of the frame buffer queue. Default: 32.
3206
+ queue_maxsize: (int) Maximum size of the frame buffer queue. Default: 8.
3314
3207
  frames: (list) List of frames indices. If `None`, all frames in the video are used. Default: None.
3315
3208
  only_labeled_frames: (bool) `True` if inference should be run only on user-labeled frames. Default: `False`.
3316
3209
  only_suggested_frames: (bool) `True` if inference should be run only on unlabeled suggested frames. Default: `False`.
sleap_nn/legacy_models.py CHANGED
@@ -7,8 +7,9 @@ TensorFlow/Keras backend to PyTorch format compatible with sleap-nn.
7
7
  import h5py
8
8
  import numpy as np
9
9
  import torch
10
- from typing import Dict, Any, Optional
10
+ from typing import Dict, Tuple, Any, Optional, List
11
11
  from pathlib import Path
12
+ from omegaconf import OmegaConf
12
13
  import re
13
14
  from loguru import logger
14
15
 
@@ -180,61 +181,18 @@ def parse_keras_layer_name(layer_path: str) -> Dict[str, Any]:
180
181
  return info
181
182
 
182
183
 
183
- def filter_legacy_weights_by_component(
184
- legacy_weights: Dict[str, np.ndarray], component: Optional[str]
185
- ) -> Dict[str, np.ndarray]:
186
- """Filter legacy weights based on component type.
187
-
188
- Args:
189
- legacy_weights: Dictionary of legacy weights from load_keras_weights()
190
- component: Component type to filter for. One of:
191
- - "backbone": Keep only encoder/decoder weights (exclude heads)
192
- - "head": Keep only head layer weights
193
- - None: No filtering (keep all weights)
194
-
195
- Returns:
196
- Filtered dictionary of legacy weights
197
- """
198
- if component is None:
199
- return legacy_weights
200
-
201
- filtered = {}
202
- for path, weight in legacy_weights.items():
203
- # Check if this is a head layer (contains "Head" in the path)
204
- is_head_layer = "Head" in path
205
-
206
- if component == "backbone" and not is_head_layer:
207
- filtered[path] = weight
208
- elif component == "head" and is_head_layer:
209
- filtered[path] = weight
210
-
211
- return filtered
212
-
213
-
214
184
  def map_legacy_to_pytorch_layers(
215
- legacy_weights: Dict[str, np.ndarray],
216
- pytorch_model: torch.nn.Module,
217
- component: Optional[str] = None,
185
+ legacy_weights: Dict[str, np.ndarray], pytorch_model: torch.nn.Module
218
186
  ) -> Dict[str, str]:
219
187
  """Create mapping between legacy Keras layers and PyTorch model layers.
220
188
 
221
189
  Args:
222
190
  legacy_weights: Dictionary of legacy weights from load_keras_weights()
223
191
  pytorch_model: PyTorch model instance to map to
224
- component: Optional component type for filtering weights before mapping.
225
- One of "backbone", "head", or None (no filtering).
226
192
 
227
193
  Returns:
228
194
  Dictionary mapping legacy layer paths to PyTorch parameter names
229
195
  """
230
- # Filter weights based on component type
231
- filtered_weights = filter_legacy_weights_by_component(legacy_weights, component)
232
-
233
- if component is not None:
234
- logger.info(
235
- f"Filtered legacy weights for {component}: "
236
- f"{len(filtered_weights)}/{len(legacy_weights)} weights"
237
- )
238
196
  mapping = {}
239
197
 
240
198
  # Get all PyTorch parameters with their shapes
@@ -243,7 +201,7 @@ def map_legacy_to_pytorch_layers(
243
201
  pytorch_params[name] = param.shape
244
202
 
245
203
  # For each legacy weight, find the corresponding PyTorch parameter
246
- for legacy_path, weight in filtered_weights.items():
204
+ for legacy_path, weight in legacy_weights.items():
247
205
  # Extract the layer name from the legacy path
248
206
  # Legacy path format: "model_weights/stack0_enc0_conv0/stack0_enc0_conv0/kernel:0"
249
207
  clean_path = legacy_path.replace("model_weights/", "")
@@ -262,6 +220,8 @@ def map_legacy_to_pytorch_layers(
262
220
  # This handles cases where Keras uses suffixes like _0, _1, etc.
263
221
  if "Head" in layer_name:
264
222
  # Remove trailing _N where N is a number
223
+ import re
224
+
265
225
  layer_name_clean = re.sub(r"_\d+$", "", layer_name)
266
226
  else:
267
227
  layer_name_clean = layer_name
@@ -306,17 +266,12 @@ def map_legacy_to_pytorch_layers(
306
266
  if not mapping:
307
267
  logger.info(
308
268
  f"No mappings could be created between legacy weights and PyTorch model. "
309
- f"Legacy weights: {len(filtered_weights)}, PyTorch parameters: {len(pytorch_params)}"
269
+ f"Legacy weights: {len(legacy_weights)}, PyTorch parameters: {len(pytorch_params)}"
310
270
  )
311
271
  else:
312
272
  logger.info(
313
- f"Successfully mapped {len(mapping)}/{len(pytorch_params)} PyTorch parameters from legacy weights"
273
+ f"Successfully mapped {len(mapping)}/{len(legacy_weights)} legacy weights to PyTorch parameters"
314
274
  )
315
- unmatched_count = len(filtered_weights) - len(mapping)
316
- if unmatched_count > 0:
317
- logger.warning(
318
- f"({unmatched_count} legacy weights did not match any parameters in this model component)"
319
- )
320
275
 
321
276
  return mapping
322
277
 
@@ -325,7 +280,6 @@ def load_legacy_model_weights(
325
280
  pytorch_model: torch.nn.Module,
326
281
  h5_path: str,
327
282
  mapping: Optional[Dict[str, str]] = None,
328
- component: Optional[str] = None,
329
283
  ) -> None:
330
284
  """Load legacy Keras weights into a PyTorch model.
331
285
 
@@ -334,10 +288,6 @@ def load_legacy_model_weights(
334
288
  h5_path: Path to the legacy .h5 model file
335
289
  mapping: Optional manual mapping of layer names. If None,
336
290
  will attempt automatic mapping.
337
- component: Optional component type for filtering weights. One of:
338
- - "backbone": Only load encoder/decoder weights (exclude heads)
339
- - "head": Only load head layer weights
340
- - None: Load all weights (default, for full model loading)
341
291
  """
342
292
  # Load legacy weights
343
293
  legacy_weights = load_keras_weights(h5_path)
@@ -345,9 +295,7 @@ def load_legacy_model_weights(
345
295
  if mapping is None:
346
296
  # Attempt automatic mapping
347
297
  try:
348
- mapping = map_legacy_to_pytorch_layers(
349
- legacy_weights, pytorch_model, component=component
350
- )
298
+ mapping = map_legacy_to_pytorch_layers(legacy_weights, pytorch_model)
351
299
  except Exception as e:
352
300
  logger.error(f"Failed to create weight mappings: {e}")
353
301
  return
@@ -469,9 +417,7 @@ def load_legacy_model_weights(
469
417
  ).item()
470
418
  diff = abs(keras_mean - torch_mean)
471
419
  if diff > 1e-6:
472
- message = f"Weight verification failed for {pytorch_name} (linear): keras={keras_mean:.6f}, torch={torch_mean:.6f}, diff={diff:.6e}"
473
- logger.error(message)
474
- verification_errors.append(message)
420
+ message = f"Weight verification failed for {pytorch_name} linear): keras={keras_mean:.6f}, torch={torch_mean:.6f}, diff={diff:.6e}"
475
421
  else:
476
422
  # Bias : just compare all values
477
423
  keras_mean = np.mean(original_weight)
@@ -480,7 +426,7 @@ def load_legacy_model_weights(
480
426
  ).item()
481
427
  diff = abs(keras_mean - torch_mean)
482
428
  if diff > 1e-6:
483
- message = f"Weight verification failed for {pytorch_name} (bias): keras={keras_mean:.6f}, torch={torch_mean:.6f}, diff={diff:.6e}"
429
+ message = f"Weight verification failed for {pytorch_name} bias): keras={keras_mean:.6f}, torch={torch_mean:.6f}, diff={diff:.6e}"
484
430
  logger.error(message)
485
431
  verification_errors.append(message)
486
432