matrice-analytics 0.1.60__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (196) hide show
  1. matrice_analytics/__init__.py +28 -0
  2. matrice_analytics/boundary_drawing_internal/README.md +305 -0
  3. matrice_analytics/boundary_drawing_internal/__init__.py +45 -0
  4. matrice_analytics/boundary_drawing_internal/boundary_drawing_internal.py +1207 -0
  5. matrice_analytics/boundary_drawing_internal/boundary_drawing_tool.py +429 -0
  6. matrice_analytics/boundary_drawing_internal/boundary_tool_template.html +1036 -0
  7. matrice_analytics/boundary_drawing_internal/data/.gitignore +12 -0
  8. matrice_analytics/boundary_drawing_internal/example_usage.py +206 -0
  9. matrice_analytics/boundary_drawing_internal/usage/README.md +110 -0
  10. matrice_analytics/boundary_drawing_internal/usage/boundary_drawer_launcher.py +102 -0
  11. matrice_analytics/boundary_drawing_internal/usage/simple_boundary_launcher.py +107 -0
  12. matrice_analytics/post_processing/README.md +455 -0
  13. matrice_analytics/post_processing/__init__.py +732 -0
  14. matrice_analytics/post_processing/advanced_tracker/README.md +650 -0
  15. matrice_analytics/post_processing/advanced_tracker/__init__.py +17 -0
  16. matrice_analytics/post_processing/advanced_tracker/base.py +99 -0
  17. matrice_analytics/post_processing/advanced_tracker/config.py +77 -0
  18. matrice_analytics/post_processing/advanced_tracker/kalman_filter.py +370 -0
  19. matrice_analytics/post_processing/advanced_tracker/matching.py +195 -0
  20. matrice_analytics/post_processing/advanced_tracker/strack.py +230 -0
  21. matrice_analytics/post_processing/advanced_tracker/tracker.py +367 -0
  22. matrice_analytics/post_processing/config.py +146 -0
  23. matrice_analytics/post_processing/core/__init__.py +63 -0
  24. matrice_analytics/post_processing/core/base.py +704 -0
  25. matrice_analytics/post_processing/core/config.py +3291 -0
  26. matrice_analytics/post_processing/core/config_utils.py +925 -0
  27. matrice_analytics/post_processing/face_reg/__init__.py +43 -0
  28. matrice_analytics/post_processing/face_reg/compare_similarity.py +556 -0
  29. matrice_analytics/post_processing/face_reg/embedding_manager.py +950 -0
  30. matrice_analytics/post_processing/face_reg/face_recognition.py +2234 -0
  31. matrice_analytics/post_processing/face_reg/face_recognition_client.py +606 -0
  32. matrice_analytics/post_processing/face_reg/people_activity_logging.py +321 -0
  33. matrice_analytics/post_processing/ocr/__init__.py +0 -0
  34. matrice_analytics/post_processing/ocr/easyocr_extractor.py +250 -0
  35. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/__init__.py +9 -0
  36. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/__init__.py +4 -0
  37. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/cli.py +33 -0
  38. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/dataset_stats.py +139 -0
  39. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/export.py +398 -0
  40. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/train.py +447 -0
  41. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/utils.py +129 -0
  42. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/valid.py +93 -0
  43. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/validate_dataset.py +240 -0
  44. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/visualize_augmentation.py +176 -0
  45. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/visualize_predictions.py +96 -0
  46. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/core/__init__.py +3 -0
  47. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/core/process.py +246 -0
  48. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/core/types.py +60 -0
  49. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/core/utils.py +87 -0
  50. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/inference/__init__.py +3 -0
  51. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/inference/config.py +82 -0
  52. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/inference/hub.py +141 -0
  53. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/inference/plate_recognizer.py +323 -0
  54. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/py.typed +0 -0
  55. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/__init__.py +0 -0
  56. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/data/__init__.py +0 -0
  57. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/data/augmentation.py +101 -0
  58. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/data/dataset.py +97 -0
  59. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/__init__.py +0 -0
  60. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/config.py +114 -0
  61. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/layers.py +553 -0
  62. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/loss.py +55 -0
  63. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/metric.py +86 -0
  64. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/model_builders.py +95 -0
  65. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/model_schema.py +395 -0
  66. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/utilities/__init__.py +0 -0
  67. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/utilities/backend_utils.py +38 -0
  68. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/utilities/utils.py +214 -0
  69. matrice_analytics/post_processing/ocr/postprocessing.py +270 -0
  70. matrice_analytics/post_processing/ocr/preprocessing.py +52 -0
  71. matrice_analytics/post_processing/post_processor.py +1175 -0
  72. matrice_analytics/post_processing/test_cases/__init__.py +1 -0
  73. matrice_analytics/post_processing/test_cases/run_tests.py +143 -0
  74. matrice_analytics/post_processing/test_cases/test_advanced_customer_service.py +841 -0
  75. matrice_analytics/post_processing/test_cases/test_basic_counting_tracking.py +523 -0
  76. matrice_analytics/post_processing/test_cases/test_comprehensive.py +531 -0
  77. matrice_analytics/post_processing/test_cases/test_config.py +852 -0
  78. matrice_analytics/post_processing/test_cases/test_customer_service.py +585 -0
  79. matrice_analytics/post_processing/test_cases/test_data_generators.py +583 -0
  80. matrice_analytics/post_processing/test_cases/test_people_counting.py +510 -0
  81. matrice_analytics/post_processing/test_cases/test_processor.py +524 -0
  82. matrice_analytics/post_processing/test_cases/test_usecases.py +165 -0
  83. matrice_analytics/post_processing/test_cases/test_utilities.py +356 -0
  84. matrice_analytics/post_processing/test_cases/test_utils.py +743 -0
  85. matrice_analytics/post_processing/usecases/Histopathological_Cancer_Detection_img.py +604 -0
  86. matrice_analytics/post_processing/usecases/__init__.py +267 -0
  87. matrice_analytics/post_processing/usecases/abandoned_object_detection.py +797 -0
  88. matrice_analytics/post_processing/usecases/advanced_customer_service.py +1601 -0
  89. matrice_analytics/post_processing/usecases/age_detection.py +842 -0
  90. matrice_analytics/post_processing/usecases/age_gender_detection.py +1085 -0
  91. matrice_analytics/post_processing/usecases/anti_spoofing_detection.py +656 -0
  92. matrice_analytics/post_processing/usecases/assembly_line_detection.py +841 -0
  93. matrice_analytics/post_processing/usecases/banana_defect_detection.py +624 -0
  94. matrice_analytics/post_processing/usecases/basic_counting_tracking.py +667 -0
  95. matrice_analytics/post_processing/usecases/blood_cancer_detection_img.py +881 -0
  96. matrice_analytics/post_processing/usecases/car_damage_detection.py +834 -0
  97. matrice_analytics/post_processing/usecases/car_part_segmentation.py +946 -0
  98. matrice_analytics/post_processing/usecases/car_service.py +1601 -0
  99. matrice_analytics/post_processing/usecases/cardiomegaly_classification.py +864 -0
  100. matrice_analytics/post_processing/usecases/cell_microscopy_segmentation.py +897 -0
  101. matrice_analytics/post_processing/usecases/chicken_pose_detection.py +648 -0
  102. matrice_analytics/post_processing/usecases/child_monitoring.py +814 -0
  103. matrice_analytics/post_processing/usecases/color/clip.py +660 -0
  104. matrice_analytics/post_processing/usecases/color/clip_processor/merges.txt +48895 -0
  105. matrice_analytics/post_processing/usecases/color/clip_processor/preprocessor_config.json +28 -0
  106. matrice_analytics/post_processing/usecases/color/clip_processor/special_tokens_map.json +30 -0
  107. matrice_analytics/post_processing/usecases/color/clip_processor/tokenizer.json +245079 -0
  108. matrice_analytics/post_processing/usecases/color/clip_processor/tokenizer_config.json +32 -0
  109. matrice_analytics/post_processing/usecases/color/clip_processor/vocab.json +1 -0
  110. matrice_analytics/post_processing/usecases/color/color_map_utils.py +70 -0
  111. matrice_analytics/post_processing/usecases/color/color_mapper.py +468 -0
  112. matrice_analytics/post_processing/usecases/color_detection.py +1936 -0
  113. matrice_analytics/post_processing/usecases/color_map_utils.py +70 -0
  114. matrice_analytics/post_processing/usecases/concrete_crack_detection.py +827 -0
  115. matrice_analytics/post_processing/usecases/crop_weed_detection.py +781 -0
  116. matrice_analytics/post_processing/usecases/customer_service.py +1008 -0
  117. matrice_analytics/post_processing/usecases/defect_detection_products.py +936 -0
  118. matrice_analytics/post_processing/usecases/distracted_driver_detection.py +822 -0
  119. matrice_analytics/post_processing/usecases/drone_traffic_monitoring.py +585 -0
  120. matrice_analytics/post_processing/usecases/drowsy_driver_detection.py +829 -0
  121. matrice_analytics/post_processing/usecases/dwell_detection.py +829 -0
  122. matrice_analytics/post_processing/usecases/emergency_vehicle_detection.py +827 -0
  123. matrice_analytics/post_processing/usecases/face_emotion.py +813 -0
  124. matrice_analytics/post_processing/usecases/face_recognition.py +827 -0
  125. matrice_analytics/post_processing/usecases/fashion_detection.py +835 -0
  126. matrice_analytics/post_processing/usecases/field_mapping.py +902 -0
  127. matrice_analytics/post_processing/usecases/fire_detection.py +1146 -0
  128. matrice_analytics/post_processing/usecases/flare_analysis.py +836 -0
  129. matrice_analytics/post_processing/usecases/flower_segmentation.py +1006 -0
  130. matrice_analytics/post_processing/usecases/gas_leak_detection.py +837 -0
  131. matrice_analytics/post_processing/usecases/gender_detection.py +832 -0
  132. matrice_analytics/post_processing/usecases/human_activity_recognition.py +871 -0
  133. matrice_analytics/post_processing/usecases/intrusion_detection.py +1672 -0
  134. matrice_analytics/post_processing/usecases/leaf.py +821 -0
  135. matrice_analytics/post_processing/usecases/leaf_disease.py +840 -0
  136. matrice_analytics/post_processing/usecases/leak_detection.py +837 -0
  137. matrice_analytics/post_processing/usecases/license_plate_detection.py +1188 -0
  138. matrice_analytics/post_processing/usecases/license_plate_monitoring.py +1781 -0
  139. matrice_analytics/post_processing/usecases/litter_monitoring.py +717 -0
  140. matrice_analytics/post_processing/usecases/mask_detection.py +869 -0
  141. matrice_analytics/post_processing/usecases/natural_disaster.py +907 -0
  142. matrice_analytics/post_processing/usecases/parking.py +787 -0
  143. matrice_analytics/post_processing/usecases/parking_space_detection.py +822 -0
  144. matrice_analytics/post_processing/usecases/pcb_defect_detection.py +888 -0
  145. matrice_analytics/post_processing/usecases/pedestrian_detection.py +808 -0
  146. matrice_analytics/post_processing/usecases/people_counting.py +706 -0
  147. matrice_analytics/post_processing/usecases/people_counting_bckp.py +1683 -0
  148. matrice_analytics/post_processing/usecases/people_tracking.py +1842 -0
  149. matrice_analytics/post_processing/usecases/pipeline_detection.py +605 -0
  150. matrice_analytics/post_processing/usecases/plaque_segmentation_img.py +874 -0
  151. matrice_analytics/post_processing/usecases/pothole_segmentation.py +915 -0
  152. matrice_analytics/post_processing/usecases/ppe_compliance.py +645 -0
  153. matrice_analytics/post_processing/usecases/price_tag_detection.py +822 -0
  154. matrice_analytics/post_processing/usecases/proximity_detection.py +1901 -0
  155. matrice_analytics/post_processing/usecases/road_lane_detection.py +623 -0
  156. matrice_analytics/post_processing/usecases/road_traffic_density.py +832 -0
  157. matrice_analytics/post_processing/usecases/road_view_segmentation.py +915 -0
  158. matrice_analytics/post_processing/usecases/shelf_inventory_detection.py +583 -0
  159. matrice_analytics/post_processing/usecases/shoplifting_detection.py +822 -0
  160. matrice_analytics/post_processing/usecases/shopping_cart_analysis.py +899 -0
  161. matrice_analytics/post_processing/usecases/skin_cancer_classification_img.py +864 -0
  162. matrice_analytics/post_processing/usecases/smoker_detection.py +833 -0
  163. matrice_analytics/post_processing/usecases/solar_panel.py +810 -0
  164. matrice_analytics/post_processing/usecases/suspicious_activity_detection.py +1030 -0
  165. matrice_analytics/post_processing/usecases/template_usecase.py +380 -0
  166. matrice_analytics/post_processing/usecases/theft_detection.py +648 -0
  167. matrice_analytics/post_processing/usecases/traffic_sign_monitoring.py +724 -0
  168. matrice_analytics/post_processing/usecases/underground_pipeline_defect_detection.py +775 -0
  169. matrice_analytics/post_processing/usecases/underwater_pollution_detection.py +842 -0
  170. matrice_analytics/post_processing/usecases/vehicle_monitoring.py +1029 -0
  171. matrice_analytics/post_processing/usecases/warehouse_object_segmentation.py +899 -0
  172. matrice_analytics/post_processing/usecases/waterbody_segmentation.py +923 -0
  173. matrice_analytics/post_processing/usecases/weapon_detection.py +771 -0
  174. matrice_analytics/post_processing/usecases/weld_defect_detection.py +615 -0
  175. matrice_analytics/post_processing/usecases/wildlife_monitoring.py +898 -0
  176. matrice_analytics/post_processing/usecases/windmill_maintenance.py +834 -0
  177. matrice_analytics/post_processing/usecases/wound_segmentation.py +856 -0
  178. matrice_analytics/post_processing/utils/__init__.py +150 -0
  179. matrice_analytics/post_processing/utils/advanced_counting_utils.py +400 -0
  180. matrice_analytics/post_processing/utils/advanced_helper_utils.py +317 -0
  181. matrice_analytics/post_processing/utils/advanced_tracking_utils.py +461 -0
  182. matrice_analytics/post_processing/utils/alerting_utils.py +213 -0
  183. matrice_analytics/post_processing/utils/category_mapping_utils.py +94 -0
  184. matrice_analytics/post_processing/utils/color_utils.py +592 -0
  185. matrice_analytics/post_processing/utils/counting_utils.py +182 -0
  186. matrice_analytics/post_processing/utils/filter_utils.py +261 -0
  187. matrice_analytics/post_processing/utils/format_utils.py +293 -0
  188. matrice_analytics/post_processing/utils/geometry_utils.py +300 -0
  189. matrice_analytics/post_processing/utils/smoothing_utils.py +358 -0
  190. matrice_analytics/post_processing/utils/tracking_utils.py +234 -0
  191. matrice_analytics/py.typed +0 -0
  192. matrice_analytics-0.1.60.dist-info/METADATA +481 -0
  193. matrice_analytics-0.1.60.dist-info/RECORD +196 -0
  194. matrice_analytics-0.1.60.dist-info/WHEEL +5 -0
  195. matrice_analytics-0.1.60.dist-info/licenses/LICENSE.txt +21 -0
  196. matrice_analytics-0.1.60.dist-info/top_level.txt +1 -0
