eotdl 2024.10.1__py3-none-any.whl → 2024.10.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
eotdl/__init__.py CHANGED
@@ -1 +1 @@
1
- __version__ = "2024.10.01"
1
+ __version__ = "2024.10.07"
@@ -0,0 +1 @@
1
+ from .models import ModelWrapper
@@ -0,0 +1,158 @@
1
+ # Q1+ model wrapper
2
+ # only works with some models, extend as we include more models in EOTDL and improve MLM extension
3
+
4
+ import os
5
+ from pathlib import Path
6
+ from tqdm import tqdm
7
+ import numpy as np
8
+
9
+ from ..models.retrieve import retrieve_model
10
+ from ..curation.stac import STACDataFrame
11
+ from ..repos import FilesAPIRepo, ModelsAPIRepo
12
+ from ..auth import with_auth
13
+
14
+ class ModelWrapper:
15
+ def __init__(self, model_name, version=None, path=None, force=False, assets=True, verbose=True):
16
+ self.model_name = model_name
17
+ self.version = version
18
+ self.path = path
19
+ self.force = force
20
+ self.assets = assets
21
+ self.verbose = verbose
22
+ self.ready = False
23
+ self.setup()
24
+
25
+ def setup(self):
26
+ download_path, gdf = self.download()
27
+ self.download_path = download_path
28
+ self.gdf = gdf
29
+ # get model name from stac metadata
30
+ item = gdf[gdf['type'] == "Feature"]
31
+ assert item.shape[0] == 1, "Only one item is supported in stac metadata, found " + str(item.shape[0])
32
+ self.props = item.iloc[0].properties
33
+ assert self.props["mlm:framework"] == "ONNX", "Only ONNX models are supported, found " + self.props["mlm:framework"]
34
+ model_name = self.props["mlm:name"]
35
+ self.model_path = download_path + '/assets/' + model_name
36
+ self.ready = True
37
+
38
+ def predict(self, x):
39
+ if not self.ready:
40
+ self.setup()
41
+ ort_session = self.get_onnx_session(self.model_path)
42
+ # preprocess input
43
+ x = self.process_inputs(x)
44
+ # execute model
45
+ input_name = ort_session.get_inputs()[0].name
46
+ ort_inputs = {input_name: x}
47
+ ort_outs = ort_session.run(None, ort_inputs)
48
+ output_nodes = ort_session.get_outputs()
49
+ output_names = [node.name for node in output_nodes]
50
+ # format and return outputs
51
+ return self.return_outputs(ort_outs, output_names)
52
+
53
+ @with_auth
54
+ def download(self, user=None):
55
+ # download the model
56
+ model = retrieve_model(self.model_name)
57
+ if model["quality"] == 0:
58
+ raise Exception("Only Q1+ models are supported")
59
+ if self.version is None:
60
+ self.version = sorted(model["versions"], key=lambda v: v["version_id"])[-1][
61
+ "version_id"
62
+ ]
63
+ else:
64
+ assert self.version in [
65
+ v["version_id"] for v in model["versions"]
66
+ ], f"Version {self.version} not found"
67
+ download_base_path = os.getenv(
68
+ "EOTDL_DOWNLOAD_PATH", str(Path.home()) + "/.cache/eotdl/models"
69
+ )
70
+ if self.path is None:
71
+ download_path = download_base_path + "/" + self.model_name + "/v" + str(self.version)
72
+ else:
73
+ download_path = self.path + "/" + self.model_name + "/v" + str(self.version)
74
+ # check if model already exists
75
+ if os.path.exists(download_path) and not self.force:
76
+ os.makedirs(download_path, exist_ok=True)
77
+ gdf = STACDataFrame.from_stac_file(download_path + f"/{self.model_name}/catalog.json")
78
+ return download_path, gdf
79
+ if self.verbose:
80
+ print("Downloading STAC metadata...")
81
+ repo = ModelsAPIRepo()
82
+ gdf, error = repo.download_stac(
83
+ model["id"],
84
+ user,
85
+ )
86
+ if error:
87
+ raise Exception(error)
88
+ df = STACDataFrame(gdf)
89
+ # df.geometry = df.geometry.apply(lambda x: Polygon() if x is None else x)
90
+ df.to_stac(download_path)
91
+ # download assets
92
+ if self.assets:
93
+ if self.verbose:
94
+ print("Downloading assets...")
95
+ repo = FilesAPIRepo()
96
+ df = df.dropna(subset=["assets"])
97
+ for row in tqdm(df.iterrows(), total=len(df)):
98
+ for k, v in row[1]["assets"].items():
99
+ href = v["href"]
100
+ _, filename = href.split("/download/")
101
+ # will overwrite assets with same name :(
102
+ repo.download_file_url(
103
+ href, filename, f"{download_path}/assets", user
104
+ )
105
+ else:
106
+ print("To download assets, set assets=True.")
107
+ if self.verbose:
108
+ print("Done")
109
+ return download_path, gdf
110
+
111
+ def process_inputs(self, x):
112
+ # pre-process and validate input
113
+ input = self.props["mlm:input"]
114
+ # input data type
115
+ dtype = input["input"]["data_type"]
116
+ x = x.astype(dtype)
117
+ # input shape
118
+ input_shape = input["input"]["shape"]
119
+ ndims = len(input_shape)
120
+ if ndims != x.ndim:
121
+ if ndims == 4:
122
+ x = np.expand_dims(x, axis=0).astype(np.float32)
123
+ else:
124
+ raise Exception("Input shape not valid", input_shape, x.ndim)
125
+ for i, dim in enumerate(input_shape):
126
+ if dim != -1:
127
+ assert dim == x.shape[i], f"Input dimension not valid: The model expects {input_shape} but input has {x.shape} (-1 means any dimension)."
128
+ # TODO: should apply normalization if defined in metadata
129
+ return x
130
+
131
+ def return_outputs(self, ort_outputs, output_names):
132
+ if self.props["mlm:output"]["tasks"] == ["classification"]:
133
+ return {
134
+ "model": self.model_name,
135
+ **{
136
+ output: ort_outputs[i].tolist() for i, output in enumerate(output_names)
137
+ },
138
+ }
139
+ elif self.props["mlm:output"]["tasks"] == ["segmentation"]:
140
+ outputs = {output: ort_outputs[i] for i, output in enumerate(output_names)}
141
+ batch = outputs[output_names[0]]
142
+ image = batch[0]
143
+ return image
144
+ else:
145
+ raise Exception("Output task not supported:", self.props["mlm:output"]["tasks"])
146
+
147
+ def get_onnx_session(self, model):
148
+ try:
149
+ import onnxruntime as ort
150
+ # gpu requires `pip install onnxruntime-gpu` but no extra imports
151
+ except ImportError:
152
+ raise ImportError("onnxruntime is not installed. Please install it with `pip install onnxruntime`")
153
+ providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
154
+ try:
155
+ session = ort.InferenceSession(model, providers=providers)
156
+ except Exception as e:
157
+ raise RuntimeError(f"Error loading ONNX model: {str(e)}")
158
+ return session
@@ -0,0 +1,35 @@
1
+ from ..curation.stac import STACDataFrame
2
+
3
+ def download_model(model_name, dst_path, version, force=False, download=True):
4
+ # check if model already downloaded
5
+ version = 1 if version is None else version
6
+ download_path = dst_path + "/" + model_name + "/v" + str(version)
7
+ if os.path.exists(download_path) and not force:
8
+ df = STACDataFrame.from_stac_file(download_path + f"/{model_name}/catalog.json")
9
+ return download_path, df
10
+ # check model exists
11
+ model, error = retrieve_model(model_name)
12
+ if error:
13
+ raise Exception(error)
14
+ if model["quality"] < 2:
15
+ raise Exception("Only Q2+ models are supported")
16
+ # check version exist
17
+ assert version in [
18
+ v["version_id"] for v in model["versions"]
19
+ ], f"Version {version} not found"
20
+ # download model files
21
+ gdf, error = retrieve_model_stac(model["id"], version)
22
+ if error:
23
+ raise Exception(error)
24
+ df = STACDataFrame(gdf)
25
+ if not download:
26
+ return download_path, df
27
+ os.makedirs(download_path, exist_ok=True)
28
+ df.to_stac(download_path)
29
+ df = df.dropna(subset=["assets"])
30
+ for row in df.iterrows():
31
+ for k, v in row[1]["assets"].items():
32
+ href = v["href"]
33
+ _, filename = href.split("/download/")
34
+ download_file_url(href, filename, f"{download_path}/assets")
35
+ return download_path, df
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: eotdl
3
- Version: 2024.10.1
3
+ Version: 2024.10.7
4
4
  Summary: Earth Observation Training Data Lab
