eoml 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- eoml/__init__.py +74 -0
- eoml/automation/__init__.py +7 -0
- eoml/automation/configuration.py +105 -0
- eoml/automation/dag.py +233 -0
- eoml/automation/experience.py +618 -0
- eoml/automation/tasks.py +825 -0
- eoml/bin/__init__.py +6 -0
- eoml/bin/clean_checkpoint.py +146 -0
- eoml/bin/land_cover_mapping_toml.py +435 -0
- eoml/bin/mosaic_images.py +137 -0
- eoml/data/__init__.py +7 -0
- eoml/data/basic_geo_data.py +214 -0
- eoml/data/dataset_utils.py +98 -0
- eoml/data/persistence/__init__.py +7 -0
- eoml/data/persistence/generic.py +253 -0
- eoml/data/persistence/lmdb.py +379 -0
- eoml/data/persistence/serializer.py +82 -0
- eoml/raster/__init__.py +7 -0
- eoml/raster/band.py +141 -0
- eoml/raster/dataset/__init__.py +6 -0
- eoml/raster/dataset/extractor.py +604 -0
- eoml/raster/raster_reader.py +602 -0
- eoml/raster/raster_utils.py +116 -0
- eoml/torch/__init__.py +7 -0
- eoml/torch/cnn/__init__.py +7 -0
- eoml/torch/cnn/augmentation.py +150 -0
- eoml/torch/cnn/dataset_evaluator.py +68 -0
- eoml/torch/cnn/db_dataset.py +605 -0
- eoml/torch/cnn/map_dataset.py +579 -0
- eoml/torch/cnn/map_dataset_const_mem.py +135 -0
- eoml/torch/cnn/outputs_transformer.py +130 -0
- eoml/torch/cnn/torch_utils.py +404 -0
- eoml/torch/cnn/training_dataset.py +241 -0
- eoml/torch/cnn/windows_dataset.py +120 -0
- eoml/torch/dataset/__init__.py +6 -0
- eoml/torch/dataset/shade_dataset_tester.py +46 -0
- eoml/torch/dataset/shade_tree_dataset_creators.py +537 -0
- eoml/torch/model_low_use.py +507 -0
- eoml/torch/models.py +282 -0
- eoml/torch/resnet.py +437 -0
- eoml/torch/sample_statistic.py +260 -0
- eoml/torch/trainer.py +782 -0
- eoml/torch/trainer_v2.py +253 -0
- eoml-0.9.0.dist-info/METADATA +93 -0
- eoml-0.9.0.dist-info/RECORD +47 -0
- eoml-0.9.0.dist-info/WHEEL +4 -0
- eoml-0.9.0.dist-info/entry_points.txt +3 -0
eoml/bin/__init__.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Checkpoint cleanup utility for EOML.
|
|
3
|
+
|
|
4
|
+
This command-line utility helps manage disk space by cleaning up old checkpoint
|
|
5
|
+
files from model training. It keeps only the N most recent files in each
|
|
6
|
+
subdirectory, removing older checkpoints to save space.
|
|
7
|
+
|
|
8
|
+
Usage:
|
|
9
|
+
python clean_checkpoint.py <root_directory>
|
|
10
|
+
|
|
11
|
+
The script will:
|
|
12
|
+
1. Scan all subdirectories for files
|
|
13
|
+
2. Identify files sorted by modification time
|
|
14
|
+
3. Keep the 4 most recent files in each directory
|
|
15
|
+
4. Remove older files after user confirmation
|
|
16
|
+
5. Report disk space saved
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
import os
|
|
20
|
+
import typer
|
|
21
|
+
from pathlib import Path
|
|
22
|
+
from datetime import datetime
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def get_dir_size(path):
|
|
26
|
+
"""
|
|
27
|
+
Calculate total size of directory in bytes.
|
|
28
|
+
|
|
29
|
+
Recursively computes the total size of all files in a directory and its
|
|
30
|
+
subdirectories.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
path (str or Path): Path to the directory to measure.
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
int: Total size in bytes.
|
|
37
|
+
|
|
38
|
+
Examples:
|
|
39
|
+
>>> size = get_dir_size("/path/to/checkpoints")
|
|
40
|
+
>>> print(f"Directory size: {size / (1024**3):.2f} GB")
|
|
41
|
+
"""
|
|
42
|
+
total = 0
|
|
43
|
+
with os.scandir(path) as it:
|
|
44
|
+
for entry in it:
|
|
45
|
+
if entry.is_file():
|
|
46
|
+
total += entry.stat().st_size
|
|
47
|
+
elif entry.is_dir():
|
|
48
|
+
total += get_dir_size(entry.path)
|
|
49
|
+
return total
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def get_files_by_time(directory):
|
|
53
|
+
"""
|
|
54
|
+
Get list of files sorted by modification time.
|
|
55
|
+
|
|
56
|
+
Retrieves all files in a directory and sorts them by modification time
|
|
57
|
+
in descending order (most recent first).
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
directory (str or Path): Path to the directory to scan.
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
list: List of os.DirEntry objects sorted by modification time (newest first).
|
|
64
|
+
|
|
65
|
+
Examples:
|
|
66
|
+
>>> files = get_files_by_time("/path/to/checkpoints")
|
|
67
|
+
>>> most_recent = files[0] # Most recently modified file
|
|
68
|
+
"""
|
|
69
|
+
files = []
|
|
70
|
+
with os.scandir(directory) as it:
|
|
71
|
+
for entry in it:
|
|
72
|
+
if entry.is_file():
|
|
73
|
+
files.append(entry)
|
|
74
|
+
return sorted(files, key=lambda x: x.stat().st_mtime, reverse=True)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
app = typer.Typer()
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
@app.command()
|
|
81
|
+
def cleanup_folders(
|
|
82
|
+
root_dir: Path = typer.Argument(
|
|
83
|
+
...,
|
|
84
|
+
exists=True,
|
|
85
|
+
dir_okay=True,
|
|
86
|
+
file_okay=False,
|
|
87
|
+
help="Root directory to clean up"
|
|
88
|
+
)
|
|
89
|
+
):
|
|
90
|
+
"""
|
|
91
|
+
Clean up folders by keeping only 4 most recent files.
|
|
92
|
+
|
|
93
|
+
This command scans all subdirectories under root_dir and keeps only the
|
|
94
|
+
4 most recently modified files in each directory, removing older files
|
|
95
|
+
to save disk space.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
root_dir (Path): Root directory containing subdirectories with checkpoint files.
|
|
99
|
+
|
|
100
|
+
The function will:
|
|
101
|
+
- Show a list of files to be removed
|
|
102
|
+
- Ask for user confirmation before deleting
|
|
103
|
+
- Report the amount of disk space saved
|
|
104
|
+
- Display the number of files removed
|
|
105
|
+
|
|
106
|
+
Examples:
|
|
107
|
+
To clean up a checkpoints directory:
|
|
108
|
+
$ python clean_checkpoint.py /path/to/checkpoints
|
|
109
|
+
"""
|
|
110
|
+
root_path = Path(root_dir)
|
|
111
|
+
initial_size = get_dir_size(root_path)
|
|
112
|
+
|
|
113
|
+
files_to_remove = []
|
|
114
|
+
|
|
115
|
+
# Collect files to remove
|
|
116
|
+
for folder in root_path.glob('*/'):
|
|
117
|
+
if folder.is_dir():
|
|
118
|
+
files = get_files_by_time(folder)
|
|
119
|
+
if len(files) > 4:
|
|
120
|
+
files_to_remove.extend([f.path for f in files[4:]])
|
|
121
|
+
|
|
122
|
+
if not files_to_remove:
|
|
123
|
+
typer.echo("No files need to be removed.")
|
|
124
|
+
return
|
|
125
|
+
|
|
126
|
+
# Show files to be removed
|
|
127
|
+
typer.echo(f"The following {len(files_to_remove)} files will be removed:")
|
|
128
|
+
for file in files_to_remove:
|
|
129
|
+
typer.echo(f" {file}")
|
|
130
|
+
|
|
131
|
+
# Ask for confirmation
|
|
132
|
+
if typer.confirm('Do you want to proceed?'):
|
|
133
|
+
for file in files_to_remove:
|
|
134
|
+
os.remove(file)
|
|
135
|
+
|
|
136
|
+
final_size = get_dir_size(root_path)
|
|
137
|
+
saved = initial_size - final_size
|
|
138
|
+
|
|
139
|
+
typer.echo(f"\nSpace saved: {saved / (1024 * 1024):.2f} MB")
|
|
140
|
+
typer.echo(f"Number of files removed: {len(files_to_remove)}")
|
|
141
|
+
else:
|
|
142
|
+
typer.echo("Operation cancelled.")
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
if __name__ == '__main__':
|
|
146
|
+
app()
|
|
@@ -0,0 +1,435 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Land cover mapping script using TOML configuration.
|
|
3
|
+
|
|
4
|
+
This script demonstrates how to run a complete land cover mapping workflow
|
|
5
|
+
using configuration loaded from a TOML file, leveraging the ExperienceInfo
|
|
6
|
+
configuration system.
|
|
7
|
+
|
|
8
|
+
Usage:
|
|
9
|
+
python land_cover_mapping_toml.py <path_to_config.toml>
|
|
10
|
+
|
|
11
|
+
Example:
|
|
12
|
+
python land_cover_mapping_toml.py ../example_experience_config.toml
|
|
13
|
+
"""
|
|
14
|
+
import logging
|
|
15
|
+
import os
|
|
16
|
+
import sys
|
|
17
|
+
import random
|
|
18
|
+
from datetime import datetime
|
|
19
|
+
from pathlib import Path
|
|
20
|
+
|
|
21
|
+
import torch
|
|
22
|
+
from rasterio.enums import Resampling
|
|
23
|
+
from torch import nn
|
|
24
|
+
from torch.optim import AdamW
|
|
25
|
+
|
|
26
|
+
from rasterop.tiled_op.operation.mapping import CountCategoryToBand, MaxCategory, MaxScore
|
|
27
|
+
|
|
28
|
+
# Configure logging
|
|
29
|
+
logging.basicConfig(
|
|
30
|
+
level=logging.INFO,
|
|
31
|
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
32
|
+
)
|
|
33
|
+
logger = logging.getLogger(__name__)
|
|
34
|
+
|
|
35
|
+
from eoml.automation.experience import ExperienceInfo
|
|
36
|
+
from eoml.automation.tasks import (
|
|
37
|
+
samples_split_setup,
|
|
38
|
+
samples_k_fold_setup,
|
|
39
|
+
extract_sample,
|
|
40
|
+
train_and_map, tiled_task,
|
|
41
|
+
)
|
|
42
|
+
from eoml.torch.cnn.augmentation import RandomTransform, CropTransform
|
|
43
|
+
|
|
44
|
+
# TODO: Import these classes from the appropriate module
|
|
45
|
+
# from eoml.raster.operations import CountCategoryToBand, MaxCategory, MaxScore
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def run_land_cover_mapping(config_path: str):
|
|
49
|
+
"""
|
|
50
|
+
Run the complete land cover mapping workflow from a TOML configuration.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
config_path: Path to the TOML configuration file.
|
|
54
|
+
"""
|
|
55
|
+
# ----------------------------------------------------------------------------------------------------------------
|
|
56
|
+
# Setup
|
|
57
|
+
# ----------------------------------------------------------------------------------------------------------------
|
|
58
|
+
|
|
59
|
+
# For GPU support in multiple threads (also needed for mapping)
|
|
60
|
+
torch.multiprocessing.set_start_method('spawn')
|
|
61
|
+
|
|
62
|
+
logger.info(f"Loading configuration from: {config_path}")
|
|
63
|
+
|
|
64
|
+
# ----------------------------------------------------------------------------------------------------------------
|
|
65
|
+
# Load Configuration from TOML
|
|
66
|
+
# ----------------------------------------------------------------------------------------------------------------
|
|
67
|
+
|
|
68
|
+
# Load and validate the experience configuration
|
|
69
|
+
experience = ExperienceInfo.from_toml(config_path)
|
|
70
|
+
|
|
71
|
+
logger.info("Configuration loaded successfully!")
|
|
72
|
+
logger.info(f" GPS file: {experience.experiment.gps_file}")
|
|
73
|
+
logger.info(f" Model: {experience.experiment.model_name}")
|
|
74
|
+
logger.info(f" Extract size: {experience.experiment.extract_size}")
|
|
75
|
+
logger.info(f" Network size: {experience.experiment.size}")
|
|
76
|
+
logger.info(f" Epochs: {experience.experiment.epoch}")
|
|
77
|
+
logger.info(f" Batch multiplier: {experience.experiment.batch_mult}")
|
|
78
|
+
logger.info(f" N-fold: {experience.experiment.nfold}")
|
|
79
|
+
|
|
80
|
+
# Extract runtime objects from experience
|
|
81
|
+
raster_reader = experience.raster_reader
|
|
82
|
+
mapper_full = experience.mapper
|
|
83
|
+
nn_output_transformer = experience.nn_output_transformer
|
|
84
|
+
system_config = experience.system_config
|
|
85
|
+
|
|
86
|
+
# Extract configuration values for convenient access
|
|
87
|
+
map_bounds = experience.boundaries.map_bounds
|
|
88
|
+
map_mask = experience.boundaries.map_mask
|
|
89
|
+
sample_mask = experience.boundaries.sample_mask
|
|
90
|
+
gps_file = experience.experiment.gps_file
|
|
91
|
+
extract_size = experience.experiment.extract_size
|
|
92
|
+
size = experience.experiment.size
|
|
93
|
+
class_label = experience.experiment.class_label
|
|
94
|
+
model_name = experience.experiment.model_name
|
|
95
|
+
batch_mult = experience.experiment.batch_mult
|
|
96
|
+
batch_mult_map = experience.experiment.batch_mult_map
|
|
97
|
+
epoch = experience.experiment.epoch
|
|
98
|
+
map_tag_name = experience.experiment.map_tag_name
|
|
99
|
+
nfold = experience.experiment.nfold
|
|
100
|
+
|
|
101
|
+
# ----------------------------------------------------------------------------------------------------------------
|
|
102
|
+
# Random Seed Configuration
|
|
103
|
+
# ----------------------------------------------------------------------------------------------------------------
|
|
104
|
+
|
|
105
|
+
# Initialize all random seeds (Python, NumPy, PyTorch) and set deterministic mode if configured
|
|
106
|
+
seed_info = experience.experiment.initialize_seeds(verbose=True)
|
|
107
|
+
|
|
108
|
+
# ----------------------------------------------------------------------------------------------------------------
|
|
109
|
+
# Device Configuration
|
|
110
|
+
# ----------------------------------------------------------------------------------------------------------------
|
|
111
|
+
|
|
112
|
+
device = experience.experiment.get_device()
|
|
113
|
+
map_mode = experience.experiment.get_map_mode()
|
|
114
|
+
|
|
115
|
+
logger.info(f" Device: {device} (mode: {map_mode})")
|
|
116
|
+
|
|
117
|
+
# Log additional device info for multi-GPU setup
|
|
118
|
+
if isinstance(experience.experiment.device, list):
|
|
119
|
+
logger.info(f" Available GPUs: {experience.experiment.device}")
|
|
120
|
+
if torch.cuda.is_available():
|
|
121
|
+
for gpu_id in experience.experiment.device:
|
|
122
|
+
if gpu_id < torch.cuda.device_count():
|
|
123
|
+
logger.info(f" - GPU {gpu_id}: {torch.cuda.get_device_name(gpu_id)}")
|
|
124
|
+
else:
|
|
125
|
+
logger.warning(f" - GPU {gpu_id}: Not available")
|
|
126
|
+
|
|
127
|
+
# ----------------------------------------------------------------------------------------------------------------
|
|
128
|
+
# File Path Management
|
|
129
|
+
# ----------------------------------------------------------------------------------------------------------------
|
|
130
|
+
|
|
131
|
+
gps_path = gps_file
|
|
132
|
+
db_path = f"{system_config.data_dir}/land_cover/samples/{gps_file.stem}_lmdb_NaN_to_0_{extract_size}"
|
|
133
|
+
|
|
134
|
+
# Training output paths
|
|
135
|
+
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
|
136
|
+
run_name = f"ch-{timestamp}"
|
|
137
|
+
run_stats_dir = f"{system_config.data_dir}/land_cover/nn_run_stats"
|
|
138
|
+
model_base_path = f"{system_config.data_dir}/land_cover/nn/{run_name}"
|
|
139
|
+
|
|
140
|
+
logger.info(f"Output directory: {model_base_path}")
|
|
141
|
+
|
|
142
|
+
# ----------------------------------------------------------------------------------------------------------------
|
|
143
|
+
# Sample Extraction Configuration
|
|
144
|
+
# ----------------------------------------------------------------------------------------------------------------
|
|
145
|
+
|
|
146
|
+
extractor_param = {
|
|
147
|
+
"gps_path": gps_path,
|
|
148
|
+
"raster_reader": raster_reader,
|
|
149
|
+
"db_path": db_path,
|
|
150
|
+
"windows_size": extract_size,
|
|
151
|
+
"label_name": class_label,
|
|
152
|
+
"mask_path": sample_mask,
|
|
153
|
+
"force_write": False
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
# ----------------------------------------------------------------------------------------------------------------
|
|
157
|
+
# Sample Split Configuration (K-Fold or Simple Split)
|
|
158
|
+
# ----------------------------------------------------------------------------------------------------------------
|
|
159
|
+
|
|
160
|
+
# K-fold cross-validation (recommended)
|
|
161
|
+
sample_param_kfold = {
|
|
162
|
+
"methode": samples_k_fold_setup,
|
|
163
|
+
"param": {
|
|
164
|
+
"db_path": db_path,
|
|
165
|
+
"mapper": mapper_full,
|
|
166
|
+
"n_fold": nfold
|
|
167
|
+
}
|
|
168
|
+
}
|
|
169
|
+
|
|
170
|
+
# Simple train/validation split (alternative)
|
|
171
|
+
sample_param_split = {
|
|
172
|
+
"methode": samples_split_setup,
|
|
173
|
+
"param": {
|
|
174
|
+
"db_path": db_path,
|
|
175
|
+
"mapper": mapper_full,
|
|
176
|
+
"split": [0.8, 0.2]
|
|
177
|
+
}
|
|
178
|
+
}
|
|
179
|
+
|
|
180
|
+
# Use K-fold by default
|
|
181
|
+
sample_param = sample_param_kfold
|
|
182
|
+
|
|
183
|
+
# ----------------------------------------------------------------------------------------------------------------
|
|
184
|
+
# Data Augmentation Configuration
|
|
185
|
+
# ----------------------------------------------------------------------------------------------------------------
|
|
186
|
+
|
|
187
|
+
augmentation_param = {
|
|
188
|
+
"methode": "no_dep",
|
|
189
|
+
"transform_train": RandomTransform(
|
|
190
|
+
width=size,
|
|
191
|
+
p_rot=0.90,
|
|
192
|
+
p_flip=0.50,
|
|
193
|
+
p_scale=0.4,
|
|
194
|
+
p_shear=0.3,
|
|
195
|
+
p_blur=0.3
|
|
196
|
+
),
|
|
197
|
+
"transform_valid": CropTransform(size)
|
|
198
|
+
}
|
|
199
|
+
|
|
200
|
+
# ----------------------------------------------------------------------------------------------------------------
|
|
201
|
+
# DataLoader Configuration
|
|
202
|
+
# ----------------------------------------------------------------------------------------------------------------
|
|
203
|
+
|
|
204
|
+
dataloader_parameter = {
|
|
205
|
+
"batch_size": int(batch_mult * 1024),
|
|
206
|
+
"num_worker": 5,
|
|
207
|
+
"prefetch": 1,
|
|
208
|
+
"device": device,
|
|
209
|
+
"balance_sample": False,
|
|
210
|
+
"persistent_workers": True
|
|
211
|
+
}
|
|
212
|
+
|
|
213
|
+
# ----------------------------------------------------------------------------------------------------------------
|
|
214
|
+
# Neural Network Configuration
|
|
215
|
+
# ----------------------------------------------------------------------------------------------------------------
|
|
216
|
+
|
|
217
|
+
nn_parameter = {
|
|
218
|
+
"in_size": size,
|
|
219
|
+
"n_bands": raster_reader.n_band,
|
|
220
|
+
"n_out": len(mapper_full)
|
|
221
|
+
}
|
|
222
|
+
|
|
223
|
+
model_parameter = {
|
|
224
|
+
"model_name": model_name,
|
|
225
|
+
"type": "normal",
|
|
226
|
+
"path": None,
|
|
227
|
+
"device": device,
|
|
228
|
+
"nn_parameter": nn_parameter
|
|
229
|
+
}
|
|
230
|
+
|
|
231
|
+
# ----------------------------------------------------------------------------------------------------------------
|
|
232
|
+
# Optimizer Configuration
|
|
233
|
+
# ----------------------------------------------------------------------------------------------------------------
|
|
234
|
+
|
|
235
|
+
optimizer_parameter = {
|
|
236
|
+
"loss": nn.CrossEntropyLoss(),
|
|
237
|
+
"optimizer": AdamW,
|
|
238
|
+
"optimizer_parameter": {
|
|
239
|
+
"lr": 1.5 * 0.018 * 1e-2,
|
|
240
|
+
"weight_decay": 0.001 * 0.0020
|
|
241
|
+
},
|
|
242
|
+
"scheduler_mode": "cycle",
|
|
243
|
+
"scheduler_parameter": {
|
|
244
|
+
"max_lr": 0.0008
|
|
245
|
+
}
|
|
246
|
+
}
|
|
247
|
+
|
|
248
|
+
# ----------------------------------------------------------------------------------------------------------------
|
|
249
|
+
# Training Configuration
|
|
250
|
+
# ----------------------------------------------------------------------------------------------------------------
|
|
251
|
+
|
|
252
|
+
dataset_parameter = {
|
|
253
|
+
"db_path": db_path,
|
|
254
|
+
"mapper": mapper_full
|
|
255
|
+
}
|
|
256
|
+
|
|
257
|
+
train_nn_parameter = {
|
|
258
|
+
"max_epochs": epoch,
|
|
259
|
+
"run_stats_dir": run_stats_dir,
|
|
260
|
+
"model_base_path": model_base_path,
|
|
261
|
+
"model_tag": model_name,
|
|
262
|
+
"grad_clip_value": 0.1,
|
|
263
|
+
"device": device
|
|
264
|
+
}
|
|
265
|
+
|
|
266
|
+
train_parameter = {
|
|
267
|
+
"sample_param": sample_param,
|
|
268
|
+
"augmentation_param": augmentation_param,
|
|
269
|
+
"dataset_parameter": dataset_parameter,
|
|
270
|
+
"dataloader_parameter": dataloader_parameter,
|
|
271
|
+
"model_parameter": model_parameter,
|
|
272
|
+
"optimizer_parameter": optimizer_parameter,
|
|
273
|
+
"train_nn_parameter": train_nn_parameter
|
|
274
|
+
}
|
|
275
|
+
|
|
276
|
+
# ----------------------------------------------------------------------------------------------------------------
|
|
277
|
+
# Mapping Configuration
|
|
278
|
+
# ----------------------------------------------------------------------------------------------------------------
|
|
279
|
+
|
|
280
|
+
map_parameter = {
|
|
281
|
+
"raster_reader": raster_reader,
|
|
282
|
+
"windows_size": size,
|
|
283
|
+
"batch_size": int(batch_mult_map * 1024),
|
|
284
|
+
"map_tag": map_tag_name,
|
|
285
|
+
"transformer": nn_output_transformer,
|
|
286
|
+
"mask": map_mask,
|
|
287
|
+
"bounds": map_bounds,
|
|
288
|
+
"mode": map_mode,
|
|
289
|
+
"num_worker": 7,
|
|
290
|
+
"prefetch": 1
|
|
291
|
+
}
|
|
292
|
+
|
|
293
|
+
# Map modes:
|
|
294
|
+
# 0 - Full CPU, no pinning
|
|
295
|
+
# 1 - Pinned memory in loader, moved asynchronously to GPU (recommended for GPU)
|
|
296
|
+
# 2 - Start CUDA in each thread, prepare samples directly on GPU
|
|
297
|
+
# (uses ~1GB per thread, requires torch.multiprocessing.set_start_method('spawn'))
|
|
298
|
+
|
|
299
|
+
train_map_parameter = train_parameter.copy()
|
|
300
|
+
train_map_parameter.update({"map_parameter": map_parameter})
|
|
301
|
+
|
|
302
|
+
# ----------------------------------------------------------------------------------------------------------------
|
|
303
|
+
# Execute Workflow
|
|
304
|
+
# ----------------------------------------------------------------------------------------------------------------
|
|
305
|
+
|
|
306
|
+
logger.info("=" * 80)
|
|
307
|
+
logger.info("STARTING LAND COVER MAPPING WORKFLOW")
|
|
308
|
+
logger.info("=" * 80)
|
|
309
|
+
|
|
310
|
+
# Create output directory
|
|
311
|
+
os.makedirs(model_base_path, exist_ok=True)
|
|
312
|
+
|
|
313
|
+
# Save configuration log
|
|
314
|
+
with open(f"{model_base_path}/log.txt", "w") as log:
|
|
315
|
+
log.write(repr(train_map_parameter))
|
|
316
|
+
|
|
317
|
+
# Copy TOML configuration to output directory for reference
|
|
318
|
+
import shutil
|
|
319
|
+
shutil.copy(config_path, f"{model_base_path}/config.toml")
|
|
320
|
+
logger.info(f"Configuration saved to: {model_base_path}/config.toml")
|
|
321
|
+
|
|
322
|
+
# Step 1: Extract samples from raster data
|
|
323
|
+
logger.info("[1/4] Extracting samples from raster data...")
|
|
324
|
+
extract_sample(**extractor_param)
|
|
325
|
+
logger.info("✓ Sample extraction complete")
|
|
326
|
+
|
|
327
|
+
# Step 2: Train model and generate maps
|
|
328
|
+
logger.info("[2/4] Training model and generating maps...")
|
|
329
|
+
maps = train_and_map(**train_map_parameter)
|
|
330
|
+
logger.info(f"✓ Training complete, generated {len(maps)} maps")
|
|
331
|
+
|
|
332
|
+
# ----------------------------------------------------------------------------------------------------------------
|
|
333
|
+
# Post-Processing: Merge and Aggregate Maps
|
|
334
|
+
# ----------------------------------------------------------------------------------------------------------------
|
|
335
|
+
|
|
336
|
+
logger.info("[3/4] Post-processing maps...")
|
|
337
|
+
|
|
338
|
+
raster_out_merge = f"{model_base_path}/01_{run_name}_merged.tif"
|
|
339
|
+
raster_out_score = f"{model_base_path}/02_{run_name}_max_arg.tif"
|
|
340
|
+
raster_out_score_max = f"{model_base_path}/02_{run_name}_max_score.tif"
|
|
341
|
+
|
|
342
|
+
default_op_param = {
|
|
343
|
+
"bounds": map_bounds,
|
|
344
|
+
"res": None,
|
|
345
|
+
"resampling": Resampling.nearest,
|
|
346
|
+
"target_aligned_pixels": False,
|
|
347
|
+
"indexes": None,
|
|
348
|
+
"src_kwds": None,
|
|
349
|
+
"dst_kwds": None,
|
|
350
|
+
"num_workers": 8
|
|
351
|
+
}
|
|
352
|
+
|
|
353
|
+
# TODO: Uncomment when CountCategoryToBand, MaxCategory, MaxScore are available
|
|
354
|
+
#
|
|
355
|
+
# Count categories across all maps
|
|
356
|
+
category_count_op = CountCategoryToBand(max(mapper_full.map_values()), dtype="int16")
|
|
357
|
+
category_count_param = {
|
|
358
|
+
"maps": maps,
|
|
359
|
+
"raster_out": raster_out_merge,
|
|
360
|
+
"operation": category_count_op
|
|
361
|
+
}
|
|
362
|
+
category_count_param.update(default_op_param)
|
|
363
|
+
|
|
364
|
+
# Find maximum category (mode)
|
|
365
|
+
category_max_op = MaxCategory()
|
|
366
|
+
category_max_param = {
|
|
367
|
+
"maps": [raster_out_merge],
|
|
368
|
+
"raster_out": raster_out_score,
|
|
369
|
+
"operation": category_max_op
|
|
370
|
+
}
|
|
371
|
+
category_max_param.update(default_op_param)
|
|
372
|
+
|
|
373
|
+
# Find maximum score (confidence)
|
|
374
|
+
category_max_score_op = MaxScore()
|
|
375
|
+
category_max_score_param = {
|
|
376
|
+
"maps": [raster_out_merge],
|
|
377
|
+
"raster_out": raster_out_score_max,
|
|
378
|
+
"operation": category_max_score_op
|
|
379
|
+
}
|
|
380
|
+
category_max_score_param.update(default_op_param)
|
|
381
|
+
|
|
382
|
+
# Execute tiled operations
|
|
383
|
+
logger.info(f" - Merging {len(maps)} maps...")
|
|
384
|
+
tiled_task(**category_count_param)
|
|
385
|
+
logger.info(" - Computing maximum category...")
|
|
386
|
+
tiled_task(**category_max_param)
|
|
387
|
+
logger.info(" - Computing maximum score...")
|
|
388
|
+
tiled_task(**category_max_score_param)
|
|
389
|
+
#
|
|
390
|
+
# logger.info("✓ Post-processing complete")
|
|
391
|
+
|
|
392
|
+
logger.warning("Post-processing operations are commented out.")
|
|
393
|
+
logger.warning("Uncomment the operations in the code once the required classes are available.")
|
|
394
|
+
|
|
395
|
+
# ----------------------------------------------------------------------------------------------------------------
|
|
396
|
+
# Done
|
|
397
|
+
# ----------------------------------------------------------------------------------------------------------------
|
|
398
|
+
|
|
399
|
+
logger.info("[4/4] Workflow complete!")
|
|
400
|
+
logger.info("=" * 80)
|
|
401
|
+
logger.info("RESULTS")
|
|
402
|
+
logger.info("=" * 80)
|
|
403
|
+
logger.info(f"Output directory: {model_base_path}")
|
|
404
|
+
logger.info(f"Maps generated: {len(maps)}")
|
|
405
|
+
for i, map_path in enumerate(maps, 1):
|
|
406
|
+
logger.info(f" [{i}] {map_path}")
|
|
407
|
+
# logger.info(f"\nMerged output: {raster_out_merge}")
|
|
408
|
+
# logger.info(f"Category map: {raster_out_score}")
|
|
409
|
+
# logger.info(f"Confidence map: {raster_out_score_max}")
|
|
410
|
+
logger.info("=" * 80)
|
|
411
|
+
|
|
412
|
+
|
|
413
|
+
def main():
|
|
414
|
+
"""Main entry point for the script."""
|
|
415
|
+
if len(sys.argv) < 2:
|
|
416
|
+
logger.error("Usage: python land_cover_mapping_toml.py <path_to_config.toml>")
|
|
417
|
+
logger.info("Example:")
|
|
418
|
+
logger.info(" python land_cover_mapping_toml.py ../example_experience_config.toml")
|
|
419
|
+
sys.exit(1)
|
|
420
|
+
|
|
421
|
+
config_path = sys.argv[1]
|
|
422
|
+
|
|
423
|
+
if not os.path.exists(config_path):
|
|
424
|
+
logger.error(f"Configuration file not found: {config_path}")
|
|
425
|
+
sys.exit(1)
|
|
426
|
+
|
|
427
|
+
try:
|
|
428
|
+
run_land_cover_mapping(config_path)
|
|
429
|
+
except Exception as e:
|
|
430
|
+
logger.exception(f"Error during execution: {e}")
|
|
431
|
+
sys.exit(1)
|
|
432
|
+
|
|
433
|
+
|
|
434
|
+
if __name__ == '__main__':
|
|
435
|
+
main()
|