rslearn 0.0.14__py3-none-any.whl → 0.0.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. rslearn/config/__init__.py +2 -10
  2. rslearn/config/dataset.py +414 -420
  3. rslearn/data_sources/__init__.py +8 -31
  4. rslearn/data_sources/aws_landsat.py +13 -24
  5. rslearn/data_sources/aws_open_data.py +21 -46
  6. rslearn/data_sources/aws_sentinel1.py +3 -14
  7. rslearn/data_sources/climate_data_store.py +21 -40
  8. rslearn/data_sources/copernicus.py +30 -91
  9. rslearn/data_sources/data_source.py +26 -0
  10. rslearn/data_sources/earthdaily.py +13 -38
  11. rslearn/data_sources/earthdata_srtm.py +14 -32
  12. rslearn/data_sources/eurocrops.py +5 -9
  13. rslearn/data_sources/gcp_public_data.py +46 -43
  14. rslearn/data_sources/google_earth_engine.py +31 -44
  15. rslearn/data_sources/local_files.py +91 -100
  16. rslearn/data_sources/openstreetmap.py +21 -51
  17. rslearn/data_sources/planet.py +12 -30
  18. rslearn/data_sources/planet_basemap.py +4 -25
  19. rslearn/data_sources/planetary_computer.py +58 -141
  20. rslearn/data_sources/usda_cdl.py +15 -26
  21. rslearn/data_sources/usgs_landsat.py +4 -29
  22. rslearn/data_sources/utils.py +9 -0
  23. rslearn/data_sources/worldcereal.py +47 -54
  24. rslearn/data_sources/worldcover.py +16 -14
  25. rslearn/data_sources/worldpop.py +15 -18
  26. rslearn/data_sources/xyz_tiles.py +11 -30
  27. rslearn/dataset/dataset.py +6 -6
  28. rslearn/dataset/manage.py +28 -26
  29. rslearn/dataset/materialize.py +9 -45
  30. rslearn/lightning_cli.py +370 -1
  31. rslearn/main.py +3 -3
  32. rslearn/models/clay/clay.py +14 -1
  33. rslearn/models/concatenate_features.py +93 -0
  34. rslearn/models/croma.py +26 -3
  35. rslearn/models/satlaspretrain.py +18 -4
  36. rslearn/models/terramind.py +19 -0
  37. rslearn/tile_stores/__init__.py +0 -11
  38. rslearn/train/dataset.py +4 -12
  39. rslearn/train/prediction_writer.py +16 -32
  40. rslearn/train/tasks/classification.py +2 -1
  41. rslearn/utils/fsspec.py +20 -0
  42. rslearn/utils/jsonargparse.py +79 -0
  43. rslearn/utils/raster_format.py +1 -41
  44. rslearn/utils/vector_format.py +1 -38
  45. {rslearn-0.0.14.dist-info → rslearn-0.0.16.dist-info}/METADATA +1 -1
  46. {rslearn-0.0.14.dist-info → rslearn-0.0.16.dist-info}/RECORD +51 -52
  47. rslearn/data_sources/geotiff.py +0 -1
  48. rslearn/data_sources/raster_source.py +0 -23
  49. {rslearn-0.0.14.dist-info → rslearn-0.0.16.dist-info}/WHEEL +0 -0
  50. {rslearn-0.0.14.dist-info → rslearn-0.0.16.dist-info}/entry_points.txt +0 -0
  51. {rslearn-0.0.14.dist-info → rslearn-0.0.16.dist-info}/licenses/LICENSE +0 -0
  52. {rslearn-0.0.14.dist-info → rslearn-0.0.16.dist-info}/licenses/NOTICE +0 -0
  53. {rslearn-0.0.14.dist-info → rslearn-0.0.16.dist-info}/top_level.txt +0 -0
@@ -17,9 +17,9 @@ from earthdaily import EDSClient, EDSConfig
17
17
  from rasterio.enums import Resampling
18
18
  from upath import UPath
19
19
 