5
5
  License: MIT
6
6
  Author: EarthPulse
@@ -1,4 +1,4 @@
1
- eotdl/__init__.py,sha256=GYhbp_qDI4r7Y9hYO76t0DwK_JCtUw2UK-DbHzV4XKk,27
1
+ eotdl/__init__.py,sha256=9dJlE6aniPM75bpVKz06wufJdjPqaOg_opOHL9R4swI,27
2
2
  eotdl/access/__init__.py,sha256=jbyjD7BRGJURlTNmtcbBBhw3Xk4EiZvkqmEykM-bJ1k,231
3
3
  eotdl/access/airbus/__init__.py,sha256=G_kkRS9eFjXbQ-aehmTLXeAxh7zpAxz_rgB7J_w0NRg,107
4
4
  eotdl/access/airbus/client.py,sha256=zjfgB_NTsCCIszoQesYkyLJgheKg-eTh28vbleXYxfw,12018
@@ -73,7 +73,10 @@ eotdl/tools/paths.py,sha256=yWhOtVxX4NxrDrrBX2fuye5N1mAqrxXFy_eA7dffd84,1152
73
73
  eotdl/tools/stac.py,sha256=ovXdrPm4Sn9AAJmrP88WnxDmq2Ut-xPoscjphxz3Iyo,5763
