matrice-analytics 0.1.60__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- matrice_analytics/__init__.py +28 -0
- matrice_analytics/boundary_drawing_internal/README.md +305 -0
- matrice_analytics/boundary_drawing_internal/__init__.py +45 -0
- matrice_analytics/boundary_drawing_internal/boundary_drawing_internal.py +1207 -0
- matrice_analytics/boundary_drawing_internal/boundary_drawing_tool.py +429 -0
- matrice_analytics/boundary_drawing_internal/boundary_tool_template.html +1036 -0
- matrice_analytics/boundary_drawing_internal/data/.gitignore +12 -0
- matrice_analytics/boundary_drawing_internal/example_usage.py +206 -0
- matrice_analytics/boundary_drawing_internal/usage/README.md +110 -0
- matrice_analytics/boundary_drawing_internal/usage/boundary_drawer_launcher.py +102 -0
- matrice_analytics/boundary_drawing_internal/usage/simple_boundary_launcher.py +107 -0
- matrice_analytics/post_processing/README.md +455 -0
- matrice_analytics/post_processing/__init__.py +732 -0
- matrice_analytics/post_processing/advanced_tracker/README.md +650 -0
- matrice_analytics/post_processing/advanced_tracker/__init__.py +17 -0
- matrice_analytics/post_processing/advanced_tracker/base.py +99 -0
- matrice_analytics/post_processing/advanced_tracker/config.py +77 -0
- matrice_analytics/post_processing/advanced_tracker/kalman_filter.py +370 -0
- matrice_analytics/post_processing/advanced_tracker/matching.py +195 -0
- matrice_analytics/post_processing/advanced_tracker/strack.py +230 -0
- matrice_analytics/post_processing/advanced_tracker/tracker.py +367 -0
- matrice_analytics/post_processing/config.py +146 -0
- matrice_analytics/post_processing/core/__init__.py +63 -0
- matrice_analytics/post_processing/core/base.py +704 -0
- matrice_analytics/post_processing/core/config.py +3291 -0
- matrice_analytics/post_processing/core/config_utils.py +925 -0
- matrice_analytics/post_processing/face_reg/__init__.py +43 -0
- matrice_analytics/post_processing/face_reg/compare_similarity.py +556 -0
- matrice_analytics/post_processing/face_reg/embedding_manager.py +950 -0
- matrice_analytics/post_processing/face_reg/face_recognition.py +2234 -0
- matrice_analytics/post_processing/face_reg/face_recognition_client.py +606 -0
- matrice_analytics/post_processing/face_reg/people_activity_logging.py +321 -0
- matrice_analytics/post_processing/ocr/__init__.py +0 -0
- matrice_analytics/post_processing/ocr/easyocr_extractor.py +250 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/__init__.py +9 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/__init__.py +4 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/cli.py +33 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/dataset_stats.py +139 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/export.py +398 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/train.py +447 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/utils.py +129 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/valid.py +93 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/validate_dataset.py +240 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/visualize_augmentation.py +176 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/visualize_predictions.py +96 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/core/__init__.py +3 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/core/process.py +246 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/core/types.py +60 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/core/utils.py +87 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/inference/__init__.py +3 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/inference/config.py +82 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/inference/hub.py +141 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/inference/plate_recognizer.py +323 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/py.typed +0 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/__init__.py +0 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/data/__init__.py +0 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/data/augmentation.py +101 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/data/dataset.py +97 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/__init__.py +0 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/config.py +114 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/layers.py +553 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/loss.py +55 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/metric.py +86 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/model_builders.py +95 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/model_schema.py +395 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/utilities/__init__.py +0 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/utilities/backend_utils.py +38 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/utilities/utils.py +214 -0
- matrice_analytics/post_processing/ocr/postprocessing.py +270 -0
- matrice_analytics/post_processing/ocr/preprocessing.py +52 -0
- matrice_analytics/post_processing/post_processor.py +1175 -0
- matrice_analytics/post_processing/test_cases/__init__.py +1 -0
- matrice_analytics/post_processing/test_cases/run_tests.py +143 -0
- matrice_analytics/post_processing/test_cases/test_advanced_customer_service.py +841 -0
- matrice_analytics/post_processing/test_cases/test_basic_counting_tracking.py +523 -0
- matrice_analytics/post_processing/test_cases/test_comprehensive.py +531 -0
- matrice_analytics/post_processing/test_cases/test_config.py +852 -0
- matrice_analytics/post_processing/test_cases/test_customer_service.py +585 -0
- matrice_analytics/post_processing/test_cases/test_data_generators.py +583 -0
- matrice_analytics/post_processing/test_cases/test_people_counting.py +510 -0
- matrice_analytics/post_processing/test_cases/test_processor.py +524 -0
- matrice_analytics/post_processing/test_cases/test_usecases.py +165 -0
- matrice_analytics/post_processing/test_cases/test_utilities.py +356 -0
- matrice_analytics/post_processing/test_cases/test_utils.py +743 -0
- matrice_analytics/post_processing/usecases/Histopathological_Cancer_Detection_img.py +604 -0
- matrice_analytics/post_processing/usecases/__init__.py +267 -0
- matrice_analytics/post_processing/usecases/abandoned_object_detection.py +797 -0
- matrice_analytics/post_processing/usecases/advanced_customer_service.py +1601 -0
- matrice_analytics/post_processing/usecases/age_detection.py +842 -0
- matrice_analytics/post_processing/usecases/age_gender_detection.py +1085 -0
- matrice_analytics/post_processing/usecases/anti_spoofing_detection.py +656 -0
- matrice_analytics/post_processing/usecases/assembly_line_detection.py +841 -0
- matrice_analytics/post_processing/usecases/banana_defect_detection.py +624 -0
- matrice_analytics/post_processing/usecases/basic_counting_tracking.py +667 -0
- matrice_analytics/post_processing/usecases/blood_cancer_detection_img.py +881 -0
- matrice_analytics/post_processing/usecases/car_damage_detection.py +834 -0
- matrice_analytics/post_processing/usecases/car_part_segmentation.py +946 -0
- matrice_analytics/post_processing/usecases/car_service.py +1601 -0
- matrice_analytics/post_processing/usecases/cardiomegaly_classification.py +864 -0
- matrice_analytics/post_processing/usecases/cell_microscopy_segmentation.py +897 -0
- matrice_analytics/post_processing/usecases/chicken_pose_detection.py +648 -0
- matrice_analytics/post_processing/usecases/child_monitoring.py +814 -0
- matrice_analytics/post_processing/usecases/color/clip.py +660 -0
- matrice_analytics/post_processing/usecases/color/clip_processor/merges.txt +48895 -0
- matrice_analytics/post_processing/usecases/color/clip_processor/preprocessor_config.json +28 -0
- matrice_analytics/post_processing/usecases/color/clip_processor/special_tokens_map.json +30 -0
- matrice_analytics/post_processing/usecases/color/clip_processor/tokenizer.json +245079 -0
- matrice_analytics/post_processing/usecases/color/clip_processor/tokenizer_config.json +32 -0
- matrice_analytics/post_processing/usecases/color/clip_processor/vocab.json +1 -0
- matrice_analytics/post_processing/usecases/color/color_map_utils.py +70 -0
- matrice_analytics/post_processing/usecases/color/color_mapper.py +468 -0
- matrice_analytics/post_processing/usecases/color_detection.py +1936 -0
- matrice_analytics/post_processing/usecases/color_map_utils.py +70 -0
- matrice_analytics/post_processing/usecases/concrete_crack_detection.py +827 -0
- matrice_analytics/post_processing/usecases/crop_weed_detection.py +781 -0
- matrice_analytics/post_processing/usecases/customer_service.py +1008 -0
- matrice_analytics/post_processing/usecases/defect_detection_products.py +936 -0
- matrice_analytics/post_processing/usecases/distracted_driver_detection.py +822 -0
- matrice_analytics/post_processing/usecases/drone_traffic_monitoring.py +585 -0
- matrice_analytics/post_processing/usecases/drowsy_driver_detection.py +829 -0
- matrice_analytics/post_processing/usecases/dwell_detection.py +829 -0
- matrice_analytics/post_processing/usecases/emergency_vehicle_detection.py +827 -0
- matrice_analytics/post_processing/usecases/face_emotion.py +813 -0
- matrice_analytics/post_processing/usecases/face_recognition.py +827 -0
- matrice_analytics/post_processing/usecases/fashion_detection.py +835 -0
- matrice_analytics/post_processing/usecases/field_mapping.py +902 -0
- matrice_analytics/post_processing/usecases/fire_detection.py +1146 -0
- matrice_analytics/post_processing/usecases/flare_analysis.py +836 -0
- matrice_analytics/post_processing/usecases/flower_segmentation.py +1006 -0
- matrice_analytics/post_processing/usecases/gas_leak_detection.py +837 -0
- matrice_analytics/post_processing/usecases/gender_detection.py +832 -0
- matrice_analytics/post_processing/usecases/human_activity_recognition.py +871 -0
- matrice_analytics/post_processing/usecases/intrusion_detection.py +1672 -0
- matrice_analytics/post_processing/usecases/leaf.py +821 -0
- matrice_analytics/post_processing/usecases/leaf_disease.py +840 -0
- matrice_analytics/post_processing/usecases/leak_detection.py +837 -0
- matrice_analytics/post_processing/usecases/license_plate_detection.py +1188 -0
- matrice_analytics/post_processing/usecases/license_plate_monitoring.py +1781 -0
- matrice_analytics/post_processing/usecases/litter_monitoring.py +717 -0
- matrice_analytics/post_processing/usecases/mask_detection.py +869 -0
- matrice_analytics/post_processing/usecases/natural_disaster.py +907 -0
- matrice_analytics/post_processing/usecases/parking.py +787 -0
- matrice_analytics/post_processing/usecases/parking_space_detection.py +822 -0
- matrice_analytics/post_processing/usecases/pcb_defect_detection.py +888 -0
- matrice_analytics/post_processing/usecases/pedestrian_detection.py +808 -0
- matrice_analytics/post_processing/usecases/people_counting.py +706 -0
- matrice_analytics/post_processing/usecases/people_counting_bckp.py +1683 -0
- matrice_analytics/post_processing/usecases/people_tracking.py +1842 -0
- matrice_analytics/post_processing/usecases/pipeline_detection.py +605 -0
- matrice_analytics/post_processing/usecases/plaque_segmentation_img.py +874 -0
- matrice_analytics/post_processing/usecases/pothole_segmentation.py +915 -0
- matrice_analytics/post_processing/usecases/ppe_compliance.py +645 -0
- matrice_analytics/post_processing/usecases/price_tag_detection.py +822 -0
- matrice_analytics/post_processing/usecases/proximity_detection.py +1901 -0
- matrice_analytics/post_processing/usecases/road_lane_detection.py +623 -0
- matrice_analytics/post_processing/usecases/road_traffic_density.py +832 -0
- matrice_analytics/post_processing/usecases/road_view_segmentation.py +915 -0
- matrice_analytics/post_processing/usecases/shelf_inventory_detection.py +583 -0
- matrice_analytics/post_processing/usecases/shoplifting_detection.py +822 -0
- matrice_analytics/post_processing/usecases/shopping_cart_analysis.py +899 -0
- matrice_analytics/post_processing/usecases/skin_cancer_classification_img.py +864 -0
- matrice_analytics/post_processing/usecases/smoker_detection.py +833 -0
- matrice_analytics/post_processing/usecases/solar_panel.py +810 -0
- matrice_analytics/post_processing/usecases/suspicious_activity_detection.py +1030 -0
- matrice_analytics/post_processing/usecases/template_usecase.py +380 -0
- matrice_analytics/post_processing/usecases/theft_detection.py +648 -0
- matrice_analytics/post_processing/usecases/traffic_sign_monitoring.py +724 -0
- matrice_analytics/post_processing/usecases/underground_pipeline_defect_detection.py +775 -0
- matrice_analytics/post_processing/usecases/underwater_pollution_detection.py +842 -0
- matrice_analytics/post_processing/usecases/vehicle_monitoring.py +1029 -0
- matrice_analytics/post_processing/usecases/warehouse_object_segmentation.py +899 -0
- matrice_analytics/post_processing/usecases/waterbody_segmentation.py +923 -0
- matrice_analytics/post_processing/usecases/weapon_detection.py +771 -0
- matrice_analytics/post_processing/usecases/weld_defect_detection.py +615 -0
- matrice_analytics/post_processing/usecases/wildlife_monitoring.py +898 -0
- matrice_analytics/post_processing/usecases/windmill_maintenance.py +834 -0
- matrice_analytics/post_processing/usecases/wound_segmentation.py +856 -0
- matrice_analytics/post_processing/utils/__init__.py +150 -0
- matrice_analytics/post_processing/utils/advanced_counting_utils.py +400 -0
- matrice_analytics/post_processing/utils/advanced_helper_utils.py +317 -0
- matrice_analytics/post_processing/utils/advanced_tracking_utils.py +461 -0
- matrice_analytics/post_processing/utils/alerting_utils.py +213 -0
- matrice_analytics/post_processing/utils/category_mapping_utils.py +94 -0
- matrice_analytics/post_processing/utils/color_utils.py +592 -0
- matrice_analytics/post_processing/utils/counting_utils.py +182 -0
- matrice_analytics/post_processing/utils/filter_utils.py +261 -0
- matrice_analytics/post_processing/utils/format_utils.py +293 -0
- matrice_analytics/post_processing/utils/geometry_utils.py +300 -0
- matrice_analytics/post_processing/utils/smoothing_utils.py +358 -0
- matrice_analytics/post_processing/utils/tracking_utils.py +234 -0
- matrice_analytics/py.typed +0 -0
- matrice_analytics-0.1.60.dist-info/METADATA +481 -0
- matrice_analytics-0.1.60.dist-info/RECORD +196 -0
- matrice_analytics-0.1.60.dist-info/WHEEL +5 -0
- matrice_analytics-0.1.60.dist-info/licenses/LICENSE.txt +21 -0
- matrice_analytics-0.1.60.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Model builder functions for supported architectures.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from collections.abc import Sequence
|
|
6
|
+
|
|
7
|
+
import keras
|
|
8
|
+
import numpy as np
|
|
9
|
+
from keras import layers
|
|
10
|
+
|
|
11
|
+
from fast_plate_ocr.train.model.config import PlateOCRConfig
|
|
12
|
+
from fast_plate_ocr.train.model.layers import (
|
|
13
|
+
PatchExtractor,
|
|
14
|
+
PositionEmbedding,
|
|
15
|
+
TokenReducer,
|
|
16
|
+
TransformerBlock,
|
|
17
|
+
VocabularyProjection,
|
|
18
|
+
)
|
|
19
|
+
from fast_plate_ocr.train.model.model_schema import AnyModelConfig, CCTModelConfig, LayerConfig
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _build_stem_from_config(specs: Sequence[LayerConfig]) -> keras.Sequential:
|
|
23
|
+
return keras.Sequential([spec.to_keras_layer() for spec in specs], name="conv_stem")
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _build_cct_model(
|
|
27
|
+
cfg: CCTModelConfig,
|
|
28
|
+
input_shape: tuple[int, int, int],
|
|
29
|
+
max_plate_slots: int,
|
|
30
|
+
vocabulary_size: int,
|
|
31
|
+
) -> keras.Model:
|
|
32
|
+
# 1. Input
|
|
33
|
+
inputs = layers.Input(shape=input_shape)
|
|
34
|
+
|
|
35
|
+
# 2. Rescale & conv stem
|
|
36
|
+
data_rescale = cfg.rescaling.to_keras_layer()
|
|
37
|
+
x = _build_stem_from_config(cfg.tokenizer.blocks)(data_rescale(inputs))
|
|
38
|
+
|
|
39
|
+
# 3. Patch extraction: (B, H, W, C) -> (B, num_patches, C*patch_size**2)
|
|
40
|
+
x = PatchExtractor(patch_size=cfg.tokenizer.patch_size)(x)
|
|
41
|
+
|
|
42
|
+
# 5. Optional patch MLP
|
|
43
|
+
if cfg.tokenizer.patch_mlp is not None:
|
|
44
|
+
x = cfg.tokenizer.patch_mlp.to_keras_layer()(x)
|
|
45
|
+
|
|
46
|
+
# 6. Positional embeddings
|
|
47
|
+
if cfg.tokenizer.positional_emb:
|
|
48
|
+
seq_len = keras.ops.shape(x)[1]
|
|
49
|
+
x = x + PositionEmbedding(sequence_length=seq_len, name="pos_emb")(x)
|
|
50
|
+
|
|
51
|
+
# 7. N x TransformerBlock's
|
|
52
|
+
dpr = list(
|
|
53
|
+
np.linspace(0.0, cfg.transformer_encoder.stochastic_depth, cfg.transformer_encoder.layers)
|
|
54
|
+
)
|
|
55
|
+
for i, rate in enumerate(dpr, 1):
|
|
56
|
+
x = TransformerBlock(
|
|
57
|
+
projection_dim=cfg.transformer_encoder.projection_dim,
|
|
58
|
+
num_heads=cfg.transformer_encoder.heads,
|
|
59
|
+
mlp_units=cfg.transformer_encoder.units,
|
|
60
|
+
attention_dropout=cfg.transformer_encoder.attention_dropout,
|
|
61
|
+
mlp_dropout=cfg.transformer_encoder.mlp_dropout,
|
|
62
|
+
drop_path_rate=rate,
|
|
63
|
+
norm_type=cfg.transformer_encoder.normalization,
|
|
64
|
+
activation=cfg.transformer_encoder.activation,
|
|
65
|
+
name=f"transformer_block_{i}",
|
|
66
|
+
)(x)
|
|
67
|
+
|
|
68
|
+
# 8. Reduce to a fixed number of tokens, then project to vocab
|
|
69
|
+
x = TokenReducer(
|
|
70
|
+
num_tokens=max_plate_slots,
|
|
71
|
+
projection_dim=cfg.transformer_encoder.projection_dim,
|
|
72
|
+
num_heads=cfg.transformer_encoder.token_reducer_heads,
|
|
73
|
+
)(x)
|
|
74
|
+
|
|
75
|
+
logits = VocabularyProjection(
|
|
76
|
+
vocabulary_size=vocabulary_size,
|
|
77
|
+
dropout_rate=cfg.transformer_encoder.head_mlp_dropout,
|
|
78
|
+
name="vocab_projection",
|
|
79
|
+
)(x)
|
|
80
|
+
|
|
81
|
+
return keras.Model(inputs, logits, name="CCT_OCR")
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def build_model(model_cfg: AnyModelConfig, plate_cfg: PlateOCRConfig) -> keras.Model:
|
|
85
|
+
"""
|
|
86
|
+
Build a Keras OCR model based on the specified model and plate configuration.
|
|
87
|
+
"""
|
|
88
|
+
if model_cfg.model == "cct":
|
|
89
|
+
return _build_cct_model(
|
|
90
|
+
cfg=model_cfg,
|
|
91
|
+
input_shape=(plate_cfg.img_height, plate_cfg.img_width, plate_cfg.num_channels),
|
|
92
|
+
max_plate_slots=plate_cfg.max_plate_slots,
|
|
93
|
+
vocabulary_size=plate_cfg.vocabulary_size,
|
|
94
|
+
)
|
|
95
|
+
raise ValueError(f"Unsupported model type: {model_cfg.model!r}")
|
|
@@ -0,0 +1,395 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Schema definitions for validating supported model architectures and layer blocks.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Annotated, Literal, Optional, TypeAlias, Union
|
|
7
|
+
|
|
8
|
+
import keras
|
|
9
|
+
import yaml
|
|
10
|
+
from keras.src.layers import RMSNormalization
|
|
11
|
+
from pydantic import BaseModel, Field, PositiveFloat, PositiveInt, model_validator
|
|
12
|
+
|
|
13
|
+
from fast_plate_ocr.core.types import PathLike
|
|
14
|
+
from fast_plate_ocr.train.model.layers import (
|
|
15
|
+
MLP,
|
|
16
|
+
CoordConv2D,
|
|
17
|
+
DyT,
|
|
18
|
+
MaxBlurPooling2D,
|
|
19
|
+
SqueezeExcite,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
UnitFloat: TypeAlias = Annotated[float, Field(ge=0.0, le=1.0)]
|
|
23
|
+
"""A float that must be in range of [0, 1]."""
|
|
24
|
+
PaddingTypeStr: TypeAlias = Literal["valid", "same"]
|
|
25
|
+
"""Padding modes supported by Keras convolution and pooling layers."""
|
|
26
|
+
PositiveIntTuple: TypeAlias = Annotated[Union[PositiveInt, tuple[PositiveInt, PositiveInt]], Field(discriminator="type")]
|
|
27
|
+
"""A single positive integer or a tuple of two positive integers, usually used for sizes/strides."""
|
|
28
|
+
NormalizationStr: TypeAlias = Literal["layer_norm", "rms_norm", "dyt"]
|
|
29
|
+
"""Available normalization layers."""
|
|
30
|
+
|
|
31
|
+
ActivationStr: TypeAlias = Literal[
|
|
32
|
+
"celu",
|
|
33
|
+
"elu",
|
|
34
|
+
"exponential",
|
|
35
|
+
"gelu",
|
|
36
|
+
"glu",
|
|
37
|
+
"hard_shrink",
|
|
38
|
+
"hard_sigmoid",
|
|
39
|
+
"hard_silu",
|
|
40
|
+
"hard_tanh",
|
|
41
|
+
"leaky_relu",
|
|
42
|
+
"linear",
|
|
43
|
+
"log_sigmoid",
|
|
44
|
+
"log_softmax",
|
|
45
|
+
"mish",
|
|
46
|
+
"relu",
|
|
47
|
+
"relu6",
|
|
48
|
+
"selu",
|
|
49
|
+
"sigmoid",
|
|
50
|
+
"silu",
|
|
51
|
+
"soft_shrink",
|
|
52
|
+
"softmax",
|
|
53
|
+
"softplus",
|
|
54
|
+
"softsign",
|
|
55
|
+
"sparse_plus",
|
|
56
|
+
"sparsemax",
|
|
57
|
+
"squareplus",
|
|
58
|
+
"tanh",
|
|
59
|
+
"tanh_shrink",
|
|
60
|
+
"threshold",
|
|
61
|
+
]
|
|
62
|
+
"""Supported Keras activation functions."""
|
|
63
|
+
|
|
64
|
+
WeightInitializationStr: TypeAlias = Literal[
|
|
65
|
+
"glorot_normal",
|
|
66
|
+
"glorot_uniform",
|
|
67
|
+
"he_normal",
|
|
68
|
+
"he_uniform",
|
|
69
|
+
"lecun_normal",
|
|
70
|
+
"lecun_uniform",
|
|
71
|
+
"ones",
|
|
72
|
+
"random_normal",
|
|
73
|
+
"random_uniform",
|
|
74
|
+
"truncated_normal",
|
|
75
|
+
"variance_scaling",
|
|
76
|
+
"zeros",
|
|
77
|
+
]
|
|
78
|
+
"""Keras weight initialization strategies."""
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class _Rescaling(BaseModel):
|
|
82
|
+
scale: float = 1.0 / 255
|
|
83
|
+
offset: float = 0.0
|
|
84
|
+
|
|
85
|
+
def to_keras_layer(self):
|
|
86
|
+
return keras.layers.Rescaling(self.scale, self.offset)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class _Activation(BaseModel):
|
|
90
|
+
layer: Literal["Activation"]
|
|
91
|
+
activation: ActivationStr
|
|
92
|
+
|
|
93
|
+
def to_keras_layer(self) -> keras.layers.Layer:
|
|
94
|
+
return keras.layers.Activation(self.activation)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class _Conv2DBase(BaseModel):
|
|
98
|
+
filters: PositiveInt
|
|
99
|
+
kernel_size: PositiveIntTuple
|
|
100
|
+
strides: PositiveIntTuple = 1
|
|
101
|
+
padding: PaddingTypeStr = "same"
|
|
102
|
+
activation: ActivationStr = "relu"
|
|
103
|
+
use_bias: bool = True
|
|
104
|
+
kernel_initializer: WeightInitializationStr = "he_normal"
|
|
105
|
+
bias_initializer: WeightInitializationStr = "zeros"
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
class _Conv2D(_Conv2DBase):
|
|
109
|
+
layer: Literal["Conv2D"]
|
|
110
|
+
|
|
111
|
+
def to_keras_layer(self) -> keras.layers.Layer:
|
|
112
|
+
params = self.model_dump(exclude={"layer"})
|
|
113
|
+
return keras.layers.Conv2D(**params)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class _CoordConv2D(_Conv2DBase):
|
|
117
|
+
layer: Literal["CoordConv2D"]
|
|
118
|
+
with_r: bool = False
|
|
119
|
+
|
|
120
|
+
def to_keras_layer(self) -> keras.layers.Layer:
|
|
121
|
+
params = self.model_dump(exclude={"layer", "with_r"})
|
|
122
|
+
return CoordConv2D(with_r=self.with_r, **params)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class _DepthwiseConv2D(BaseModel):
|
|
126
|
+
layer: Literal["DepthwiseConv2D"]
|
|
127
|
+
kernel_size: PositiveIntTuple
|
|
128
|
+
strides: PositiveIntTuple = 1
|
|
129
|
+
padding: PaddingTypeStr = "same"
|
|
130
|
+
depth_multiplier: PositiveInt = 1
|
|
131
|
+
activation: ActivationStr = "relu"
|
|
132
|
+
use_bias: bool = True
|
|
133
|
+
depthwise_initializer: WeightInitializationStr = "he_normal"
|
|
134
|
+
bias_initializer: WeightInitializationStr = "zeros"
|
|
135
|
+
|
|
136
|
+
def to_keras_layer(self) -> keras.layers.Layer:
|
|
137
|
+
return keras.layers.DepthwiseConv2D(
|
|
138
|
+
kernel_size=self.kernel_size,
|
|
139
|
+
strides=self.strides,
|
|
140
|
+
padding=self.padding,
|
|
141
|
+
depth_multiplier=self.depth_multiplier,
|
|
142
|
+
activation=self.activation,
|
|
143
|
+
use_bias=self.use_bias,
|
|
144
|
+
depthwise_initializer=self.depthwise_initializer,
|
|
145
|
+
bias_initializer=self.bias_initializer,
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
class _SeparableConv2D(BaseModel):
|
|
150
|
+
layer: Literal["SeparableConv2D"]
|
|
151
|
+
filters: PositiveInt
|
|
152
|
+
kernel_size: PositiveIntTuple
|
|
153
|
+
strides: PositiveIntTuple = 1
|
|
154
|
+
padding: PaddingTypeStr = "same"
|
|
155
|
+
depth_multiplier: PositiveInt = 1
|
|
156
|
+
activation: ActivationStr = "relu"
|
|
157
|
+
use_bias: bool = True
|
|
158
|
+
depthwise_initializer: WeightInitializationStr = "he_normal"
|
|
159
|
+
pointwise_initializer: WeightInitializationStr = "glorot_uniform"
|
|
160
|
+
bias_initializer: WeightInitializationStr = "zeros"
|
|
161
|
+
|
|
162
|
+
def to_keras_layer(self) -> keras.layers.Layer:
|
|
163
|
+
return keras.layers.SeparableConv2D(
|
|
164
|
+
filters=self.filters,
|
|
165
|
+
kernel_size=self.kernel_size,
|
|
166
|
+
strides=self.strides,
|
|
167
|
+
padding=self.padding,
|
|
168
|
+
depth_multiplier=self.depth_multiplier,
|
|
169
|
+
activation=self.activation,
|
|
170
|
+
use_bias=self.use_bias,
|
|
171
|
+
depthwise_initializer=self.depthwise_initializer,
|
|
172
|
+
pointwise_initializer=self.pointwise_initializer,
|
|
173
|
+
bias_initializer=self.bias_initializer,
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
class _MLP(BaseModel):
|
|
178
|
+
layer: Literal["MLP"]
|
|
179
|
+
hidden_units: list[PositiveInt]
|
|
180
|
+
dropout_rate: UnitFloat = 0.1
|
|
181
|
+
activation: ActivationStr = "gelu"
|
|
182
|
+
use_bias: bool = True
|
|
183
|
+
|
|
184
|
+
def to_keras_layer(self) -> keras.layers.Layer:
|
|
185
|
+
return MLP(
|
|
186
|
+
hidden_units=self.hidden_units,
|
|
187
|
+
dropout_rate=self.dropout_rate,
|
|
188
|
+
activation=self.activation,
|
|
189
|
+
use_bias=self.use_bias,
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
class _MaxBlurPooling2D(BaseModel):
|
|
194
|
+
layer: Literal["MaxBlurPooling2D"]
|
|
195
|
+
pool_size: PositiveInt = 2
|
|
196
|
+
filter_size: PositiveInt = 3
|
|
197
|
+
padding: PaddingTypeStr = "same"
|
|
198
|
+
|
|
199
|
+
def to_keras_layer(self) -> keras.layers.Layer:
|
|
200
|
+
return MaxBlurPooling2D(
|
|
201
|
+
pool_size=self.pool_size, filter_size=self.filter_size, padding=self.padding
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
class _MaxPooling2D(BaseModel):
|
|
206
|
+
layer: Literal["MaxPooling2D"]
|
|
207
|
+
pool_size: PositiveIntTuple = 2
|
|
208
|
+
strides: Optional[PositiveInt] = None
|
|
209
|
+
padding: PaddingTypeStr = "valid"
|
|
210
|
+
|
|
211
|
+
def to_keras_layer(self) -> keras.layers.Layer:
|
|
212
|
+
return keras.layers.MaxPooling2D(
|
|
213
|
+
pool_size=self.pool_size,
|
|
214
|
+
strides=self.strides,
|
|
215
|
+
padding=self.padding,
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
class _AveragePooling2D(BaseModel):
|
|
220
|
+
layer: Literal["AveragePooling2D"]
|
|
221
|
+
pool_size: PositiveIntTuple = 2
|
|
222
|
+
strides: Optional[PositiveInt] = None
|
|
223
|
+
padding: PaddingTypeStr = "valid"
|
|
224
|
+
|
|
225
|
+
def to_keras_layer(self) -> keras.layers.Layer:
|
|
226
|
+
return keras.layers.AveragePooling2D(
|
|
227
|
+
pool_size=self.pool_size,
|
|
228
|
+
strides=self.strides,
|
|
229
|
+
padding=self.padding,
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
class _ZeroPadding2D(BaseModel):
|
|
234
|
+
layer: Literal["ZeroPadding2D"]
|
|
235
|
+
padding: PositiveIntTuple = 1
|
|
236
|
+
|
|
237
|
+
def to_keras_layer(self) -> keras.layers.Layer:
|
|
238
|
+
return keras.layers.ZeroPadding2D(padding=self.padding)
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
class _SqueezeExcite(BaseModel):
|
|
242
|
+
layer: Literal["SqueezeExcite"]
|
|
243
|
+
ratio: PositiveFloat = 1.0
|
|
244
|
+
|
|
245
|
+
def to_keras_layer(self) -> keras.layers.Layer:
|
|
246
|
+
return SqueezeExcite(ratio=self.ratio)
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
class _BatchNormalization(BaseModel):
|
|
250
|
+
layer: Literal["BatchNormalization"]
|
|
251
|
+
momentum: PositiveFloat = 0.99
|
|
252
|
+
epsilon: PositiveFloat = 1e-3
|
|
253
|
+
center: bool = True
|
|
254
|
+
scale: bool = True
|
|
255
|
+
|
|
256
|
+
def to_keras_layer(self) -> keras.layers.Layer:
|
|
257
|
+
return keras.layers.BatchNormalization(
|
|
258
|
+
momentum=self.momentum,
|
|
259
|
+
epsilon=self.epsilon,
|
|
260
|
+
center=self.center,
|
|
261
|
+
scale=self.scale,
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
class _Dropout(BaseModel):
|
|
266
|
+
layer: Literal["Dropout"]
|
|
267
|
+
rate: PositiveFloat
|
|
268
|
+
|
|
269
|
+
def to_keras_layer(self):
|
|
270
|
+
return keras.layers.Dropout(rate=self.rate)
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
class _SpatialDropout2D(BaseModel):
|
|
274
|
+
layer: Literal["SpatialDropout2D"]
|
|
275
|
+
rate: PositiveFloat
|
|
276
|
+
|
|
277
|
+
def to_keras_layer(self) -> keras.layers.Layer:
|
|
278
|
+
return keras.layers.SpatialDropout2D(rate=self.rate)
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
class _GaussianNoise(BaseModel):
|
|
282
|
+
layer: Literal["GaussianNoise"]
|
|
283
|
+
stddev: PositiveFloat
|
|
284
|
+
seed: Optional[int] = None
|
|
285
|
+
|
|
286
|
+
def to_keras_layer(self) -> keras.layers.Layer:
|
|
287
|
+
return keras.layers.GaussianNoise(stddev=self.stddev, seed=self.seed)
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
class _LayerNorm(BaseModel):
|
|
291
|
+
layer: Literal["LayerNorm"]
|
|
292
|
+
epsilon: PositiveFloat = 1e-3
|
|
293
|
+
|
|
294
|
+
def to_keras_layer(self) -> keras.layers.Layer:
|
|
295
|
+
return keras.layers.LayerNormalization(epsilon=self.epsilon)
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
class _RMSNorm(BaseModel):
|
|
299
|
+
layer: Literal["RMSNorm"]
|
|
300
|
+
epsilon: PositiveFloat = 1e-6
|
|
301
|
+
|
|
302
|
+
def to_keras_layer(self) -> keras.layers.Layer:
|
|
303
|
+
return RMSNormalization(epsilon=self.epsilon)
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
class _DyT(BaseModel):
|
|
307
|
+
layer: Literal["DyT"]
|
|
308
|
+
alpha_init_value: PositiveFloat = 0.5
|
|
309
|
+
|
|
310
|
+
def to_keras_layer(self) -> keras.layers.Layer:
|
|
311
|
+
return DyT(alpha_init_value=self.alpha_init_value)
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
LayerConfig = Annotated[
|
|
315
|
+
_Activation
|
|
316
|
+
| _Conv2D
|
|
317
|
+
| _CoordConv2D
|
|
318
|
+
| _DepthwiseConv2D
|
|
319
|
+
| _SeparableConv2D
|
|
320
|
+
| _MLP
|
|
321
|
+
| _MaxBlurPooling2D
|
|
322
|
+
| _MaxPooling2D
|
|
323
|
+
| _AveragePooling2D
|
|
324
|
+
| _ZeroPadding2D
|
|
325
|
+
| _SqueezeExcite
|
|
326
|
+
| _BatchNormalization
|
|
327
|
+
| _Dropout
|
|
328
|
+
| _SpatialDropout2D
|
|
329
|
+
| _GaussianNoise
|
|
330
|
+
| _LayerNorm
|
|
331
|
+
| _RMSNorm
|
|
332
|
+
| _DyT,
|
|
333
|
+
Field(discriminator="layer"),
|
|
334
|
+
]
|
|
335
|
+
|
|
336
|
+
|
|
337
|
+
class _CCTTokenizerConfig(BaseModel):
|
|
338
|
+
blocks: list[LayerConfig]
|
|
339
|
+
patch_size: PositiveIntTuple = 1
|
|
340
|
+
patch_mlp: Optional[_MLP] = None
|
|
341
|
+
positional_emb: bool = True
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
class _CCTTransformerEncoderConfig(BaseModel):
|
|
345
|
+
layers: PositiveInt
|
|
346
|
+
heads: PositiveInt
|
|
347
|
+
projection_dim: PositiveInt
|
|
348
|
+
units: list[PositiveInt]
|
|
349
|
+
activation: ActivationStr = "gelu"
|
|
350
|
+
stochastic_depth: UnitFloat = 0.1
|
|
351
|
+
attention_dropout: UnitFloat = 0.1
|
|
352
|
+
mlp_dropout: UnitFloat = 0.1
|
|
353
|
+
head_mlp_dropout: UnitFloat = 0.2
|
|
354
|
+
token_reducer_heads: PositiveInt = 2
|
|
355
|
+
normalization: NormalizationStr = "layer_norm"
|
|
356
|
+
|
|
357
|
+
@model_validator(mode="after")
|
|
358
|
+
def _consistency_checks(self):
|
|
359
|
+
if self.units[-1] != self.projection_dim:
|
|
360
|
+
raise ValueError(
|
|
361
|
+
"'units[-1]' must equal 'projection_dim' "
|
|
362
|
+
f"(got {self.units[-1]} vs {self.projection_dim})."
|
|
363
|
+
)
|
|
364
|
+
return self
|
|
365
|
+
|
|
366
|
+
|
|
367
|
+
class CCTModelConfig(BaseModel):
|
|
368
|
+
model: Literal["cct"] = "cct"
|
|
369
|
+
rescaling: _Rescaling
|
|
370
|
+
tokenizer: _CCTTokenizerConfig
|
|
371
|
+
transformer_encoder: _CCTTransformerEncoderConfig
|
|
372
|
+
|
|
373
|
+
|
|
374
|
+
AnyModelConfig = Annotated[CCTModelConfig, Field(discriminator="model")]
|
|
375
|
+
"""Supported model-architecture. New model configs should be added here."""
|
|
376
|
+
|
|
377
|
+
|
|
378
|
+
def load_model_config_from_yaml(yaml_path: PathLike) -> AnyModelConfig:
|
|
379
|
+
"""
|
|
380
|
+
Loads, parses, and validates a YAML file defining a model architecture.
|
|
381
|
+
|
|
382
|
+
Args:
|
|
383
|
+
yaml_path: Path to the YAML file.
|
|
384
|
+
|
|
385
|
+
Returns:
|
|
386
|
+
AnyModelConfig: Parsed and validated model configuration.
|
|
387
|
+
|
|
388
|
+
Raises:
|
|
389
|
+
FileNotFoundError: If the YAML file does not exist.
|
|
390
|
+
"""
|
|
391
|
+
if not Path(yaml_path).is_file():
|
|
392
|
+
raise FileNotFoundError(f"Model config '{yaml_path}' doesn't exist.")
|
|
393
|
+
with open(yaml_path, encoding="utf-8") as f_in:
|
|
394
|
+
data = yaml.safe_load(f_in)
|
|
395
|
+
return AnyModelConfig(**data)
|
|
File without changes
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Utils for Keras supported backends.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
from typing import Literal, TypeAlias
|
|
7
|
+
|
|
8
|
+
Framework: TypeAlias = Literal["jax", "tensorflow", "torch"]
|
|
9
|
+
"""Supported backend frameworks for Keras."""
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def set_jax_backend() -> None:
|
|
13
|
+
"""Set Keras backend to jax."""
|
|
14
|
+
set_keras_backend("jax")
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def set_tensorflow_backend() -> None:
|
|
18
|
+
"""Set Keras backend to tensorflow."""
|
|
19
|
+
set_keras_backend("tensorflow")
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def set_pytorch_backend() -> None:
|
|
23
|
+
"""Set Keras backend to pytorch."""
|
|
24
|
+
set_keras_backend("torch")
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def set_keras_backend(framework: Framework) -> None:
|
|
28
|
+
"""Set the Keras backend to a given framework."""
|
|
29
|
+
os.environ["KERAS_BACKEND"] = framework
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def reload_keras_backend(framework: Framework) -> None:
|
|
33
|
+
"""Reload the Keras backend with a given framework."""
|
|
34
|
+
# ruff: noqa: PLC0415
|
|
35
|
+
# pylint: disable=import-outside-toplevel
|
|
36
|
+
import keras
|
|
37
|
+
|
|
38
|
+
keras.config.set_backend(framework)
|