20
- from rslearn.config import LayerConfig, QueryConfig, RasterLayerConfig
20
+ from rslearn.config import LayerConfig, QueryConfig
21
21
  from rslearn.const import WGS84_PROJECTION
22
- from rslearn.data_sources import DataSource, Item
22
+ from rslearn.data_sources import DataSource, DataSourceContext, Item
23
23
  from rslearn.data_sources.utils import match_candidate_items_to_window
24
24
  from rslearn.dataset import Window
25
25
  from rslearn.dataset.materialize import RasterMaterializer
@@ -81,10 +81,11 @@ class EarthDaily(DataSource, TileStore):
81
81
  sort_ascending: bool = True,
82
82
  timeout: timedelta = timedelta(seconds=10),
83
83
  skip_items_missing_assets: bool = False,
84
- cache_dir: UPath | None = None,
84
+ cache_dir: str | None = None,
85
85
  max_retries: int = 3,
86
86
  retry_backoff_factor: float = 5.0,
87
87
  service_name: Literal["platform"] = "platform",
88
+ context: DataSourceContext = DataSourceContext(),
88
89
  ):
89
90
  """Initialize a new EarthDaily instance.
90
91
 
@@ -108,6 +109,7 @@ class EarthDaily(DataSource, TileStore):
108
109
  `(retry_backoff_factor * (2 ** (retry_count - 1)))` seconds.
109
110
  service_name: the service name, only "platform" is supported, the other
110
111
  services "legacy" and "internal" are not supported.
112
+ context: the data source context.
111
113
  """
112
114
  self.collection_name = collection_name
113
115
  self.asset_bands = asset_bands
@@ -116,51 +118,25 @@ class EarthDaily(DataSource, TileStore):
116
118
  self.sort_ascending = sort_ascending
117
119
  self.timeout = timeout
118
120
  self.skip_items_missing_assets = skip_items_missing_assets
119
- self.cache_dir = cache_dir
120
121
  self.max_retries = max_retries
121
122
  self.retry_backoff_factor = retry_backoff_factor
122
123
  self.service_name = service_name
123
124
 
124
125
  if cache_dir is not None:
125
- self.cache_dir = cache_dir
126
+ # Use dataset path as root if provided.
127
+ if context.ds_path is not None:
128
+ self.cache_dir = join_upath(context.ds_path, cache_dir)
129
+ else:
130
+ self.cache_dir = UPath(cache_dir)
131
+
126
132
  self.cache_dir.mkdir(parents=True, exist_ok=True)
133
+ else:
134
+ self.cache_dir = None
127
135
 
128
136
  self.eds_client: EDSClient | None = None
129
137
  self.client: pystac_client.Client | None = None
130
138
  self.collection: pystac_client.CollectionClient | None = None
131
139
 
132
- @staticmethod
133
- def from_config(config: RasterLayerConfig, ds_path: UPath) -> "EarthDaily":
134
- """Creates a new EarthDaily instance from a configuration dictionary."""
135
- if config.data_source is None:
136
- raise ValueError("config.data_source is required")
137
- d = config.data_source.config_dict
138
-
139
- kwargs: dict[str, Any] = dict(
140
- collection_name=d["collection_name"],
141
- service_name=d["service_name"],
142
- asset_bands=d["asset_bands"],
143
- )
144
-
145
- if "timeout_seconds" in d:
146
- kwargs["timeout"] = timedelta(seconds=d["timeout_seconds"])
147
-
148
- if "cache_dir" in d:
149
- kwargs["cache_dir"] = join_upath(ds_path, d["cache_dir"])
150
-
151
- if "max_retries" in d:
152
- kwargs["max_retries"] = d["max_retries"]
153
-
154
- if "retry_backoff_factor" in d:
155
- kwargs["retry_backoff_factor"] = d["retry_backoff_factor"]
156
-
157
- simple_optionals = ["query", "sort_by", "sort_ascending"]
158
- for k in simple_optionals:
159
- if k in d:
160
- kwargs[k] = d[k]
161
-
162
- return EarthDaily(**kwargs)
163
-
164
140
  def _load_client(
165
141
  self,
166
142
  ) -> tuple[EDSClient, pystac_client.Client, pystac_client.CollectionClient]:
@@ -499,7 +475,6 @@ class EarthDaily(DataSource, TileStore):
499
475
  layer_name: the name of this layer
500
476
  layer_cfg: the config of this layer
501
477
  """
502
- assert isinstance(layer_cfg, RasterLayerConfig)
503
478
  RasterMaterializer().materialize(
504
479
  TileStoreWithLayer(self, layer_name),
505
480
  window,
@@ -12,9 +12,9 @@ import requests.auth
12
12
  import shapely
13
13
  from upath import UPath
14
14
 
15
- from rslearn.config import QueryConfig, RasterLayerConfig, SpaceMode
15
+ from rslearn.config import QueryConfig, SpaceMode
16
16
  from rslearn.const import WGS84_PROJECTION
17
- from rslearn.data_sources import DataSource, Item
17
+ from rslearn.data_sources import DataSource, DataSourceContext, Item
18
18
  from rslearn.log_utils import get_logger
19
19
  from rslearn.tile_stores import TileStoreWithLayer
20
20
  from rslearn.utils.geometry import STGeometry
@@ -47,8 +47,8 @@ class SRTM(DataSource):
47
47
  self,
48
48
  username: str | None = None,
49
49
  password: str | None = None,
50
- band_name: str = "srtm",
51
50
  timeout: timedelta = timedelta(seconds=10),
51
+ context: DataSourceContext = DataSourceContext(),
52
52
  ):
53
53
  """Initialize a new SRTM instance.
54
54
 
@@ -57,10 +57,19 @@ class SRTM(DataSource):
57
57
  NASA_EARTHDATA_USERNAME environment variable.
58
58
  password: NASA Earthdata account password. If not set, it is read from the
59
59
  NASA_EARTHDATA_PASSWORD environment variable.
60
- band_name: what to call the band.
61
60
  timeout: timeout for requests.
61
+ context: the data source context.
62
62
  """
63
- self.band_name = band_name
63
+ # Get band name from context if possible, falling back to "srtm".
64
+ if context.layer_config is not None:
65
+ if len(context.layer_config.band_sets) != 1:
66
+ raise ValueError("expected a single band set")
67
+ if len(context.layer_config.band_sets[0].bands) != 1:
68
+ raise ValueError("expected band set to have a single band")
69
+ self.band_name = context.layer_config.band_sets[0].bands[0]
70
+ else:
71
+ self.band_name = "srtm"
72
+
64
73
  self.timeout = timeout
65
74
 
66
75
  if username is None:
@@ -73,33 +82,6 @@ class SRTM(DataSource):
73
82
 
74
83
  self.session = requests.session()
75
84
 
76
- @staticmethod
77
- def from_config(config: RasterLayerConfig, ds_path: UPath) -> "SRTM":
78
- """Creates a new SRTM instance from a configuration dictionary."""
79
- if config.data_source is None:
80
- raise ValueError("config.data_source is required")
81
- d = config.data_source.config_dict
82
-
83
- # Get the band name chosen by the user.
84
- # There should be a single band set with a single band.
85
- if len(config.band_sets) != 1:
86
- raise ValueError("expected a single band set")
87
- if len(config.band_sets[0].bands) != 1:
88
- raise ValueError("expected band set to have a single band")
89
- kwargs: dict[str, Any] = {
90
- "band_name": config.band_sets[0].bands[0],
91
- }
92
-
93
- if "timeout_seconds" in d:
94
- kwargs["timeout"] = timedelta(seconds=d["timeout_seconds"])
95
-
96
- simple_optionals = ["username", "password"]
97
- for k in simple_optionals:
98
- if k in d:
99
- kwargs[k] = d[k]
100
-
101
- return SRTM(**kwargs)
102
-
103
85
  def get_item_by_name(self, name: str) -> Item:
104
86
  """Gets an item by name.
