eoml 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. eoml/__init__.py +74 -0
  2. eoml/automation/__init__.py +7 -0
  3. eoml/automation/configuration.py +105 -0
  4. eoml/automation/dag.py +233 -0
  5. eoml/automation/experience.py +618 -0
  6. eoml/automation/tasks.py +825 -0
  7. eoml/bin/__init__.py +6 -0
  8. eoml/bin/clean_checkpoint.py +146 -0
  9. eoml/bin/land_cover_mapping_toml.py +435 -0
  10. eoml/bin/mosaic_images.py +137 -0
  11. eoml/data/__init__.py +7 -0
  12. eoml/data/basic_geo_data.py +214 -0
  13. eoml/data/dataset_utils.py +98 -0
  14. eoml/data/persistence/__init__.py +7 -0
  15. eoml/data/persistence/generic.py +253 -0
  16. eoml/data/persistence/lmdb.py +379 -0
  17. eoml/data/persistence/serializer.py +82 -0
  18. eoml/raster/__init__.py +7 -0
  19. eoml/raster/band.py +141 -0
  20. eoml/raster/dataset/__init__.py +6 -0
  21. eoml/raster/dataset/extractor.py +604 -0
  22. eoml/raster/raster_reader.py +602 -0
  23. eoml/raster/raster_utils.py +116 -0
  24. eoml/torch/__init__.py +7 -0
  25. eoml/torch/cnn/__init__.py +7 -0
  26. eoml/torch/cnn/augmentation.py +150 -0
  27. eoml/torch/cnn/dataset_evaluator.py +68 -0
  28. eoml/torch/cnn/db_dataset.py +605 -0
  29. eoml/torch/cnn/map_dataset.py +579 -0
  30. eoml/torch/cnn/map_dataset_const_mem.py +135 -0
  31. eoml/torch/cnn/outputs_transformer.py +130 -0
  32. eoml/torch/cnn/torch_utils.py +404 -0
  33. eoml/torch/cnn/training_dataset.py +241 -0
  34. eoml/torch/cnn/windows_dataset.py +120 -0
  35. eoml/torch/dataset/__init__.py +6 -0
  36. eoml/torch/dataset/shade_dataset_tester.py +46 -0
  37. eoml/torch/dataset/shade_tree_dataset_creators.py +537 -0
  38. eoml/torch/model_low_use.py +507 -0
  39. eoml/torch/models.py +282 -0
  40. eoml/torch/resnet.py +437 -0
  41. eoml/torch/sample_statistic.py +260 -0
  42. eoml/torch/trainer.py +782 -0
  43. eoml/torch/trainer_v2.py +253 -0
  44. eoml-0.9.0.dist-info/METADATA +93 -0
  45. eoml-0.9.0.dist-info/RECORD +47 -0
  46. eoml-0.9.0.dist-info/WHEEL +4 -0
  47. eoml-0.9.0.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,253 @@
