eoml 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- eoml/__init__.py +74 -0
- eoml/automation/__init__.py +7 -0
- eoml/automation/configuration.py +105 -0
- eoml/automation/dag.py +233 -0
- eoml/automation/experience.py +618 -0
- eoml/automation/tasks.py +825 -0
- eoml/bin/__init__.py +6 -0
- eoml/bin/clean_checkpoint.py +146 -0
- eoml/bin/land_cover_mapping_toml.py +435 -0
- eoml/bin/mosaic_images.py +137 -0
- eoml/data/__init__.py +7 -0
- eoml/data/basic_geo_data.py +214 -0
- eoml/data/dataset_utils.py +98 -0
- eoml/data/persistence/__init__.py +7 -0
- eoml/data/persistence/generic.py +253 -0
- eoml/data/persistence/lmdb.py +379 -0
- eoml/data/persistence/serializer.py +82 -0
- eoml/raster/__init__.py +7 -0
- eoml/raster/band.py +141 -0
- eoml/raster/dataset/__init__.py +6 -0
- eoml/raster/dataset/extractor.py +604 -0
- eoml/raster/raster_reader.py +602 -0
- eoml/raster/raster_utils.py +116 -0
- eoml/torch/__init__.py +7 -0
- eoml/torch/cnn/__init__.py +7 -0
- eoml/torch/cnn/augmentation.py +150 -0
- eoml/torch/cnn/dataset_evaluator.py +68 -0
- eoml/torch/cnn/db_dataset.py +605 -0
- eoml/torch/cnn/map_dataset.py +579 -0
- eoml/torch/cnn/map_dataset_const_mem.py +135 -0
- eoml/torch/cnn/outputs_transformer.py +130 -0
- eoml/torch/cnn/torch_utils.py +404 -0
- eoml/torch/cnn/training_dataset.py +241 -0
- eoml/torch/cnn/windows_dataset.py +120 -0
- eoml/torch/dataset/__init__.py +6 -0
- eoml/torch/dataset/shade_dataset_tester.py +46 -0
- eoml/torch/dataset/shade_tree_dataset_creators.py +537 -0
- eoml/torch/model_low_use.py +507 -0
- eoml/torch/models.py +282 -0
- eoml/torch/resnet.py +437 -0
- eoml/torch/sample_statistic.py +260 -0
- eoml/torch/trainer.py +782 -0
- eoml/torch/trainer_v2.py +253 -0
- eoml-0.9.0.dist-info/METADATA +93 -0
- eoml-0.9.0.dist-info/RECORD +47 -0
- eoml-0.9.0.dist-info/WHEEL +4 -0
- eoml-0.9.0.dist-info/entry_points.txt +3 -0
|
@@ -0,0 +1,618 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Experience information module for machine learning experiments.
|
|
3
|
+
|
|
4
|
+
This module defines data structures for storing experiment information including
|
|
5
|
+
raster readers, mappers, transformers, and spatial bounds for ML experiments.
|
|
6
|
+
"""
|
|
7
|
+
import logging
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Any, Optional, Dict, List, Union, Literal
|
|
10
|
+
import tomli
|
|
11
|
+
from pydantic import BaseModel, Field, field_validator, model_validator, FilePath
|
|
12
|
+
|
|
13
|
+
from eoml.automation.configuration import SystemConfigModel
|
|
14
|
+
from eoml.raster.raster_utils import read_gdal_stats, SigmaNormalizer
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
from eoml.raster.raster_reader import RasterReader, MultiRasterReader, append_raster_reader
|
|
19
|
+
from eoml.raster.band import Band
|
|
20
|
+
from eoml.torch.cnn.db_dataset import Mapper
|
|
21
|
+
from eoml.torch.cnn.outputs_transformer import OutputTransformer
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class MapperCategoryConfig(BaseModel):
|
|
25
|
+
"""Configuration for a single mapper category."""
|
|
26
|
+
|
|
27
|
+
name: str = Field(
|
|
28
|
+
...,
|
|
29
|
+
description="Name of the output category"
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
labels: List[Union[int, str]] = Field(
|
|
33
|
+
...,
|
|
34
|
+
description="List of input labels that map to this category",
|
|
35
|
+
min_length=1
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
map_value: Optional[int] = Field(
|
|
39
|
+
None,
|
|
40
|
+
description="Output value for this category. If None, uses category index"
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class MapperConfig(BaseModel):
|
|
45
|
+
"""Configuration for the label mapper."""
|
|
46
|
+
|
|
47
|
+
no_target: int = Field(
|
|
48
|
+
-1,
|
|
49
|
+
description="Value to use for invalid/missing labels"
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
vectorize: bool = Field(
|
|
53
|
+
False,
|
|
54
|
+
description="Whether to use one-hot vector outputs instead of scalar"
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
label_dictionary: Optional[Dict[str, int]] = Field(
|
|
58
|
+
None,
|
|
59
|
+
description="Optional mapping from label names to integer values"
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
categories: List[MapperCategoryConfig] = Field(
|
|
63
|
+
...,
|
|
64
|
+
description="List of output categories",
|
|
65
|
+
min_length=1
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
def build_mapper(self) -> Mapper:
|
|
69
|
+
"""Build a Mapper instance from this configuration."""
|
|
70
|
+
mapper = Mapper(
|
|
71
|
+
no_target=self.no_target,
|
|
72
|
+
vectorize=self.vectorize,
|
|
73
|
+
label_dictionary=self.label_dictionary
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
for category in self.categories:
|
|
77
|
+
mapper.add_category(
|
|
78
|
+
name=category.name,
|
|
79
|
+
labels=category.labels,
|
|
80
|
+
map_value=category.map_value
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
return mapper
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class RasterReaderConfig(BaseModel):
|
|
87
|
+
"""Configuration for a single raster reader."""
|
|
88
|
+
|
|
89
|
+
type: Literal["single"] = Field(
|
|
90
|
+
"single",
|
|
91
|
+
description="Type of raster reader"
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
path: FilePath = Field(
|
|
95
|
+
...,
|
|
96
|
+
description="Path to the raster file"
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
bands: Optional[List[int]] = Field(
|
|
100
|
+
None,
|
|
101
|
+
description="List of band indices to use (1-indexed). If None, uses all bands"
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
stats_path: Optional[FilePath] = Field(
|
|
105
|
+
None,
|
|
106
|
+
description="Path to statistics file for normalization"
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
interpolation: Optional[str] = Field(
|
|
110
|
+
None,
|
|
111
|
+
description="Interpolation method for resampling"
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
read_profile: Optional[Dict[str, Any]] = Field(
|
|
115
|
+
None,
|
|
116
|
+
description="Rasterio read profile configuration"
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
sharing: bool = Field(
|
|
120
|
+
False,
|
|
121
|
+
description="Enable file sharing mode"
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
def build_reader(self) -> RasterReader:
|
|
125
|
+
"""Build a RasterReader instance from this configuration."""
|
|
126
|
+
# Create Band object
|
|
127
|
+
if self.bands is not None:
|
|
128
|
+
band = Band(self.bands)
|
|
129
|
+
else:
|
|
130
|
+
band = Band.from_file(self.path)
|
|
131
|
+
|
|
132
|
+
# Load transformer (normalizer) if stats provided
|
|
133
|
+
normalizers = None
|
|
134
|
+
if self.stats_path:
|
|
135
|
+
raster_stat = read_gdal_stats(self.stats_path)
|
|
136
|
+
|
|
137
|
+
normalizers = SigmaNormalizer(raster_stat[band.selected, 0],
|
|
138
|
+
raster_stat[band.selected, 1],
|
|
139
|
+
3, True, 0)
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
return RasterReader(
|
|
143
|
+
path=self.path,
|
|
144
|
+
bands_list=band,
|
|
145
|
+
transformer=normalizers,
|
|
146
|
+
interpolation=self.interpolation,
|
|
147
|
+
read_profile=self.read_profile,
|
|
148
|
+
sharing=self.sharing
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
class MultiRasterReaderConfig(BaseModel):
|
|
153
|
+
"""Configuration for multiple raster readers."""
|
|
154
|
+
|
|
155
|
+
type: Literal["multi"] = Field(
|
|
156
|
+
"multi",
|
|
157
|
+
description="Type of raster reader"
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
readers: List[RasterReaderConfig] = Field(
|
|
161
|
+
...,
|
|
162
|
+
description="List of raster reader configurations",
|
|
163
|
+
min_length=1
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
reference_index: int = Field(
|
|
167
|
+
0,
|
|
168
|
+
description="Index of the reader to use as spatial reference",
|
|
169
|
+
ge=0
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
read_profile: Optional[Dict[str, Any]] = Field(
|
|
173
|
+
None,
|
|
174
|
+
description="Rasterio read profile configuration"
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
sharing: bool = Field(
|
|
178
|
+
False,
|
|
179
|
+
description="Enable file sharing mode"
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
@field_validator('reference_index')
|
|
183
|
+
@classmethod
|
|
184
|
+
def validate_reference_index(cls, v, info):
|
|
185
|
+
"""Ensure reference_index is within bounds of readers list."""
|
|
186
|
+
readers = info.data.get('readers', [])
|
|
187
|
+
if readers and v >= len(readers):
|
|
188
|
+
raise ValueError(f"reference_index {v} is out of bounds for {len(readers)} readers")
|
|
189
|
+
return v
|
|
190
|
+
|
|
191
|
+
def build_reader(self) -> MultiRasterReader:
|
|
192
|
+
"""Build a MultiRasterReader instance from this configuration."""
|
|
193
|
+
readers_list = [r.build_reader() for r in self.readers]
|
|
194
|
+
|
|
195
|
+
return append_raster_reader(
|
|
196
|
+
readers_list,
|
|
197
|
+
reference_index=self.reference_index,
|
|
198
|
+
read_profile=self.read_profile,
|
|
199
|
+
sharing=self.sharing
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
class BoundariesConfig(BaseModel):
|
|
204
|
+
"""Configuration for spatial boundaries and masks."""
|
|
205
|
+
|
|
206
|
+
map_bounds: Optional[List[float]] = Field(
|
|
207
|
+
None,
|
|
208
|
+
description="Spatial bounds for mapping [minx, miny, maxx, maxy]",
|
|
209
|
+
min_length=4,
|
|
210
|
+
max_length=4
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
map_mask: Optional[FilePath] = Field(
|
|
214
|
+
None,
|
|
215
|
+
description="Path to mask defining valid mapping areas"
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
sample_mask: Optional[FilePath] = Field(
|
|
219
|
+
None,
|
|
220
|
+
description="Path to mask for filtering training/validation samples"
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
@field_validator('map_bounds')
|
|
224
|
+
@classmethod
|
|
225
|
+
def validate_bounds(cls, v):
|
|
226
|
+
"""Ensure bounds are valid [minx, miny, maxx, maxy]."""
|
|
227
|
+
if v is not None:
|
|
228
|
+
if len(v) != 4:
|
|
229
|
+
raise ValueError("map_bounds must contain exactly 4 values [minx, miny, maxx, maxy]")
|
|
230
|
+
return v
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
class ExperimentConfig(BaseModel):
|
|
234
|
+
"""Configuration for experiment parameters."""
|
|
235
|
+
|
|
236
|
+
gps_file: FilePath = Field(
|
|
237
|
+
...,
|
|
238
|
+
description="Name of the geopackage file (without extension)"
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
extract_size: int = Field(
|
|
242
|
+
47,
|
|
243
|
+
description="Size of extracted windows from raster data",
|
|
244
|
+
gt=0
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
size: int = Field(
|
|
248
|
+
31,
|
|
249
|
+
description="Size of input windows for the neural network",
|
|
250
|
+
gt=0
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
class_label: str = Field(
|
|
254
|
+
"Class",
|
|
255
|
+
description="Name of the class label column in the geopackage"
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
model_name: str = Field(
|
|
259
|
+
"Resnet20",
|
|
260
|
+
description="Name of the neural network model to use"
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
batch_mult: float = Field(
|
|
264
|
+
0.25,
|
|
265
|
+
description="Batch size multiplier for training",
|
|
266
|
+
gt=0
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
batch_mult_map: float = Field(
|
|
270
|
+
0.5,
|
|
271
|
+
description="Batch size multiplier for mapping",
|
|
272
|
+
gt=0
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
epoch: int = Field(
|
|
276
|
+
1,
|
|
277
|
+
description="Number of training epochs",
|
|
278
|
+
gt=0
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
map_tag_name: Optional[str] = Field(
|
|
282
|
+
None,
|
|
283
|
+
description="Tag name for the mapping output"
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
nfold: int = Field(
|
|
287
|
+
5,
|
|
288
|
+
description="Number of folds for k-fold cross-validation",
|
|
289
|
+
gt=0
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
device: Union[str, List[int]] = Field(
|
|
293
|
+
"auto",
|
|
294
|
+
description="Device selection: 'auto', 'cpu', 'cuda', 'gpu', or list of CUDA device IDs [0, 1, 2]"
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
random_seed: Optional[int] = Field(
|
|
298
|
+
None,
|
|
299
|
+
description="Master random seed for reproducibility. If None, a random seed will be generated. "
|
|
300
|
+
"This sets the default for all other seeds if they are not specified.",
|
|
301
|
+
ge=0
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
python_seed: Optional[int] = Field(
|
|
305
|
+
None,
|
|
306
|
+
description="Seed for Python's random module. If None, uses random_seed.",
|
|
307
|
+
ge=0
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
numpy_seed: Optional[int] = Field(
|
|
311
|
+
None,
|
|
312
|
+
description="Seed for NumPy's random number generator. If None, uses random_seed.",
|
|
313
|
+
ge=0
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
torch_seed: Optional[int] = Field(
|
|
317
|
+
None,
|
|
318
|
+
description="Seed for PyTorch's random number generator. If None, uses random_seed.",
|
|
319
|
+
ge=0
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
torch_deterministic: bool = Field(
|
|
323
|
+
False,
|
|
324
|
+
description="Enable deterministic behavior in PyTorch (may reduce performance). "
|
|
325
|
+
"When True, sets torch.use_deterministic_algorithms(True) and "
|
|
326
|
+
"configures cuDNN for deterministic behavior."
|
|
327
|
+
)
|
|
328
|
+
|
|
329
|
+
@field_validator('size')
|
|
330
|
+
@classmethod
|
|
331
|
+
def validate_size(cls, v, info):
|
|
332
|
+
"""Ensure size <= extract_size."""
|
|
333
|
+
extract_size = info.data.get('extract_size')
|
|
334
|
+
if extract_size is not None and v > extract_size:
|
|
335
|
+
raise ValueError(f"size ({v}) must be <= extract_size ({extract_size})")
|
|
336
|
+
return v
|
|
337
|
+
|
|
338
|
+
@field_validator('map_tag_name')
|
|
339
|
+
@classmethod
|
|
340
|
+
def set_default_map_tag_name(cls, v, info):
|
|
341
|
+
"""Set default map_tag_name based on gps_file if not provided."""
|
|
342
|
+
if v is None:
|
|
343
|
+
gps_file = info.data.get('gps_file', 'CH_39_all')
|
|
344
|
+
return f"CH_2022_{gps_file}"
|
|
345
|
+
return v
|
|
346
|
+
|
|
347
|
+
@field_validator('device')
|
|
348
|
+
@classmethod
|
|
349
|
+
def validate_device(cls, v):
|
|
350
|
+
"""Validate device configuration."""
|
|
351
|
+
if isinstance(v, str):
|
|
352
|
+
v_lower = v.lower()
|
|
353
|
+
if v_lower not in ['auto', 'automatic', 'cpu', 'cuda', 'gpu']:
|
|
354
|
+
raise ValueError(
|
|
355
|
+
f"Invalid device string '{v}'. Must be one of: 'auto', 'automatic', 'cpu', 'cuda', 'gpu'"
|
|
356
|
+
)
|
|
357
|
+
# Normalize to standard values
|
|
358
|
+
if v_lower in ['automatic', 'gpu']:
|
|
359
|
+
return 'auto' if v_lower == 'automatic' else 'cuda'
|
|
360
|
+
return v_lower
|
|
361
|
+
elif isinstance(v, list):
|
|
362
|
+
# Validate list of device IDs
|
|
363
|
+
if not all(isinstance(x, int) and x >= 0 for x in v):
|
|
364
|
+
raise ValueError(
|
|
365
|
+
f"Device list must contain only non-negative integers, got: {v}"
|
|
366
|
+
)
|
|
367
|
+
if len(v) == 0:
|
|
368
|
+
raise ValueError("Device list cannot be empty")
|
|
369
|
+
return v
|
|
370
|
+
else:
|
|
371
|
+
raise ValueError(
|
|
372
|
+
f"Device must be a string or list of integers, got: {type(v).__name__}"
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
@field_validator('random_seed')
|
|
376
|
+
@classmethod
|
|
377
|
+
def validate_random_seed(cls, v):
|
|
378
|
+
"""Validate random seed and generate one if None."""
|
|
379
|
+
if v is None:
|
|
380
|
+
# Generate a random seed
|
|
381
|
+
import random
|
|
382
|
+
import time
|
|
383
|
+
return int(time.time() * 1000) % (2**31) # Use timestamp-based seed
|
|
384
|
+
return v
|
|
385
|
+
|
|
386
|
+
@model_validator(mode='after')
|
|
387
|
+
def set_default_seeds(self):
|
|
388
|
+
"""
|
|
389
|
+
Set individual seeds based on random_seed if not specified.
|
|
390
|
+
|
|
391
|
+
Instead of using identical seeds (which can cause correlations), we derive
|
|
392
|
+
independent seeds from the master seed using a simple but effective method:
|
|
393
|
+
- python_seed = random_seed + 0
|
|
394
|
+
- numpy_seed = random_seed + 1
|
|
395
|
+
- torch_seed = random_seed + 2
|
|
396
|
+
|
|
397
|
+
This ensures reproducibility while avoiding unwanted correlations between RNGs.
|
|
398
|
+
"""
|
|
399
|
+
if self.python_seed is None:
|
|
400
|
+
self.python_seed = self.random_seed
|
|
401
|
+
if self.numpy_seed is None:
|
|
402
|
+
# Derive a different seed to avoid correlation
|
|
403
|
+
self.numpy_seed = (self.random_seed + 1) % (2**31)
|
|
404
|
+
if self.torch_seed is None:
|
|
405
|
+
# Derive yet another different seed
|
|
406
|
+
self.torch_seed = (self.random_seed + 2) % (2**31)
|
|
407
|
+
return self
|
|
408
|
+
|
|
409
|
+
def get_device(self) -> str:
|
|
410
|
+
"""
|
|
411
|
+
Get the PyTorch device string based on configuration.
|
|
412
|
+
|
|
413
|
+
Returns:
|
|
414
|
+
str: PyTorch device string (e.g., 'cpu', 'cuda', 'cuda:0', 'cuda:1')
|
|
415
|
+
"""
|
|
416
|
+
import torch
|
|
417
|
+
|
|
418
|
+
if isinstance(self.device, list):
|
|
419
|
+
# Use first device in list as primary
|
|
420
|
+
if torch.cuda.is_available():
|
|
421
|
+
return f"cuda:{self.device[0]}"
|
|
422
|
+
else:
|
|
423
|
+
logger.warning("CUDA not available, falling back to CPU")
|
|
424
|
+
return "cpu"
|
|
425
|
+
|
|
426
|
+
device_str = self.device.lower()
|
|
427
|
+
|
|
428
|
+
if device_str == 'auto':
|
|
429
|
+
return 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
430
|
+
elif device_str in ['cuda', 'gpu']:
|
|
431
|
+
if torch.cuda.is_available():
|
|
432
|
+
return 'cuda'
|
|
433
|
+
else:
|
|
434
|
+
logger.warning("CUDA not available, falling back to CPU")
|
|
435
|
+
return 'cpu'
|
|
436
|
+
else: # 'cpu'
|
|
437
|
+
return 'cpu'
|
|
438
|
+
|
|
439
|
+
def get_map_mode(self) -> int:
|
|
440
|
+
"""
|
|
441
|
+
Get the mapping mode based on device configuration.
|
|
442
|
+
|
|
443
|
+
Returns:
|
|
444
|
+
int: Mapping mode (0 for CPU, 1 for GPU)
|
|
445
|
+
"""
|
|
446
|
+
device = self.get_device()
|
|
447
|
+
return 1 if device.startswith('cuda') else 0
|
|
448
|
+
|
|
449
|
+
def initialize_seeds(self, verbose: bool = True) -> Dict[str, int]:
|
|
450
|
+
"""
|
|
451
|
+
Initialize all random number generators with configured seeds.
|
|
452
|
+
|
|
453
|
+
Args:
|
|
454
|
+
verbose: If True, print seed information. Defaults to True.
|
|
455
|
+
|
|
456
|
+
Returns:
|
|
457
|
+
Dict[str, int]: Dictionary of all seeds that were set.
|
|
458
|
+
"""
|
|
459
|
+
import random
|
|
460
|
+
import numpy as np
|
|
461
|
+
import torch
|
|
462
|
+
|
|
463
|
+
seed_info = {
|
|
464
|
+
'master_seed': self.random_seed,
|
|
465
|
+
'python_seed': self.python_seed,
|
|
466
|
+
'numpy_seed': self.numpy_seed,
|
|
467
|
+
'torch_seed': self.torch_seed,
|
|
468
|
+
}
|
|
469
|
+
|
|
470
|
+
# Set Python random seed
|
|
471
|
+
random.seed(self.python_seed)
|
|
472
|
+
|
|
473
|
+
# Set NumPy random seed
|
|
474
|
+
np.random.seed(self.numpy_seed)
|
|
475
|
+
|
|
476
|
+
# Set PyTorch random seed
|
|
477
|
+
torch.manual_seed(self.torch_seed)
|
|
478
|
+
if torch.cuda.is_available():
|
|
479
|
+
torch.cuda.manual_seed_all(self.torch_seed)
|
|
480
|
+
|
|
481
|
+
# Configure deterministic behavior
|
|
482
|
+
if self.torch_deterministic:
|
|
483
|
+
torch.use_deterministic_algorithms(True)
|
|
484
|
+
torch.backends.cudnn.deterministic = True
|
|
485
|
+
torch.backends.cudnn.benchmark = False
|
|
486
|
+
seed_info['deterministic'] = True
|
|
487
|
+
|
|
488
|
+
# Set environment variable for CUDA determinism
|
|
489
|
+
import os
|
|
490
|
+
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
|
|
491
|
+
else:
|
|
492
|
+
seed_info['deterministic'] = False
|
|
493
|
+
|
|
494
|
+
if verbose:
|
|
495
|
+
logger.info("Random seeds initialized:")
|
|
496
|
+
logger.info(f" - Master seed: {self.random_seed}")
|
|
497
|
+
|
|
498
|
+
# Show individual seeds (with indication if they're derived or custom)
|
|
499
|
+
expected_numpy = (self.random_seed + 1) % (2**31)
|
|
500
|
+
expected_torch = (self.random_seed + 2) % (2**31)
|
|
501
|
+
|
|
502
|
+
if self.python_seed == self.random_seed:
|
|
503
|
+
logger.info(f" - Python: {self.python_seed} (derived)")
|
|
504
|
+
else:
|
|
505
|
+
logger.info(f" - Python: {self.python_seed} (custom)")
|
|
506
|
+
|
|
507
|
+
if self.numpy_seed == expected_numpy:
|
|
508
|
+
logger.info(f" - NumPy: {self.numpy_seed} (derived)")
|
|
509
|
+
else:
|
|
510
|
+
logger.info(f" - NumPy: {self.numpy_seed} (custom)")
|
|
511
|
+
|
|
512
|
+
if self.torch_seed == expected_torch:
|
|
513
|
+
logger.info(f" - PyTorch: {self.torch_seed} (derived)")
|
|
514
|
+
else:
|
|
515
|
+
logger.info(f" - PyTorch: {self.torch_seed} (custom)")
|
|
516
|
+
|
|
517
|
+
if self.torch_deterministic:
|
|
518
|
+
logger.info(" - Deterministic mode: ENABLED (may reduce performance)")
|
|
519
|
+
else:
|
|
520
|
+
logger.info(" - Deterministic mode: Disabled")
|
|
521
|
+
|
|
522
|
+
return seed_info
|
|
523
|
+
|
|
524
|
+
|
|
525
|
+
class ExperienceInfo(BaseModel):
|
|
526
|
+
"""
|
|
527
|
+
Complete configuration and runtime container for machine learning experiments.
|
|
528
|
+
|
|
529
|
+
This class stores both the configuration (loaded from TOML) and built runtime
|
|
530
|
+
components for running a machine learning experiment including data readers,
|
|
531
|
+
neural network mappers, and output transformers.
|
|
532
|
+
|
|
533
|
+
The class uses Pydantic for configuration parsing and validation. After loading,
|
|
534
|
+
the raster_reader and mapper fields are automatically transformed from config
|
|
535
|
+
objects to built runtime objects.
|
|
536
|
+
|
|
537
|
+
Attributes:
|
|
538
|
+
system_config: System configuration (paths, IO profiles, device/mapping settings).
|
|
539
|
+
experiment: Experiment parameters (training settings, model config, etc.).
|
|
540
|
+
raster_reader: Configuration (during init), then built reader object after validation.
|
|
541
|
+
mapper: Configuration (during init), then built mapper object after validation.
|
|
542
|
+
boundaries: Spatial boundaries and masks.
|
|
543
|
+
nn_output_transformer: Built transformer for post-processing model outputs (property).
|
|
544
|
+
"""
|
|
545
|
+
|
|
546
|
+
# Configuration fields
|
|
547
|
+
system_config: SystemConfigModel = Field(
|
|
548
|
+
None,
|
|
549
|
+
description="System configuration (paths, IO profiles, device/mapping settings)"
|
|
550
|
+
)
|
|
551
|
+
|
|
552
|
+
experiment: ExperimentConfig = Field(
|
|
553
|
+
default_factory=ExperimentConfig,
|
|
554
|
+
description="Experiment parameters"
|
|
555
|
+
)
|
|
556
|
+
|
|
557
|
+
raster_reader: Union[RasterReaderConfig, MultiRasterReaderConfig] = Field(
|
|
558
|
+
...,
|
|
559
|
+
description="Raster reader configuration",
|
|
560
|
+
discriminator='type'
|
|
561
|
+
)
|
|
562
|
+
|
|
563
|
+
mapper: MapperConfig = Field(
|
|
564
|
+
...,
|
|
565
|
+
description="Label mapper configuration"
|
|
566
|
+
)
|
|
567
|
+
|
|
568
|
+
boundaries: BoundariesConfig = Field(
|
|
569
|
+
default_factory=BoundariesConfig,
|
|
570
|
+
description="Spatial boundaries and masks"
|
|
571
|
+
)
|
|
572
|
+
|
|
573
|
+
# Runtime objects (stored privately, exposed via properties after building)
|
|
574
|
+
_built_raster_reader: Optional[Any] = None
|
|
575
|
+
_built_mapper: Optional[Any] = None
|
|
576
|
+
_built_nn_output_transformer: Optional[Any] = None
|
|
577
|
+
|
|
578
|
+
@model_validator(mode='after')
|
|
579
|
+
def build_runtime_objects(self):
|
|
580
|
+
"""Build runtime objects from configuration."""
|
|
581
|
+
# Build and store runtime objects
|
|
582
|
+
self._built_raster_reader = self.raster_reader.build_reader()
|
|
583
|
+
self._built_mapper = self.mapper.build_mapper()
|
|
584
|
+
self._built_nn_output_transformer = self._built_mapper.map_output_transformer()
|
|
585
|
+
|
|
586
|
+
# Override the config fields with runtime objects for backward compatibility
|
|
587
|
+
object.__setattr__(self, 'raster_reader', self._built_raster_reader)
|
|
588
|
+
object.__setattr__(self, 'mapper', self._built_mapper)
|
|
589
|
+
|
|
590
|
+
return self
|
|
591
|
+
|
|
592
|
+
@property
|
|
593
|
+
def nn_output_transformer(self) -> Any:
|
|
594
|
+
"""Get the built neural network output transformer."""
|
|
595
|
+
return self._built_nn_output_transformer
|
|
596
|
+
|
|
597
|
+
@classmethod
|
|
598
|
+
def from_toml(cls, toml_path: str) -> "ExperienceInfo":
|
|
599
|
+
"""Load ExperienceInfo from a TOML configuration file with full validation.
|
|
600
|
+
|
|
601
|
+
This method loads and validates the configuration using Pydantic,
|
|
602
|
+
then automatically builds all runtime objects.
|
|
603
|
+
|
|
604
|
+
Args:
|
|
605
|
+
toml_path: Path to the TOML configuration file.
|
|
606
|
+
|
|
607
|
+
Returns:
|
|
608
|
+
ExperienceInfo: Fully configured and validated experiment information object.
|
|
609
|
+
|
|
610
|
+
Raises:
|
|
611
|
+
ValidationError: If the TOML configuration is invalid.
|
|
612
|
+
FileNotFoundError: If the TOML file doesn't exist.
|
|
613
|
+
"""
|
|
614
|
+
with open(toml_path, 'rb') as f:
|
|
615
|
+
config_dict = tomli.load(f)
|
|
616
|
+
|
|
617
|
+
# Pydantic will validate and build runtime objects automatically
|
|
618
|
+
return cls(**config_dict)
|