105
87
 
@@ -10,11 +10,10 @@ from typing import Any
10
10
  import fiona
11
11
  import requests
12
12
  from rasterio.crs import CRS
13
- from upath import UPath
14
13
 
15
- from rslearn.config import QueryConfig, VectorLayerConfig
14
+ from rslearn.config import QueryConfig
16
15
  from rslearn.const import WGS84_PROJECTION
17
- from rslearn.data_sources import DataSource, Item
16
+ from rslearn.data_sources import DataSource, DataSourceContext, Item
18
17
  from rslearn.data_sources.utils import match_candidate_items_to_window
19
18
  from rslearn.log_utils import get_logger
20
19
  from rslearn.tile_stores import TileStoreWithLayer
@@ -108,12 +107,9 @@ class EuroCrops(DataSource[EuroCropsItem]):
108
107
  }
109
108
  TIMEOUT = timedelta(seconds=10)
110
109
 
111
- @staticmethod
112
- def from_config(config: VectorLayerConfig, ds_path: UPath) -> "EuroCrops":
113
- """Creates a new EuroCrops instance from a configuration dictionary."""
114
- if config.data_source is None:
115
- raise ValueError("data_source is required")
116
- return EuroCrops()
110
+ def __init__(self, context: DataSourceContext = DataSourceContext()):
111
+ """Create a new EuroCrops."""
112
+ pass
117
113
 
118
114
  def _get_all_items(self) -> list[EuroCropsItem]:
119
115
  """Get a list of all available items in the data source."""
@@ -18,10 +18,9 @@ import tqdm
18
18
  from google.cloud import bigquery, storage
19
19
  from upath import UPath
20
20
 
21
- from rslearn.config import QueryConfig, RasterLayerConfig
21
+ from rslearn.config import QueryConfig
22
22
  from rslearn.const import WGS84_PROJECTION
23
- from rslearn.data_sources import DataSource, Item
24
- from rslearn.data_sources.raster_source import is_raster_needed
23
+ from rslearn.data_sources import DataSource, DataSourceContext, Item
25
24
  from rslearn.data_sources.utils import match_candidate_items_to_window
26
25
  from rslearn.log_utils import get_logger
27
26
  from rslearn.tile_stores import TileStoreWithLayer
@@ -157,19 +156,19 @@ class Sentinel2(DataSource):
157
156
 
158
157
  def __init__(
159
158
  self,
160
- config: RasterLayerConfig,
161
- index_cache_dir: UPath,
159
+ index_cache_dir: str,
162
160
  sort_by: str | None = None,
163
161
  use_rtree_index: bool = True,
164
162
  harmonize: bool = False,
165
163
  rtree_time_range: tuple[datetime, datetime] | None = None,
166
- rtree_cache_dir: UPath | None = None,
164
+ rtree_cache_dir: str | None = None,
167
165
  use_bigquery: bool | None = None,
166
+ bands: list[str] | None = None,
167
+ context: DataSourceContext = DataSourceContext(),
168
168
  ):
169
169
  """Initialize a new Sentinel2 instance.
170
170
 
171
171
  Args:
172
- config: the LayerConfig of the layer containing this data source.
173
172
  index_cache_dir: local directory to cache the index contents, as well as
174
173
  individual product metadata files.
175
174
  sort_by: can be "cloud_cover", default arbitrary order; only has effect for
@@ -193,6 +192,9 @@ class Sentinel2(DataSource):
193
192
  credentials, set use_bigquery=False and use_rtree_index=False. The
194
193
  default value is None which enables BigQuery when use_rtree_index=True
195
194
  and disables when use_rtree_index=False.
195
+ bands: the bands to download, or None to download all bands. This is only
196
+ used if the layer config is not in the context.
197
+ context: the data source context.
196
198
  """
197
199
  if use_bigquery is None:
198
200
  use_bigquery = use_rtree_index