74
74
  eotdl/tools/time_utils.py,sha256=qJ3-rk1I7ne722SLfAP6-59kahQ0vLQqIf9VpOi0Kpg,4691
75
75
  eotdl/tools/tools.py,sha256=Tl4_v2ejkQo_zyZek8oofJwoYcdVosdOwW1C0lvWaNM,6354
76
- eotdl-2024.10.1.dist-info/METADATA,sha256=FSAul-2xBkcUyN-T0tk8YX0Kp_ZzlQ4AAk0e9IxYmKk,4144
77
- eotdl-2024.10.1.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
78
- eotdl-2024.10.1.dist-info/entry_points.txt,sha256=s6sfxUfRrSX2IP2UbrzTFTvRCtLgw3_OKcHlOKf_5F8,39
79
- eotdl-2024.10.1.dist-info/RECORD,,
76
+ eotdl/wrappers/__init__.py,sha256=IY3DK_5LMbc5bIQFleQA9kzFbPhWuTLesJ8dwfvpkdA,32
77
+ eotdl/wrappers/models.py,sha256=kNO4pYw9KKKmElE7bZWWHGs7FIThNUXj8XciKh_3rNw,6432
78
+ eotdl/wrappers/utils.py,sha256=BoG1Gtt1jXM_62BD6amrrw9jE25ki-zyY468LVgK0gM,1379
79
+ eotdl-2024.10.7.dist-info/METADATA,sha256=17YF30YXxtAE8jCX9-zKEyyrDPTKxjkYFbCj0qf4Gk0,4144
80
+ eotdl-2024.10.7.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
81
+ eotdl-2024.10.7.dist-info/entry_points.txt,sha256=s6sfxUfRrSX2IP2UbrzTFTvRCtLgw3_OKcHlOKf_5F8,39
82
+ eotdl-2024.10.7.dist-info/RECORD,,