1
+ from abc import ABC
2
+
3
+ from eoml import torch
4
+ from torchmetrics import F1Score
5
+
6
+
7
+ class Score(ABC):
8
+
9
+ def __init__(self):
10
+ pass
11
+
12
+ def __call__(self):
13
+ pass
14
+
15
+ def direction(self):
16
+ pass
17
+
18
+ def is_last_best(self):
19
+ pass
20
+
21
+ @property
22
+ def best(self):
23
+ return 0
24
+
25
+ class F1MultiClass(Score):
26
+ def __init__(self, num_class, average="macro", device="cpu"):
27
+ #https://stephenallwright.com/micro-vs-macro-f1-score/
28
+ self._best = float("inf")
29
+
30
+ self.score = F1Score(task="multiclass", average=average, num_classes=num_class).to(device)
31
+
32
+
33
+ def __call__(self, output, target):
34
+ self.score(output, target)
35
+
36
+ def direction(self):
37
+ pass
38
+
39
+ def is_last_best(self):
40
+ pass
41
+
42
+ def best(self):
43
+ return self._best
44
+
45
+ class F1_Score(ABC):
46
+
47
+ def __init__(self):
48
+ pass
49
+
50
+ def __call__(self):
51
+ pass
52
+
53
+ class Trainer:
54
+ """TODO DO AGGRESSIVBE VERSION"""
55
+ def __init__(self, optimizer, model, loss_fn, grad_f=None, score_function=f1, score_name="f1", score_direction=1):
56
+ self.optimizer = optimizer
57
+ self.model = model
58
+ self.loss_fn = loss_fn
59
+ self.grad_f = grad_f
60
+
61
+ self.score_direction = score_direction
62
+
63
+ self.writer = None
64
+
65
+ self.score_function = score_function
66
+ self.score_name = score_name
67
+
68
+ def _epoch(self, loader, epoch_index, report_frequency, device="cpu"):
69
+
70
+ """
71
+ :param loader:
72
+ :param epoch_index:
73
+ :param report_frequency:
74
+ :param device: device to move tensors to. None for do nothing
75
+ :return:
76
+ """
77
+
78
+ # Make sure gradient tracking is on, and do a pass over the data
79
+ self.model.train(True)
80
+
81
+ running_loss = 0.
82
+ last_loss = 0.
83
+ with tqdm(total=len(loader),desc="Batch") as pbar:
84
+ for i, data in enumerate(loader):
85
+ # Every data instance is an input + label pair
86
+ inputs, labels = data
87
+
88
+ if device is not None:
89
+ inputs = inputs.to(device, non_blocking=True)
90
+ labels = labels.to(device, non_blocking=True)
91
+ # Zero your gradients for every batch!
92
+ self.optimizer.zero_grad()
93
+
94
+ # Make predictions for this batch
95
+ outputs = self.model(inputs)
96
+
97
+ # Compute the loss and its gradients
98
+ loss = self.loss_fn(outputs, labels)
99
+ loss.backward()
100
+
101
+ # clip the gradient
102
+ if self.grad_f is not None:
103
+ self.grad_f(self.model)
104
+
105
+ # Adjust learning weights
106
+ self.optimizer.step()
107
+
108
+ # Gather data and report
109
+ running_loss += loss.item()
110
+ if i % report_frequency == report_frequency - 1:
111
+ pbar.set_postfix({'Batch ': i + 1,
112
+ 'Last loss': last_loss,
113
+ }, refresh=False)
114
+ pbar.update(report_frequency)
115
+
116
+ last_loss = running_loss / report_frequency # loss per item
117
+ #print(' batch {} loss: {}'.format(i + 1, last_loss))
118
+ tb_x = epoch_index * len(loader) + i + 1
119
+ self.writer.add_scalar('Loss/train', last_loss, tb_x)
120
+ running_loss = 0.
121
+
122
+ return last_loss
123
+
124
+ def _validate(self, validation_loader, device):
125
+
126
+ self.model.train(False)
127
+
128
+ running_vloss = 0.0
129
+ running_score = 0.0
130
+
131
+ for i, vdata in enumerate(validation_loader):
132
+ vinputs, vlabels = vdata
133
+
134
+ if device is not None:
135
+ vinputs = vinputs.to(device, non_blocking=True)
136
+ vlabels = vlabels.to(device, non_blocking=True)
137
+
138
+ voutputs = self.model(vinputs)
139
+ vloss = self.loss_fn(voutputs, vlabels)
140
+ running_vloss += vloss.item()
141
+
142
+ vf1 = self.score_function(voutputs.cpu(), vlabels.cpu())
143
+ running_score += vf1
144
+
145
+ avg_vloss = running_vloss / (i + 1)
146
+ avg_score = running_score / (i + 1)
147
+ # print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))
148
+ # print('Weighted avg f1 {}'.format(avg_f1))
149
+ return avg_vloss, avg_score
150
+
151
+ def train(self, epochs, training_loader, validation_loader, report_per_epoch=10,
152
+ writer_base_path="runs", model_base_path=".", model_tag="model", device="cpu"):
153
+ # Initializing in a separate cell so we can easily add more epochs to the same run
154
+ timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
155
+ self.writer = SummaryWriter(f"{writer_base_path}/{model_tag}_{timestamp}")
156
+
157
+ model_name = f"{model_tag}_{timestamp}"
158
+ base_dir = f"{model_base_path}/{model_name}"
159
+ os.mkdir(base_dir)
160
+
161
+ n_batch = len(training_loader)
162
+
163
+ report_frequency = math.ceil(n_batch / report_per_epoch)
164
+
165
+ best_score_epoch = 0.
166
+ if self.score_direction == -1:
167
+ best_score = 1_000_000
168
+ else:
169
+ best_score = 0
170
+
171
+
172
+ best_vloss = 1_000_000.
173
+ best_vloss_epoch = 0
174
+
175
+ best_epoch = 0
176
+
177
+ model_path = None
178
+
179
+ with tqdm(total=epochs, desc='Epoch') as pbar:
180
+ for epoch in range(epochs):
181
+
182
+ #print('EPOCH {}:'.format(epoch_number + 1))
183
+ avg_loss = self._epoch(training_loader, epoch, report_frequency, device)
184
+
185
+ # We don't need gradients on to do reporting
186
+
187
+ avg_vloss, avg_score = self._validate(validation_loader, device)
188
+
189
+ # Log the running loss averaged per batch
190
+ # for both training and validation
191
+ self.writer.add_scalars('Training vs. Validation Loss',
192
+ {'Training': avg_loss, 'Validation': avg_vloss},
193
+ epoch + 1)
194
+ self.writer.add_scalars(f'Weighted avg {self.score_name}',
195
+ {f'Weighted avg {self.score_name}': avg_score},
196
+ epoch + 1)
197
+
198
+ # todo f1 for all v batch at once
199
+
200
+ self.writer.flush()
201
+
202
+
203
+ # Track the best performance, and save the model's state
204
+ if avg_vloss < best_vloss:
205
+ best_vloss = avg_vloss
206
+ best_vloss_epoch = epoch + 1
207
+ best_epoch = best_vloss_epoch
208
+ best_metric = "loss"
209
+ model_path = f'{base_dir}/{best_metric}_{model_tag}_{timestamp}_{best_vloss_epoch}'
210
+ torch.save(self.model.state_dict(), model_path)
211
+
212
+ if self.score_direction * avg_score > self.score_direction*best_score:
213
+ best_score = avg_score
214
+ best_score_epoch = epoch+1
215
+ best_epoch = best_score_epoch
216
+ best_metric = self.score_name
217
+ model_path = f'{base_dir}/{best_metric}_{model_tag}_{timestamp}_{best_score_epoch}'
218
+ torch.save(self.model.state_dict(), model_path)
219
+
220
+
221
+ pbar.set_postfix({f'best {self.score_name} epoch': best_score_epoch,
222
+ f'best {self.score_name}': best_score,
223
+ f'current {self.score_name}': avg_score,
224
+ 'best avg loss epoch': best_vloss_epoch,
225
+ 'best avg loss': best_vloss,
226
+ 'current avg loss': avg_vloss}, refresh=False)
227
+ pbar.update(1)
228
+
229
+
230
+ # load best model
231
+ self.model.load_state_dict(torch.load(model_path))
232
+ #switch off training
233
+ self.model.train(False)
234
+ # git model for inference
235
+
236
+ vinputs, _ = next(iter(validation_loader))
237
+
238
+ if device is not None:
239
+ vinputs = vinputs.to(device)
240
+
241
+ # switch off gradient
242
+ #todo update py torch
243
+ #torch.jit.enable_onednn_fusion(True)
244
+ with torch.inference_mode():
245
+ #model_scripted = torch.jit.script(model, example_inputs=vinputs) # Export to TorchScript, from the doc: TorchScript is actually the recommended model format for scaled inference and deployment.
246
+ model_scripted = torch.jit.trace(self.model, example_inputs=vinputs)
247
+ model_scripted = torch.jit.freeze(model_scripted)
248
+
249
+ model_path = f'{base_dir}/jited_{best_metric}_{model_tag}_{timestamp}_{best_epoch}.pt'
250
+ model_scripted.save(model_path) # Save
251
+
252
+ return base_dir, model_path, model_name
253
+
@@ -0,0 +1,93 @@
1
+ Metadata-Version: 2.4
2
+ Name: eoml
3
+ Version: 0.9.0
4
+ Summary: library to manage GIS operation
5
+ Project-URL: Homepage, https://ciatgit.ciat.cgiar.org/Data_driven_sustainability_public/terra-i/eoml#
6
+ Project-URL: Documentation, https://ciatgit.ciat.cgiar.org/Data_driven_sustainability_public/terra-i/eoml#
7
+ Project-URL: Repository, https://ciatgit.ciat.cgiar.org/Data_driven_sustainability_public/terra-i/eoml#
8
+ Project-URL: Bug Tracker, https://ciatgit.ciat.cgiar.org/Data_driven_sustainability_public/terra-i/eoml/-/issues
9
+ Author-email: Thibaud Vantalon <t.vantalon@cgiar.org>
10
+ Maintainer-email: Thibaud Vantalon <t.vantalon@cgiar.org>
11
+ License: MIT License
12
+ Keywords: GIS,Rasterio
13
+ Classifier: Development Status :: 4 - Beta
14
+ Classifier: Programming Language :: Python
15
+ Requires-Python: >=3.12
16
+ Requires-Dist: fiona
17
+ Requires-Dist: geopandas
18
+ Requires-Dist: lmdb
19
+ Requires-Dist: msgpack
20
+ Requires-Dist: numpy
21
+ Requires-Dist: pydantic>=2.6
22
+ Requires-Dist: pyproj
23
+ Requires-Dist: rasterio
24
+ Requires-Dist: rasterstats
25
+ Requires-Dist: scikit-learn
26
+ Requires-Dist: shapely
27
+ Requires-Dist: tensorboard
28
+ Requires-Dist: toml
29
+ Requires-Dist: tomli
30
+ Requires-Dist: torchmetrics
31
+ Requires-Dist: tqdm
32
+ Requires-Dist: typer
33
+ Description-Content-Type: text/markdown
34
+
35
+ # EOML - Earth Observation Machine Learning
36
+
37
+ A Python library for managing GIS operations and machine learning workflows for remote sensing applications.
38
+
39
+ ## Overview
40
+
41
+ EOML provides a comprehensive toolkit for processing Earth observation data and building machine learning models for
42
+ satellite imagery analysis. The library integrates rasterio, PyTorch, and Google Earth Engine to streamline geospatial
43
+ machine learning workflows.
44
+
45
+ ## Features
46
+
47
+ - **PyTorch Integration**: Pre-built CNN architectures and training utilities for remote sensing
48
+
49
+
50
+ ## Installation
51
+
52
+ ### PyPI
53
+ pip install eoml
54
+
55
+ ### Developement mode
56
+ Installation in development mode:
57
+ ```bash
58
+ pip install -e .
59
+ ```
60
+
61
+
62
+
63
+ ### Running Tests
64
+
65
+ ```bash
66
+ pytest tests/
67
+ ```
68
+
69
+ ## Contributing
70
+
71
+ Contributions are welcome! Please ensure code follows the project style and includes appropriate docstrings.
72
+
73
+ ## License
74
+ MIT License
75
+
76
+ ## Author
77
+
78
+ **Thibaud Vantalon**
79
+ Email: t.vantalon@cgiar.org
80
+ Organization: CGIAR
81
+
82
+ ## Citation
83
+
84
+ If you use this library in your research, please cite:
85
+
86
+ ```bibtex
87
+ @software{eoml,
88
+ author = {Vantalon, Thibaud},
89
+ title = {EOML: Earth Observation Machine Learning},
90
+ year = {2024},
91
+ url = {https://ciatgit.ciat.cgiar.org/Data_driven_sustainability_public/terra-i/eoml#}
92
+ }
93
+ ```
@@ -0,0 +1,47 @@
1
+ eoml/__init__.py,sha256=ZNIQkoMq2XYKy8R4nzCKMrJ2NJl3lsAbCZzTfh9RSHc,2448
2
+ eoml/automation/__init__.py,sha256=jBxHaR8zlPEow2QHmqIA3AWhZwmUPmEny90z3AXZBKA,241
3
+ eoml/automation/configuration.py,sha256=Ii7CQG5RNQDubH35rtKkvqgQ69_vZk13xRUJm8ttyfE,3474
4
+ eoml/automation/dag.py,sha256=_g-YP6nKW9MeYBtGMEURjKMKoi9l29LFw3X6Uzqc4qE,6469
5
+ eoml/automation/experience.py,sha256=TKi3jsZPKsWRpBPHuaaUIXG-G_-mZQ2mIjTe28ZRlGM,19822
6
+ eoml/automation/tasks.py,sha256=ALYROfvTIUPY7QmOEyfk4PpjluGx7rfUWR0YJ75ddxM,32283
7
+ eoml/bin/__init__.py,sha256=y0ukkpx3o4fGFWE289g14JF_UIA_Z3jMI9uGJBt58aU,159
8
+ eoml/bin/clean_checkpoint.py,sha256=PWgSsaY_9DNxy_knYpTrqVWLvtymh15G35l49k0CT-A,4070
9
+ eoml/bin/land_cover_mapping_toml.py,sha256=lIak0jLjZ8wtWRcdlO5ngU7LpcLkoeDgt4vHhS23Xjg,16471
10
+ eoml/bin/mosaic_images.py,sha256=yAKUwXlJY80heVLQ87if3aq8dh_yPn6Un6j1mwyP2v8,4432
11
+ eoml/data/__init__.py,sha256=2zBgj0doFdn03fsTlrozad2JrKKE_0wMBPORit5shM8,228
12
+ eoml/data/basic_geo_data.py,sha256=vaiQfnzxgR7UhXGnl2SarnDJw-aPXFlNBQ_zCs9QwWk,6617
13
+ eoml/data/dataset_utils.py,sha256=Mn5XojdgVGEcndSgNMU8NEgDMpVUl6BhAlXhugTLPA8,3209
14
+ eoml/data/persistence/__init__.py,sha256=mohqYLULxrXtU-75dR3slnZpXkCyMBXLa1V6HEewnkU,254
15
+ eoml/data/persistence/generic.py,sha256=661dQKcS7bnN_2WgmVdPa_3sgzCkKUkltHMYrN9UFU4,6665
16
+ eoml/data/persistence/lmdb.py,sha256=IjUUHxbwh8Q2yq6SfJH050wUmDnz7DxgCRNdKn1yiBE,14128
17
+ eoml/data/persistence/serializer.py,sha256=pWGuCnGNyqD1We5fEV-tN8u7xjiOWap9z47QS_AJfcs,2997
18
+ eoml/raster/__init__.py,sha256=AB9Y8A7gq-UvsYJ0lHLD18omQZ6wxYwHjxsK6x7sMpQ,248
19
+ eoml/raster/band.py,sha256=EFQCab6MyHtVdnvHBOFH4lXsrm4LRQToknCWmu0Ai6U,3905
20
+ eoml/raster/raster_reader.py,sha256=V3AnyTq_kYRINhQCOB-Zr7tlcMwSB8eS0H9xasg6rj8,20899
21
+ eoml/raster/raster_utils.py,sha256=LgXyJqg0PqVUzPfHfXv9mOMEFaryJv1qi2MkygM9CNI,3931
22
+ eoml/raster/dataset/__init__.py,sha256=Bezfq4avYoBWXjca6bWDDJiemhVY3dYcww4HmHo75rk,186
23
+ eoml/raster/dataset/extractor.py,sha256=B96_WRN5eVkSmeujpIqZrAKcLgH9wbF3nYC0hLUSZtI,22512
24
+ eoml/torch/__init__.py,sha256=82HZiOUYEy6fqbAGYLKrdptutql-Xc5M7ts0Z4HZNKg,237
25
+ eoml/torch/model_low_use.py,sha256=eGKv1iAZp_2509N2tfMF9q_1FRPEsKs-mLoSHnbEiAM,18416
26
+ eoml/torch/models.py,sha256=s9J5fjurunFon6WkeW78X1E8fCqNzPL5unGQJP1EP5A,10568
27
+ eoml/torch/resnet.py,sha256=fs9kZykkshKF4B5qnir96LUezKCSrET48zmO5HAem9Y,16332
28
+ eoml/torch/sample_statistic.py,sha256=pUcYkf6c7BFXhNq4z0UDlNyosvrxvWP8Ep-oUlrENwA,10489
29
+ eoml/torch/trainer.py,sha256=h46fQ9pekHi5CCenhEXGYS_HSjeseL2SNRKL9zaEMG0,28561
30
+ eoml/torch/trainer_v2.py,sha256=jTQ2FpU1jY-x1S0NNZqNPcKjFQGeVuhVM7IVdqwB9-8,8655
31
+ eoml/torch/cnn/__init__.py,sha256=aczbA04QW_pGDxOHoMqLU-vUNiibyBG82HMDjw5CFWQ,246
32
+ eoml/torch/cnn/augmentation.py,sha256=HkRNHRBeysjqulGPJHRnyYyhrPu6dkggsIRcZySkUxA,4779
33
+ eoml/torch/cnn/dataset_evaluator.py,sha256=MucZBc0iCbB2jrRXfbGnGfnqoYvu3JUlxxB8gKvlXJs,2100
34
+ eoml/torch/cnn/db_dataset.py,sha256=EttHGEOh8XSL-MKRkWLArPQUB6ZZJ0tGiNeyr454aJg,21170
35
+ eoml/torch/cnn/map_dataset.py,sha256=49xQjzR4D9qe5TgADGKKwfNbjlYdL5qM6T7zvGZ2TIE,21008
36
+ eoml/torch/cnn/map_dataset_const_mem.py,sha256=fKgA8X8Vd_3aeFLgnZhQdDVy5JKOgaE0X2qSIkt3fn4,4901
37
+ eoml/torch/cnn/outputs_transformer.py,sha256=0852oWf86Dlq_1942g-Hil9sSXj3vkqa_lmhTXmwsVU,3674
38
+ eoml/torch/cnn/torch_utils.py,sha256=ThKoiF812CZgfc3J7XY47-QoD2dmbmihOb5Un0_rPUc,12661
39
+ eoml/torch/cnn/training_dataset.py,sha256=ncLAIEKe1hXAh7TWZhgfVQtoCdG2c2wrXipU91Q9FrM,8607
40
+ eoml/torch/cnn/windows_dataset.py,sha256=--kWNGRBO86oYSY3JoHIEYFC4_vBTpfI7AdWAImqEuA,3758
41
+ eoml/torch/dataset/__init__.py,sha256=oFoDD6oeFyws8wvoOrHEBvIqJQTlTvYU9OPuiDOWB3o,178
42
+ eoml/torch/dataset/shade_dataset_tester.py,sha256=UZaOKl-twF_74r8I9rxFVchq7t2HQzlmm7_eHIRjDYw,1512
43
+ eoml/torch/dataset/shade_tree_dataset_creators.py,sha256=afqhJzwQcEytpdqflThb2tDWbIn6HiigJxIF0EfaqUw,21183
44
+ eoml-0.9.0.dist-info/METADATA,sha256=CGaytOPlGtq-X0Uhfm9O9LZwmNBz61bGqOyBaf2brpo,2472
45
+ eoml-0.9.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
46
+ eoml-0.9.0.dist-info/entry_points.txt,sha256=QmOKUZQNv8HGTC9J1Q-O_Nsdnhy8a0-6mOfHWzoYpEY,115
47
+ eoml-0.9.0.dist-info/RECORD,,
@@ -0,0 +1,4 @@
1
+ Wheel-Version: 1.0
2
+ Generator: hatchling 1.28.0
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
@@ -0,0 +1,3 @@
1
+ [console_scripts]
2
+ eoml_map = eoml.bin.land_cover_mapping_toml:main
3
+ eoml_mosaic_images = eoml.bin.mosaic_images:app