@@ -201,22 +203,52 @@ class Sentinel2(DataSource):
201
203
  "use_bigquery must be enabled if use_rtree_index is enabled"
202
204
  )
203
205
 
204
- self.config = config
205
- self.index_cache_dir = index_cache_dir
206
+ # Resolve index_cache_dir and rtree_cache_dir depending on dataset context.
207
+ if context.ds_path is not None:
208
+ self.index_cache_dir = join_upath(context.ds_path, index_cache_dir)
209
+ else:
210
+ self.index_cache_dir = UPath(index_cache_dir)
211
+
212
+ if rtree_cache_dir is None:
213
+ self.rtree_cache_dir = self.index_cache_dir
214
+ elif context.ds_path is not None:
215
+ self.rtree_cache_dir = join_upath(context.ds_path, rtree_cache_dir)
216
+ else:
217
+ self.rtree_cache_dir = UPath(rtree_cache_dir)
218
+
206
219
  self.sort_by = sort_by
207
220
  self.harmonize = harmonize
208
221
  self.use_bigquery = use_bigquery
209
222
 
210
223
  self.index_cache_dir.mkdir(parents=True, exist_ok=True)
211
224
 
225
+ # Determine the subset of bands that are needed based on the layer config.
226
+ self.needed_bands: list[tuple[str, list[str]]]
227
+ if context.layer_config is not None:
228
+ self.needed_bands = []
229
+ for fname, cur_bands in self.BANDS:
230
+ # See if the bands provided by this file intersect with the bands in at
231
+ # least one configured band set.
232
+ for band_set in context.layer_config.band_sets:
233
+ if not set(band_set.bands).intersection(cur_bands):
234
+ continue
235
+ self.needed_bands.append((fname, cur_bands))
236
+ break
237
+ elif bands is not None:
238
+ self.needed_bands = []
239
+ for fname, cur_bands in self.BANDS:
240
+ if not set(bands).intersection(cur_bands):
241
+ continue
242
+ self.needed_bands.append((fname, cur_bands))
243
+ else:
244
+ self.needed_bands = list(self.BANDS)
245
+
212
246
  self.bucket = storage.Client.create_anonymous_client().bucket(self.BUCKET_NAME)
213
247
  self.rtree_index: Any | None = None
214
248
  if use_rtree_index:
215
249
  from rslearn.utils.rtree_index import RtreeIndex, get_cached_rtree
216
250
 
217
- if rtree_cache_dir is None:
218
- rtree_cache_dir = self.index_cache_dir
219
- rtree_cache_dir.mkdir(parents=True, exist_ok=True)
251
+ self.rtree_cache_dir.mkdir(parents=True, exist_ok=True)
220
252
 
221
253
  def build_fn(index: RtreeIndex) -> None:
222
254
  """Build the RtreeIndex from items in the data source."""
@@ -226,34 +258,7 @@ class Sentinel2(DataSource):
226
258
  for shp in flatten_shape(item.geometry.shp):
227
259
  index.insert(shp.bounds, json.dumps(item.serialize()))
228
260
 
229
- self.rtree_index = get_cached_rtree(rtree_cache_dir, build_fn)
230
-
231
- @staticmethod
232
- def from_config(config: RasterLayerConfig, ds_path: UPath) -> "Sentinel2":
233
- """Creates a new Sentinel2 instance from a configuration dictionary."""
234
- if config.data_source is None:
235
- raise ValueError("config.data_source is required")
236
- d = config.data_source.config_dict
237
- kwargs = dict(
238
- config=config,
239
- index_cache_dir=join_upath(ds_path, d["index_cache_dir"]),
240
- )
241
-
242
- if "rtree_time_range" in d:
243
- kwargs["rtree_time_range"] = (
244
- datetime.fromisoformat(d["rtree_time_range"][0]),
245
- datetime.fromisoformat(d["rtree_time_range"][1]),
246
- )
247
-
248
- if "rtree_cache_dir" in d:
249
- kwargs["rtree_cache_dir"] = join_upath(ds_path, d["rtree_cache_dir"])
250
-
251
- simple_optionals = ["sort_by", "use_rtree_index", "harmonize", "use_bigquery"]
252
- for k in simple_optionals:
253
- if k in d:
254
- kwargs[k] = d[k]
255
-
256
- return Sentinel2(**kwargs)
261
+ self.rtree_index = get_cached_rtree(self.rtree_cache_dir, build_fn)
257
262
 