@@ -0,0 +1,95 @@
1
+ """
2
+ Model builder functions for supported architectures.
3
+ """
4
+
5
+ from collections.abc import Sequence
6
+
7
+ import keras
8
+ import numpy as np
9
+ from keras import layers
10
+
11
+ from fast_plate_ocr.train.model.config import PlateOCRConfig
12
+ from fast_plate_ocr.train.model.layers import (
13
+ PatchExtractor,
14
+ PositionEmbedding,
15
+ TokenReducer,
16
+ TransformerBlock,
17
+ VocabularyProjection,
18
+ )
19
+ from fast_plate_ocr.train.model.model_schema import AnyModelConfig, CCTModelConfig, LayerConfig
20
+
21
+
22
+ def _build_stem_from_config(specs: Sequence[LayerConfig]) -> keras.Sequential:
23
+ return keras.Sequential([spec.to_keras_layer() for spec in specs], name="conv_stem")
24
+
25
+
26
+ def _build_cct_model(
27
+ cfg: CCTModelConfig,
28
+ input_shape: tuple[int, int, int],
29
+ max_plate_slots: int,
30
+ vocabulary_size: int,
31
+ ) -> keras.Model:
32
+ # 1. Input
33
+ inputs = layers.Input(shape=input_shape)
34
+
35
+ # 2. Rescale & conv stem
36
+ data_rescale = cfg.rescaling.to_keras_layer()
37
+ x = _build_stem_from_config(cfg.tokenizer.blocks)(data_rescale(inputs))
38
+
39
+ # 3. Patch extraction: (B, H, W, C) -> (B, num_patches, C*patch_size**2)
40
+ x = PatchExtractor(patch_size=cfg.tokenizer.patch_size)(x)
41
+
42
+ # 5. Optional patch MLP
43
+ if cfg.tokenizer.patch_mlp is not None:
44
+ x = cfg.tokenizer.patch_mlp.to_keras_layer()(x)
45
+
46
+ # 6. Positional embeddings
47
+ if cfg.tokenizer.positional_emb:
48
+ seq_len = keras.ops.shape(x)[1]
49
+ x = x + PositionEmbedding(sequence_length=seq_len, name="pos_emb")(x)
50
+
51
+ # 7. N x TransformerBlock's
52
+ dpr = list(
53
+ np.linspace(0.0, cfg.transformer_encoder.stochastic_depth, cfg.transformer_encoder.layers)
54
+ )
55
+ for i, rate in enumerate(dpr, 1):
56
+ x = TransformerBlock(
57
+ projection_dim=cfg.transformer_encoder.projection_dim,
58
+ num_heads=cfg.transformer_encoder.heads,
59
+ mlp_units=cfg.transformer_encoder.units,
60
+ attention_dropout=cfg.transformer_encoder.attention_dropout,
61
+ mlp_dropout=cfg.transformer_encoder.mlp_dropout,
62
+ drop_path_rate=rate,
63
+ norm_type=cfg.transformer_encoder.normalization,
64
+ activation=cfg.transformer_encoder.activation,
65
+ name=f"transformer_block_{i}",
66
+ )(x)
67
+
68
+ # 8. Reduce to a fixed number of tokens, then project to vocab
69
+ x = TokenReducer(
70
+ num_tokens=max_plate_slots,
71
+ projection_dim=cfg.transformer_encoder.projection_dim,
72
+ num_heads=cfg.transformer_encoder.token_reducer_heads,
73
+ )(x)
74
+
75
+ logits = VocabularyProjection(
76
+ vocabulary_size=vocabulary_size,
77
+ dropout_rate=cfg.transformer_encoder.head_mlp_dropout,
78
+ name="vocab_projection",
79
+ )(x)
80
+
81
+ return keras.Model(inputs, logits, name="CCT_OCR")
82
+
83
+
84
+ def build_model(model_cfg: AnyModelConfig, plate_cfg: PlateOCRConfig) -> keras.Model:
85
+ """
86
+ Build a Keras OCR model based on the specified model and plate configuration.
87
+ """
88
+ if model_cfg.model == "cct":
89
+ return _build_cct_model(
90
+ cfg=model_cfg,
91
+ input_shape=(plate_cfg.img_height, plate_cfg.img_width, plate_cfg.num_channels),
92
+ max_plate_slots=plate_cfg.max_plate_slots,
93
+ vocabulary_size=plate_cfg.vocabulary_size,
94
+ )
95
+ raise ValueError(f"Unsupported model type: {model_cfg.model!r}")
@@ -0,0 +1,395 @@
1
+ """
2
+ Schema definitions for validating supported model architectures and layer blocks.
3
+ """
4
+
5
+ from pathlib import Path
6
+ from typing import Annotated, Literal, Optional, TypeAlias, Union
7
+
8
+ import keras
9
+ import yaml
10
+ from keras.src.layers import RMSNormalization
11
+ from pydantic import BaseModel, Field, PositiveFloat, PositiveInt, model_validator
12
+
13
+ from fast_plate_ocr.core.types import PathLike
14
+ from fast_plate_ocr.train.model.layers import (
15
+ MLP,
16
+ CoordConv2D,
17
+ DyT,
18
+ MaxBlurPooling2D,
19
+ SqueezeExcite,
20
+ )
21
+
22
+ UnitFloat: TypeAlias = Annotated[float, Field(ge=0.0, le=1.0)]
23
+ """A float that must be in range of [0, 1]."""
24
+ PaddingTypeStr: TypeAlias = Literal["valid", "same"]
25
+ """Padding modes supported by Keras convolution and pooling layers."""
26
+ PositiveIntTuple: TypeAlias = Annotated[Union[PositiveInt, tuple[PositiveInt, PositiveInt]], Field(discriminator="type")]
27
+ """A single positive integer or a tuple of two positive integers, usually used for sizes/strides."""
28
+ NormalizationStr: TypeAlias = Literal["layer_norm", "rms_norm", "dyt"]
29
+ """Available normalization layers."""
30
+
31
+ ActivationStr: TypeAlias = Literal[
32
+ "celu",
33
+ "elu",
34
+ "exponential",
35
+ "gelu",
36
+ "glu",
37
+ "hard_shrink",
38
+ "hard_sigmoid",
39
+ "hard_silu",
40
+ "hard_tanh",
41
+ "leaky_relu",
42
+ "linear",
43
+ "log_sigmoid",
44
+ "log_softmax",
45
+ "mish",
46
+ "relu",
47
+ "relu6",
48
+ "selu",
49
+ "sigmoid",
50
+ "silu",
51
+ "soft_shrink",
52
+ "softmax",
53
+ "softplus",
54
+ "softsign",
55
+ "sparse_plus",
56
+ "sparsemax",
57
+ "squareplus",
58
+ "tanh",
59
+ "tanh_shrink",
60
+ "threshold",
61
+ ]
62
+ """Supported Keras activation functions."""
63
+
64
+ WeightInitializationStr: TypeAlias = Literal[
65
+ "glorot_normal",
66
+ "glorot_uniform",
67
+ "he_normal",
68
+ "he_uniform",
69
+ "lecun_normal",
70
+ "lecun_uniform",
71
+ "ones",
72
+ "random_normal",
73
+ "random_uniform",
74
+ "truncated_normal",
75
+ "variance_scaling",
76
+ "zeros",
77
+ ]
78
+ """Keras weight initialization strategies."""
79
+
80
+
81
+ class _Rescaling(BaseModel):
82
+ scale: float = 1.0 / 255
83
+ offset: float = 0.0
84
+
85
+ def to_keras_layer(self):
86
+ return keras.layers.Rescaling(self.scale, self.offset)
87
+
88
+
89
+ class _Activation(BaseModel):
90
+ layer: Literal["Activation"]
91
+ activation: ActivationStr
92
+
93
+ def to_keras_layer(self) -> keras.layers.Layer:
94
+ return keras.layers.Activation(self.activation)
95
+
96
+
97
+ class _Conv2DBase(BaseModel):
98
+ filters: PositiveInt
99
+ kernel_size: PositiveIntTuple
100
+ strides: PositiveIntTuple = 1
101
+ padding: PaddingTypeStr = "same"
102
+ activation: ActivationStr = "relu"
103
+ use_bias: bool = True
104
+ kernel_initializer: WeightInitializationStr = "he_normal"
105
+ bias_initializer: WeightInitializationStr = "zeros"
106
+
107
+
108
+ class _Conv2D(_Conv2DBase):
109
+ layer: Literal["Conv2D"]
110
+
111
+ def to_keras_layer(self) -> keras.layers.Layer:
112
+ params = self.model_dump(exclude={"layer"})
113
+ return keras.layers.Conv2D(**params)
114
+
115
+
116
+ class _CoordConv2D(_Conv2DBase):
117
+ layer: Literal["CoordConv2D"]
118
+ with_r: bool = False
119
+
120
+ def to_keras_layer(self) -> keras.layers.Layer:
121
+ params = self.model_dump(exclude={"layer", "with_r"})
122
+ return CoordConv2D(with_r=self.with_r, **params)
123
+
124
+
125
+ class _DepthwiseConv2D(BaseModel):
126
+ layer: Literal["DepthwiseConv2D"]
127
+ kernel_size: PositiveIntTuple
128
+ strides: PositiveIntTuple = 1
129
+ padding: PaddingTypeStr = "same"
130
+ depth_multiplier: PositiveInt = 1
131
+ activation: ActivationStr = "relu"
132
+ use_bias: bool = True
133
+ depthwise_initializer: WeightInitializationStr = "he_normal"
134
+ bias_initializer: WeightInitializationStr = "zeros"
135
+
136
+ def to_keras_layer(self) -> keras.layers.Layer:
137
+ return keras.layers.DepthwiseConv2D(
138
+ kernel_size=self.kernel_size,
139
+ strides=self.strides,
140
+ padding=self.padding,
141
+ depth_multiplier=self.depth_multiplier,
142
+ activation=self.activation,
143
+ use_bias=self.use_bias,
144
+ depthwise_initializer=self.depthwise_initializer,
145
+ bias_initializer=self.bias_initializer,
146
+ )
147
+
148
+
149
+ class _SeparableConv2D(BaseModel):
150
+ layer: Literal["SeparableConv2D"]
151
+ filters: PositiveInt
152
+ kernel_size: PositiveIntTuple
153
+ strides: PositiveIntTuple = 1
154
+ padding: PaddingTypeStr = "same"
155
+ depth_multiplier: PositiveInt = 1
156
+ activation: ActivationStr = "relu"
157
+ use_bias: bool = True
158
+ depthwise_initializer: WeightInitializationStr = "he_normal"
159
+ pointwise_initializer: WeightInitializationStr = "glorot_uniform"
160
+ bias_initializer: WeightInitializationStr = "zeros"
161
+
162
+ def to_keras_layer(self) -> keras.layers.Layer:
163
+ return keras.layers.SeparableConv2D(
164
+ filters=self.filters,
165
+ kernel_size=self.kernel_size,
166
+ strides=self.strides,
167
+ padding=self.padding,
168
+ depth_multiplier=self.depth_multiplier,
169
+ activation=self.activation,
170
+ use_bias=self.use_bias,
171
+ depthwise_initializer=self.depthwise_initializer,
172
+ pointwise_initializer=self.pointwise_initializer,
173
+ bias_initializer=self.bias_initializer,
174
+ )
175
+
176
+
177
+ class _MLP(BaseModel):
178
+ layer: Literal["MLP"]
179
+ hidden_units: list[PositiveInt]
180
+ dropout_rate: UnitFloat = 0.1
181
+ activation: ActivationStr = "gelu"
182
+ use_bias: bool = True
183
+
184
+ def to_keras_layer(self) -> keras.layers.Layer:
185
+ return MLP(
186
+ hidden_units=self.hidden_units,
187
+ dropout_rate=self.dropout_rate,
188
+ activation=self.activation,
189
+ use_bias=self.use_bias,
190
+ )
191
+
192
+
193
+ class _MaxBlurPooling2D(BaseModel):
194
+ layer: Literal["MaxBlurPooling2D"]
195
+ pool_size: PositiveInt = 2
196
+ filter_size: PositiveInt = 3
197
+ padding: PaddingTypeStr = "same"
198
+
199
+ def to_keras_layer(self) -> keras.layers.Layer:
200
+ return MaxBlurPooling2D(
201
+ pool_size=self.pool_size, filter_size=self.filter_size, padding=self.padding
202
+ )
203
+
204
+
205
+ class _MaxPooling2D(BaseModel):
206
+ layer: Literal["MaxPooling2D"]
207
+ pool_size: PositiveIntTuple = 2
208
+ strides: Optional[PositiveInt] = None
209
+ padding: PaddingTypeStr = "valid"
210
+
211
+ def to_keras_layer(self) -> keras.layers.Layer:
212
+ return keras.layers.MaxPooling2D(
213
+ pool_size=self.pool_size,
214
+ strides=self.strides,
215
+ padding=self.padding,
216
+ )
217
+
218
+
219
+ class _AveragePooling2D(BaseModel):
220
+ layer: Literal["AveragePooling2D"]
221
+ pool_size: PositiveIntTuple = 2
222
+ strides: Optional[PositiveInt] = None
223
+ padding: PaddingTypeStr = "valid"
224
+
225
+ def to_keras_layer(self) -> keras.layers.Layer:
226
+ return keras.layers.AveragePooling2D(
227
+ pool_size=self.pool_size,
228
+ strides=self.strides,
229
+ padding=self.padding,
230
+ )
231
+
232
+
233
+ class _ZeroPadding2D(BaseModel):
234
+ layer: Literal["ZeroPadding2D"]
235
+ padding: PositiveIntTuple = 1
236
+
237
+ def to_keras_layer(self) -> keras.layers.Layer:
238
+ return keras.layers.ZeroPadding2D(padding=self.padding)
239
+
240
+
241
+ class _SqueezeExcite(BaseModel):
242
+ layer: Literal["SqueezeExcite"]
243
+ ratio: PositiveFloat = 1.0
244
+
245
+ def to_keras_layer(self) -> keras.layers.Layer:
246
+ return SqueezeExcite(ratio=self.ratio)
247
+
248
+
249
+ class _BatchNormalization(BaseModel):
250
+ layer: Literal["BatchNormalization"]
251
+ momentum: PositiveFloat = 0.99
252
+ epsilon: PositiveFloat = 1e-3
253
+ center: bool = True
254
+ scale: bool = True
255
+
256
+ def to_keras_layer(self) -> keras.layers.Layer:
257
+ return keras.layers.BatchNormalization(
258
+ momentum=self.momentum,
259
+ epsilon=self.epsilon,
260
+ center=self.center,
261
+ scale=self.scale,
262
+ )
263
+
264
+
265
+ class _Dropout(BaseModel):
266
+ layer: Literal["Dropout"]
267
+ rate: PositiveFloat
268
+
269
+ def to_keras_layer(self):
270
+ return keras.layers.Dropout(rate=self.rate)
271
+
272
+
273
+ class _SpatialDropout2D(BaseModel):
274
+ layer: Literal["SpatialDropout2D"]
275
+ rate: PositiveFloat
276
+
277
+ def to_keras_layer(self) -> keras.layers.Layer:
278
+ return keras.layers.SpatialDropout2D(rate=self.rate)
279
+
280
+
281
+ class _GaussianNoise(BaseModel):
282
+ layer: Literal["GaussianNoise"]
283
+ stddev: PositiveFloat
284
+ seed: Optional[int] = None
285
+
286
+ def to_keras_layer(self) -> keras.layers.Layer:
287
+ return keras.layers.GaussianNoise(stddev=self.stddev, seed=self.seed)
288
+
289
+
290
+ class _LayerNorm(BaseModel):
291
+ layer: Literal["LayerNorm"]
292
+ epsilon: PositiveFloat = 1e-3
293
+
294
+ def to_keras_layer(self) -> keras.layers.Layer:
295
+ return keras.layers.LayerNormalization(epsilon=self.epsilon)
296
+
297
+
298
+ class _RMSNorm(BaseModel):
299
+ layer: Literal["RMSNorm"]
300
+ epsilon: PositiveFloat = 1e-6
301
+
302
+ def to_keras_layer(self) -> keras.layers.Layer:
303
+ return RMSNormalization(epsilon=self.epsilon)
304
+
305
+
306
+ class _DyT(BaseModel):
307
+ layer: Literal["DyT"]
308
+ alpha_init_value: PositiveFloat = 0.5
309
+
310
+ def to_keras_layer(self) -> keras.layers.Layer:
311
+ return DyT(alpha_init_value=self.alpha_init_value)
312
+
313
+
314
+ LayerConfig = Annotated[
315
+ _Activation
316
+ | _Conv2D
317
+ | _CoordConv2D
318
+ | _DepthwiseConv2D
319
+ | _SeparableConv2D
320
+ | _MLP
321
+ | _MaxBlurPooling2D
322
+ | _MaxPooling2D
323
+ | _AveragePooling2D
324
+ | _ZeroPadding2D
325
+ | _SqueezeExcite
326
+ | _BatchNormalization
327
+ | _Dropout
328
+ | _SpatialDropout2D
329
+ | _GaussianNoise
330
+ | _LayerNorm
331
+ | _RMSNorm
332
+ | _DyT,
333
+ Field(discriminator="layer"),
334
+ ]
335
+
336
+
337
+ class _CCTTokenizerConfig(BaseModel):
338
+ blocks: list[LayerConfig]
339
+ patch_size: PositiveIntTuple = 1
340
+ patch_mlp: Optional[_MLP] = None
341
+ positional_emb: bool = True
342
+
343
+
344
+ class _CCTTransformerEncoderConfig(BaseModel):
345
+ layers: PositiveInt
346
+ heads: PositiveInt
347
+ projection_dim: PositiveInt
348
+ units: list[PositiveInt]
349
+ activation: ActivationStr = "gelu"
350
+ stochastic_depth: UnitFloat = 0.1
351
+ attention_dropout: UnitFloat = 0.1
352
+ mlp_dropout: UnitFloat = 0.1
353
+ head_mlp_dropout: UnitFloat = 0.2
354
+ token_reducer_heads: PositiveInt = 2
355
+ normalization: NormalizationStr = "layer_norm"
356
+
357
+ @model_validator(mode="after")
358
+ def _consistency_checks(self):
359
+ if self.units[-1] != self.projection_dim:
360
+ raise ValueError(
361
+ "'units[-1]' must equal 'projection_dim' "
362
+ f"(got {self.units[-1]} vs {self.projection_dim})."
363
+ )
364
+ return self
365
+
366
+
367
+ class CCTModelConfig(BaseModel):
368
+ model: Literal["cct"] = "cct"
369
+ rescaling: _Rescaling
370
+ tokenizer: _CCTTokenizerConfig
371
+ transformer_encoder: _CCTTransformerEncoderConfig
372
+
373
+
374
+ AnyModelConfig = Annotated[CCTModelConfig, Field(discriminator="model")]
375
+ """Supported model-architecture. New model configs should be added here."""
376
+
377
+
378
+ def load_model_config_from_yaml(yaml_path: PathLike) -> AnyModelConfig:
379
+ """
380
+ Loads, parses, and validates a YAML file defining a model architecture.
381
+
382
+ Args:
383
+ yaml_path: Path to the YAML file.
384
+
385
+ Returns:
386
+ AnyModelConfig: Parsed and validated model configuration.
387
+
388
+ Raises:
389
+ FileNotFoundError: If the YAML file does not exist.
390
+ """
391
+ if not Path(yaml_path).is_file():
392
+ raise FileNotFoundError(f"Model config '{yaml_path}' doesn't exist.")
393
+ with open(yaml_path, encoding="utf-8") as f_in:
394
+ data = yaml.safe_load(f_in)
395
+ return AnyModelConfig(**data)
@@ -0,0 +1,38 @@
1
+ """
2
+ Utils for Keras supported backends.
3
+ """
4
+
5
+ import os
6
+ from typing import Literal, TypeAlias
7
+
8
+ Framework: TypeAlias = Literal["jax", "tensorflow", "torch"]
9
+ """Supported backend frameworks for Keras."""
10
+
11
+
12
+ def set_jax_backend() -> None:
13
+ """Set Keras backend to jax."""
14
+ set_keras_backend("jax")
15
+
16
+
17
+ def set_tensorflow_backend() -> None:
18
+ """Set Keras backend to tensorflow."""
19
+ set_keras_backend("tensorflow")
20
+
21
+
22
+ def set_pytorch_backend() -> None:
23
+ """Set Keras backend to pytorch."""
24
+ set_keras_backend("torch")
25
+
26
+
27
+ def set_keras_backend(framework: Framework) -> None:
28
+ """Set the Keras backend to a given framework."""
29
+ os.environ["KERAS_BACKEND"] = framework
30
+
31
+
32
+ def reload_keras_backend(framework: Framework) -> None:
33
+ """Reload the Keras backend with a given framework."""
34
+ # ruff: noqa: PLC0415
35
+ # pylint: disable=import-outside-toplevel
36
+ import keras
37
+
38
+ keras.config.set_backend(framework)