258
263
  def _read_bigquery(
259
264
  self,
@@ -833,9 +838,7 @@ class Sentinel2(DataSource):
833
838
  geometries: a list of geometries needed for each item
834
839
  """
835
840
  for item in items:
836
- for suffix, band_names in self.BANDS:
837
- if not is_raster_needed(band_names, self.config.band_sets):
838
- continue
841
+ for suffix, band_names in self.needed_bands:
839
842
  if tile_store.is_raster_ready(item.name, band_names):
840
843
  continue
841
844
 
@@ -20,7 +20,7 @@ from google.cloud import storage
20
20
  from upath import UPath
21
21
 
22
22
  import rslearn.data_sources.utils
23
- from rslearn.config import DType, LayerConfig, RasterLayerConfig
23
+ from rslearn.config import DType, LayerConfig
24
24
  from rslearn.const import WGS84_PROJECTION
25
25
  from rslearn.dataset.materialize import RasterMaterializer
26
26
  from rslearn.dataset.window import Window
@@ -36,7 +36,7 @@ from rslearn.utils.raster_format import (
36
36
  )
37
37
  from rslearn.utils.rtree_index import RtreeIndex, get_cached_rtree
38
38
 
39
- from .data_source import DataSource, Item, QueryConfig
39
+ from .data_source import DataSource, DataSourceContext, Item, QueryConfig
40
40
 
41
41
  logger = get_logger(__name__)
42
42
 
@@ -61,34 +61,55 @@ class GEE(DataSource, TileStore):
61
61
  self,
62
62
  collection_name: str,
63
63
  gcs_bucket_name: str,
64
- bands: list[str],
65
- index_cache_dir: UPath,
64
+ index_cache_dir: str,
66
65
  service_account_name: str,
67
66
  service_account_credentials: str,
67
+ bands: list[str] | None = None,
68
68
  filters: list[tuple[str, Any]] | None = None,
69
69
  dtype: DType | None = None,
70
+ context: DataSourceContext = DataSourceContext(),
70
71
  ) -> None:
71
72
  """Initialize a new GEE instance.
72
73
 
73
74
  Args:
74
75
  collection_name: the Earth Engine ImageCollection to ingest images from
75
76
  gcs_bucket_name: the Cloud Storage bucket to export GEE images to
76
- bands: the list of bands to ingest
77
77
  index_cache_dir: cache directory to store rtree index
78
78
  service_account_name: name of the service account to use for authentication
79
79
  service_account_credentials: service account credentials filename
80
+ bands: the list of bands to ingest, in case the layer config is not present
81
+ in the context.
80
82
  filters: optional list of tuples (property_name, property_value) to filter
81
83
  images (using ee.Filter.eq)
82
84
  dtype: optional desired array data type. If the data obtained from GEE does
83
85
  not match this type, then it is converted.
86
+ context: the data source context.
84
87
  """
85
88
  self.collection_name = collection_name
86
89
  self.gcs_bucket_name = gcs_bucket_name
87
- self.bands = bands
88
- self.index_cache_dir = index_cache_dir
89
90
  self.filters = filters
90
91
  self.dtype = dtype
91
92
 
93
+ # Get index cache dir depending on dataset path.
94
+ if context.ds_path is not None:
95
+ self.index_cache_dir = join_upath(context.ds_path, index_cache_dir)
96
+ else:
97
+ self.index_cache_dir = UPath(index_cache_dir)
98
+
99
+ # Get bands we need to export.
100
+ if context.layer_config is not None:
101
+ self.bands = [
102
+ band
103
+ for band_set in context.layer_config.band_sets
104
+ for band in band_set.bands
105
+ ]
106
+ elif bands is not None:
107
+ self.bands = bands
108
+ else:
109
+ raise ValueError(
110
+ "bands must be specified if layer_config is not present in the context"
111
+ )
112
+
92
113
  self.bucket = storage.Client().bucket(self.gcs_bucket_name)
93
114
 
94
115
  credentials = ee.ServiceAccountCredentials(
@@ -99,27 +120,6 @@ class GEE(DataSource, TileStore):
99
120
  self.index_cache_dir.mkdir(parents=True, exist_ok=True)
100
121
  self.rtree_index = get_cached_rtree(self.index_cache_dir, self._build_index)
101
122
 
102
- @staticmethod
103
- def from_config(config: RasterLayerConfig, ds_path: UPath) -> "GEE":
104
- """Creates a new GEE instance from a configuration dictionary."""
105
- if config.data_source is None:
106
- raise ValueError("data_source is required in config")
107
- d = config.data_source.config_dict
108
- bands = [band for band_set in config.band_sets for band in band_set.bands]
109
- kwargs = {
110
- "collection_name": d["collection_name"],
111
- "gcs_bucket_name": d["gcs_bucket_name"],
112
- "bands": bands,
113
- "service_account_name": d["service_account_name"],
114
- "service_account_credentials": d["service_account_credentials"],
115
- "filters": d.get("filters"),
116
- "index_cache_dir": join_upath(ds_path, d["index_cache_dir"]),
117
- }
118
- if "dtype" in d:
119
- kwargs["dtype"] = DType(d["dtype"])
120
-
121
- return GEE(**kwargs)
122
-
123
123
  def get_collection(self) -> ee.ImageCollection:
124
124
  """Returns the Earth Engine image collection for this data source."""
125
125
  image_collection = ee.ImageCollection(self.collection_name)
@@ -578,7 +578,6 @@ class GEE(DataSource, TileStore):
578
578
  layer_name: the name of this layer
579
579
  layer_cfg: the config of this layer
580
580
  """
581
- assert isinstance(layer_cfg, RasterLayerConfig)
582
581
  RasterMaterializer().materialize(
583
582
  TileStoreWithLayer(self, layer_name),
584
583
  window,
@@ -600,9 +599,10 @@ class GoogleSatelliteEmbeddings(GEE):
600
599
  def __init__(
601
600
  self,
602
601
  gcs_bucket_name: str,
603
- index_cache_dir: UPath,
602
+ index_cache_dir: str,
604
603
  service_account_name: str,
605
604
  service_account_credentials: str,
605
+ context: DataSourceContext = DataSourceContext(),
606
606
  ):
607
607
  """Create a new GoogleSatelliteEmbeddings. See GEE for the arguments."""
608
608
  super().__init__(
@@ -612,22 +612,9 @@ class GoogleSatelliteEmbeddings(GEE):
612
612
  index_cache_dir=index_cache_dir,
613
613
  service_account_name=service_account_name,
614
614
  service_account_credentials=service_account_credentials,
615
+ context=context,
615
616
  )
616
617
 
617
- @staticmethod
618
- def from_config(config: RasterLayerConfig, ds_path: UPath) -> "GEE":
619
- """Creates a new GEE instance from a configuration dictionary."""
620
- if config.data_source is None:
621
- raise ValueError("data_source is required in config")
622
- d = config.data_source.config_dict
623
- kwargs = {
624
- "gcs_bucket_name": d["gcs_bucket_name"],
625
- "index_cache_dir": join_upath(ds_path, d["index_cache_dir"]),
626
- "service_account_name": d["service_account_name"],
627
- "service_account_credentials": d["service_account_credentials"],
628
- }
629
- return GoogleSatelliteEmbeddings(**kwargs)
630
-
631
618
  # Override to add conversion to uint16.
632
619
  def item_to_image(self, item: Item) -> ee.image.Image:
633
620
  """Get the Image corresponding to the Item."""