rs-embed 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rs_embed/__init__.py +72 -0
- rs_embed/__main__.py +8 -0
- rs_embed/_version.py +1 -0
- rs_embed/api.py +500 -0
- rs_embed/cli.py +308 -0
- rs_embed/core/__init__.py +17 -0
- rs_embed/core/embedding.py +26 -0
- rs_embed/core/errors.py +22 -0
- rs_embed/core/registry.py +152 -0
- rs_embed/core/specs.py +518 -0
- rs_embed/core/types.py +362 -0
- rs_embed/core/validation.py +123 -0
- rs_embed/embedders/__init__.py +39 -0
- rs_embed/embedders/_vendor/LICENSE.agrifm +201 -0
- rs_embed/embedders/_vendor/LICENSE.dofa +21 -0
- rs_embed/embedders/_vendor/LICENSE.fomo +24 -0
- rs_embed/embedders/_vendor/LICENSE.galileo +24 -0
- rs_embed/embedders/_vendor/LICENSE.prithvi +12 -0
- rs_embed/embedders/_vendor/LICENSE.satmaepp +201 -0
- rs_embed/embedders/_vendor/LICENSE.terrafm +201 -0
- rs_embed/embedders/_vendor/LICENSE.thor +23 -0
- rs_embed/embedders/_vendor/LICENSE.thor_terratorch_ext +201 -0
- rs_embed/embedders/_vendor/LICENSE.torchgeo +21 -0
- rs_embed/embedders/_vendor/__init__.py +1 -0
- rs_embed/embedders/_vendor/agrifm_video_swin_transformer.py +1058 -0
- rs_embed/embedders/_vendor/anysat/LICENSE +21 -0
- rs_embed/embedders/_vendor/anysat/__init__.py +0 -0
- rs_embed/embedders/_vendor/anysat/hubconf.py +315 -0
- rs_embed/embedders/_vendor/anysat/src/__init__.py +0 -0
- rs_embed/embedders/_vendor/anysat/src/models/__init__.py +0 -0
- rs_embed/embedders/_vendor/anysat/src/models/networks/__init__.py +0 -0
- rs_embed/embedders/_vendor/anysat/src/models/networks/encoder/Any_multi.py +277 -0
- rs_embed/embedders/_vendor/anysat/src/models/networks/encoder/Transformer.py +427 -0
- rs_embed/embedders/_vendor/anysat/src/models/networks/encoder/__init__.py +0 -0
- rs_embed/embedders/_vendor/anysat/src/models/networks/encoder/utils/__init__.py +0 -0
- rs_embed/embedders/_vendor/anysat/src/models/networks/encoder/utils/irpe.py +961 -0
- rs_embed/embedders/_vendor/anysat/src/models/networks/encoder/utils/ltae.py +337 -0
- rs_embed/embedders/_vendor/anysat/src/models/networks/encoder/utils/patch_embeddings.py +81 -0
- rs_embed/embedders/_vendor/anysat/src/models/networks/encoder/utils/pos_embed.py +122 -0
- rs_embed/embedders/_vendor/anysat/src/models/networks/encoder/utils/utils.py +193 -0
- rs_embed/embedders/_vendor/anysat/src/models/networks/encoder/utils/utils_ViT.py +551 -0
- rs_embed/embedders/_vendor/copernicus_embed.py +339 -0
- rs_embed/embedders/_vendor/dofa_vit.py +132 -0
- rs_embed/embedders/_vendor/dofa_wave_dynamic_layer.py +165 -0
- rs_embed/embedders/_vendor/fomo_multimodal_mae.py +164 -0
- rs_embed/embedders/_vendor/galileo_single_file.py +1185 -0
- rs_embed/embedders/_vendor/prithvi_mae.py +766 -0
- rs_embed/embedders/_vendor/satmaepp_s2/__init__.py +1 -0
- rs_embed/embedders/_vendor/satmaepp_s2/models_mae_group_channels.py +465 -0
- rs_embed/embedders/_vendor/satmaepp_s2/util/__init__.py +1 -0
- rs_embed/embedders/_vendor/satmaepp_s2/util/pos_embed.py +117 -0
- rs_embed/embedders/_vendor/terrafm.py +377 -0
- rs_embed/embedders/_vendor/thor/__init__.py +2 -0
- rs_embed/embedders/_vendor/thor/core/__init__.py +2 -0
- rs_embed/embedders/_vendor/thor/core/model_registry.py +165 -0
- rs_embed/embedders/_vendor/thor/models/__init__.py +5 -0
- rs_embed/embedders/_vendor/thor/models/patch_timm.py +110 -0
- rs_embed/embedders/_vendor/thor/models/thor_vit.py +1103 -0
- rs_embed/embedders/_vendor/thor/utils/__init__.py +2 -0
- rs_embed/embedders/_vendor/thor/utils/helper.py +23 -0
- rs_embed/embedders/_vendor/thor/utils/patch_embed.py +1311 -0
- rs_embed/embedders/_vendor/thor/utils/pos_embed.py +162 -0
- rs_embed/embedders/_vendor/thor_vit.py +556 -0
- rs_embed/embedders/_vit_mae_utils.py +241 -0
- rs_embed/embedders/base.py +328 -0
- rs_embed/embedders/catalog.py +52 -0
- rs_embed/embedders/config_utils.py +31 -0
- rs_embed/embedders/meta_utils.py +137 -0
- rs_embed/embedders/onthefly_agrifm.py +852 -0
- rs_embed/embedders/onthefly_anysat.py +702 -0
- rs_embed/embedders/onthefly_dofa.py +1157 -0
- rs_embed/embedders/onthefly_fomo.py +753 -0
- rs_embed/embedders/onthefly_galileo.py +778 -0
- rs_embed/embedders/onthefly_prithvi.py +1015 -0
- rs_embed/embedders/onthefly_remoteclip.py +1028 -0
- rs_embed/embedders/onthefly_satmae.py +524 -0
- rs_embed/embedders/onthefly_satmaepp.py +766 -0
- rs_embed/embedders/onthefly_satmaepp_s2.py +1108 -0
- rs_embed/embedders/onthefly_satvision_toa.py +1306 -0
- rs_embed/embedders/onthefly_scalemae.py +798 -0
- rs_embed/embedders/onthefly_terrafm.py +1043 -0
- rs_embed/embedders/onthefly_terramind.py +998 -0
- rs_embed/embedders/onthefly_thor.py +871 -0
- rs_embed/embedders/onthefly_wildsat.py +958 -0
- rs_embed/embedders/precomputed_copernicus_embed.py +294 -0
- rs_embed/embedders/precomputed_gse_annual.py +170 -0
- rs_embed/embedders/precomputed_tessera.py +451 -0
- rs_embed/embedders/runtime_utils.py +615 -0
- rs_embed/export.py +56 -0
- rs_embed/inspect.py +104 -0
- rs_embed/model.py +246 -0
- rs_embed/pipelines/__init__.py +33 -0
- rs_embed/pipelines/checkpoint.py +285 -0
- rs_embed/pipelines/combined_flow.py +329 -0
- rs_embed/pipelines/exporter.py +781 -0
- rs_embed/pipelines/inference.py +674 -0
- rs_embed/pipelines/point_payload.py +320 -0
- rs_embed/pipelines/prefetch.py +288 -0
- rs_embed/pipelines/runner.py +90 -0
- rs_embed/providers/__init__.py +93 -0
- rs_embed/providers/base.py +334 -0
- rs_embed/providers/gee.py +620 -0
- rs_embed/providers/gee_utils.py +392 -0
- rs_embed/providers/prefetch_plan.py +157 -0
- rs_embed/tools/__init__.py +21 -0
- rs_embed/tools/checkpoint_utils.py +135 -0
- rs_embed/tools/export_requests.py +166 -0
- rs_embed/tools/inspection.py +343 -0
- rs_embed/tools/manifest.py +146 -0
- rs_embed/tools/model_defaults.py +307 -0
- rs_embed/tools/normalization.py +112 -0
- rs_embed/tools/output.py +107 -0
- rs_embed/tools/progress.py +56 -0
- rs_embed/tools/runtime.py +475 -0
- rs_embed/tools/serialization.py +101 -0
- rs_embed/tools/temporal.py +65 -0
- rs_embed/tools/tiling.py +683 -0
- rs_embed/writers.py +235 -0
- rs_embed-0.1.1.dist-info/METADATA +241 -0
- rs_embed-0.1.1.dist-info/RECORD +123 -0
- rs_embed-0.1.1.dist-info/WHEEL +4 -0
- rs_embed-0.1.1.dist-info/entry_points.txt +2 -0
- rs_embed-0.1.1.dist-info/licenses/LICENSE +201 -0
rs_embed/__init__.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
"""rs-embed — location embeddings from remote-sensing imagery.
|
|
2
|
+
|
|
3
|
+
This top-level package re-exports the public API surface so that users can
|
|
4
|
+
write ``from rs_embed import Model, BBox`` without reaching into subpackages.
|
|
5
|
+
|
|
6
|
+
See :mod:`rs_embed.api` for the function-based API and :class:`rs_embed.Model`
|
|
7
|
+
for the class-based (stateful) interface.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from ._version import __version__
|
|
11
|
+
from .api import (
|
|
12
|
+
describe_model,
|
|
13
|
+
export_batch,
|
|
14
|
+
get_embedding,
|
|
15
|
+
get_embeddings_batch,
|
|
16
|
+
list_models,
|
|
17
|
+
reset_runtime,
|
|
18
|
+
)
|
|
19
|
+
from .core.specs import (
|
|
20
|
+
BBox,
|
|
21
|
+
FetchSpec,
|
|
22
|
+
InputPrepSpec,
|
|
23
|
+
OutputSpec,
|
|
24
|
+
PointBuffer,
|
|
25
|
+
SensorSpec,
|
|
26
|
+
TemporalSpec,
|
|
27
|
+
)
|
|
28
|
+
from .core.types import (
|
|
29
|
+
ExportConfig,
|
|
30
|
+
ExportLayout,
|
|
31
|
+
ExportModelRequest,
|
|
32
|
+
ExportTarget,
|
|
33
|
+
ModelConfig,
|
|
34
|
+
)
|
|
35
|
+
from .export import export_npz
|
|
36
|
+
from .inspect import inspect_gee_patch, inspect_provider_patch
|
|
37
|
+
from .model import Model
|
|
38
|
+
from .pipelines.exporter import BatchExporter
|
|
39
|
+
|
|
40
|
+
__all__ = [
|
|
41
|
+
# Specs
|
|
42
|
+
"BBox",
|
|
43
|
+
"PointBuffer",
|
|
44
|
+
"TemporalSpec",
|
|
45
|
+
"SensorSpec",
|
|
46
|
+
"FetchSpec",
|
|
47
|
+
"OutputSpec",
|
|
48
|
+
"InputPrepSpec",
|
|
49
|
+
# Types
|
|
50
|
+
"ExportConfig",
|
|
51
|
+
"ExportLayout",
|
|
52
|
+
"ExportModelRequest",
|
|
53
|
+
"ExportTarget",
|
|
54
|
+
"ModelConfig",
|
|
55
|
+
# Embedding API (class-based)
|
|
56
|
+
"Model",
|
|
57
|
+
"BatchExporter",
|
|
58
|
+
# Embedding API (function-based, backward compat)
|
|
59
|
+
"get_embedding",
|
|
60
|
+
"get_embeddings_batch",
|
|
61
|
+
"list_models",
|
|
62
|
+
"describe_model",
|
|
63
|
+
"reset_runtime",
|
|
64
|
+
# Export API
|
|
65
|
+
"export_batch",
|
|
66
|
+
"export_npz",
|
|
67
|
+
# Inspection
|
|
68
|
+
"inspect_provider_patch",
|
|
69
|
+
# Backward-compatible alias for inspect_provider_patch
|
|
70
|
+
"inspect_gee_patch",
|
|
71
|
+
"__version__",
|
|
72
|
+
]
|
rs_embed/__main__.py
ADDED
rs_embed/_version.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = "0.1.1"
|
rs_embed/api.py
ADDED
|
@@ -0,0 +1,500 @@
|
|
|
1
|
+
"""Public Facade and Entry Point.
|
|
2
|
+
|
|
3
|
+
This module is the single boundary between callers and the internal pipeline
|
|
4
|
+
stack. It should contain **no** heavy execution logic — only configuration
|
|
5
|
+
and delegation.
|
|
6
|
+
|
|
7
|
+
Responsibilities
|
|
8
|
+
----------------
|
|
9
|
+
1. **Validation** — ensure all input Specs (Spatial, Temporal, Output) are
|
|
10
|
+
valid before any processing begins.
|
|
11
|
+
2. **Normalisation** — convert user-friendly strings (e.g. ``"sentinel-2"``,
|
|
12
|
+
``"cuda"``) into strict internal objects.
|
|
13
|
+
3. **Context Resolution** — route each request to the correct backend
|
|
14
|
+
(Provider vs. Precomputed) and device.
|
|
15
|
+
4. **Delegation** — instantiate the appropriate Pipeline or Embedder and
|
|
16
|
+
hand off execution.
|
|
17
|
+
|
|
18
|
+
Flow summary
|
|
19
|
+
------------
|
|
20
|
+
1. Validate and normalise user inputs / specs.
|
|
21
|
+
2. Resolve request context (model / backend / device / sensor / input prep).
|
|
22
|
+
3. Execute single / batch embedding, or delegate batch export to
|
|
23
|
+
:class:`BatchExporter`.
|
|
24
|
+
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
from __future__ import annotations
|
|
28
|
+
|
|
29
|
+
from dataclasses import replace
|
|
30
|
+
from typing import Any
|
|
31
|
+
|
|
32
|
+
from .core.embedding import Embedding
|
|
33
|
+
from .core.errors import ModelError
|
|
34
|
+
from .core.registry import get_embedder_cls as _get_embedder_cls
|
|
35
|
+
from .core.specs import (
|
|
36
|
+
FetchSpec,
|
|
37
|
+
InputPrepSpec,
|
|
38
|
+
OutputSpec,
|
|
39
|
+
SensorSpec,
|
|
40
|
+
SpatialSpec,
|
|
41
|
+
TemporalSpec,
|
|
42
|
+
)
|
|
43
|
+
from .core.types import (
|
|
44
|
+
ExportConfig,
|
|
45
|
+
ExportModelRequest,
|
|
46
|
+
ExportTarget,
|
|
47
|
+
)
|
|
48
|
+
from .core.validation import (
|
|
49
|
+
assert_supported as _assert_supported,
|
|
50
|
+
)
|
|
51
|
+
from .core.validation import (
|
|
52
|
+
validate_spatial_list as _validate_spatial_list,
|
|
53
|
+
)
|
|
54
|
+
from .core.validation import (
|
|
55
|
+
validate_specs as _validate_specs,
|
|
56
|
+
)
|
|
57
|
+
from .embedders.catalog import MODEL_ALIASES, MODEL_SPECS
|
|
58
|
+
from .tools.export_requests import (
|
|
59
|
+
maybe_return_completed_combined_resume as _maybe_return_completed_combined_resume,
|
|
60
|
+
)
|
|
61
|
+
from .tools.export_requests import (
|
|
62
|
+
normalize_export_format as _normalize_export_format,
|
|
63
|
+
)
|
|
64
|
+
from .tools.export_requests import (
|
|
65
|
+
normalize_export_target as _normalize_export_target,
|
|
66
|
+
)
|
|
67
|
+
from .tools.export_requests import (
|
|
68
|
+
resolve_export_model_configs as _resolve_export_model_configs,
|
|
69
|
+
)
|
|
70
|
+
from .tools.model_defaults import (
|
|
71
|
+
resolve_sensor_for_model as _resolve_sensor_for_model,
|
|
72
|
+
)
|
|
73
|
+
from .tools.normalization import (
|
|
74
|
+
# Re-exported so `from rs_embed.api import ...` in tests/downstream still works.
|
|
75
|
+
_default_provider_backend_for_api, # noqa: F401
|
|
76
|
+
_probe_model_describe, # noqa: F401
|
|
77
|
+
_resolve_embedding_api_backend, # noqa: F401
|
|
78
|
+
)
|
|
79
|
+
from .tools.normalization import (
|
|
80
|
+
normalize_backend_name as _normalize_backend_name,
|
|
81
|
+
)
|
|
82
|
+
from .tools.normalization import (
|
|
83
|
+
normalize_device_name as _normalize_device_name,
|
|
84
|
+
)
|
|
85
|
+
from .tools.normalization import (
|
|
86
|
+
normalize_model_name as _normalize_model_name,
|
|
87
|
+
)
|
|
88
|
+
from .tools.progress import create_progress as _create_progress
|
|
89
|
+
from .tools.runtime import (
|
|
90
|
+
_prepare_embedding_request_context,
|
|
91
|
+
provider_factory_for_backend,
|
|
92
|
+
)
|
|
93
|
+
from .tools.runtime import (
|
|
94
|
+
reset_runtime as _reset_runtime_shared,
|
|
95
|
+
)
|
|
96
|
+
from .tools.runtime import (
|
|
97
|
+
run_embedding_request as _run_embedding_request_shared,
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
# -----------------------------------------------------------------------------
|
|
101
|
+
# Public: embeddings
|
|
102
|
+
# -----------------------------------------------------------------------------
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def list_models(*, include_aliases: bool = False) -> list[str]:
|
|
106
|
+
"""Return the stable model catalog, independent of runtime lazy-load state.
|
|
107
|
+
|
|
108
|
+
Parameters
|
|
109
|
+
----------
|
|
110
|
+
include_aliases : bool
|
|
111
|
+
If ``True``, include alias names in addition to canonical ids.
|
|
112
|
+
|
|
113
|
+
Returns
|
|
114
|
+
-------
|
|
115
|
+
list[str]
|
|
116
|
+
Sorted model names available in the catalog.
|
|
117
|
+
"""
|
|
118
|
+
model_ids = set(MODEL_SPECS.keys())
|
|
119
|
+
if include_aliases:
|
|
120
|
+
model_ids.update(MODEL_ALIASES.keys())
|
|
121
|
+
return sorted(model_ids)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def describe_model(model: str) -> dict[str, Any]:
|
|
125
|
+
"""Return metadata for a model without loading its weights.
|
|
126
|
+
|
|
127
|
+
Instantiates the embedder class (a lightweight operation — no checkpoint
|
|
128
|
+
download, no torch import) and calls its :meth:`~EmbedderBase.describe`
|
|
129
|
+
method, which returns a plain dict of static configuration.
|
|
130
|
+
|
|
131
|
+
Parameters
|
|
132
|
+
----------
|
|
133
|
+
model : str
|
|
134
|
+
Canonical model id or any registered alias (e.g. ``"prithvi"``,
|
|
135
|
+
``"satmae"``, ``"galileo"``). Call :func:`list_models` to see all
|
|
136
|
+
available ids.
|
|
137
|
+
|
|
138
|
+
Returns
|
|
139
|
+
-------
|
|
140
|
+
dict[str, Any]
|
|
141
|
+
Model metadata including input bands, supported output modes,
|
|
142
|
+
default parameters, and architecture notes. The exact keys vary
|
|
143
|
+
per model but always include ``"type"`` and ``"output"``.
|
|
144
|
+
|
|
145
|
+
Raises
|
|
146
|
+
------
|
|
147
|
+
ModelError
|
|
148
|
+
If *model* is not a known id or alias.
|
|
149
|
+
|
|
150
|
+
Examples
|
|
151
|
+
--------
|
|
152
|
+
>>> from rs_embed import describe_model
|
|
153
|
+
>>> info = describe_model("galileo")
|
|
154
|
+
>>> info["output"]
|
|
155
|
+
['pooled', 'grid']
|
|
156
|
+
"""
|
|
157
|
+
model_n = _normalize_model_name(model)
|
|
158
|
+
cls = _get_embedder_cls(model_n)
|
|
159
|
+
return cls().describe()
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def reset_runtime() -> dict[str, int]:
|
|
163
|
+
"""Clear lazy-import/runtime caches in the current Python process.
|
|
164
|
+
|
|
165
|
+
This is mainly useful in notebooks after a failed model import or when you
|
|
166
|
+
want to force fresh embedder instances without restarting the kernel.
|
|
167
|
+
|
|
168
|
+
Returns
|
|
169
|
+
-------
|
|
170
|
+
dict[str, int]
|
|
171
|
+
Summary counts describing how many runtime/import caches were cleared.
|
|
172
|
+
"""
|
|
173
|
+
return _reset_runtime_shared()
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def get_embedding(
|
|
177
|
+
model: str,
|
|
178
|
+
*,
|
|
179
|
+
spatial: SpatialSpec,
|
|
180
|
+
temporal: TemporalSpec | None = None,
|
|
181
|
+
sensor: SensorSpec | None = None,
|
|
182
|
+
fetch: FetchSpec | None = None,
|
|
183
|
+
modality: str | None = None,
|
|
184
|
+
output: OutputSpec = OutputSpec.pooled(),
|
|
185
|
+
backend: str = "auto",
|
|
186
|
+
device: str = "auto",
|
|
187
|
+
input_prep: InputPrepSpec | str | None = "resize",
|
|
188
|
+
**model_kwargs: Any,
|
|
189
|
+
) -> Embedding:
|
|
190
|
+
"""Compute a single embedding.
|
|
191
|
+
|
|
192
|
+
Parameters
|
|
193
|
+
----------
|
|
194
|
+
model : str
|
|
195
|
+
Model identifier or alias.
|
|
196
|
+
spatial : SpatialSpec
|
|
197
|
+
Spatial location/extent to embed.
|
|
198
|
+
temporal : TemporalSpec or None
|
|
199
|
+
Optional temporal filter.
|
|
200
|
+
sensor : SensorSpec or None
|
|
201
|
+
Optional sensor override.
|
|
202
|
+
fetch : FetchSpec or None
|
|
203
|
+
Lightweight fetch-policy override applied to the model default sensor.
|
|
204
|
+
Cannot be combined with ``sensor``.
|
|
205
|
+
modality : str or None
|
|
206
|
+
Optional modality selector for models that expose multiple input
|
|
207
|
+
branches.
|
|
208
|
+
output : OutputSpec
|
|
209
|
+
Output representation policy.
|
|
210
|
+
backend : str
|
|
211
|
+
Backend/provider selector (for example ``"auto"`` or ``"gee"``).
|
|
212
|
+
device : str
|
|
213
|
+
Target inference device.
|
|
214
|
+
input_prep : InputPrepSpec or str or None
|
|
215
|
+
Optional API-side input preprocessing policy.
|
|
216
|
+
**model_kwargs
|
|
217
|
+
Model-specific settings passed directly as keyword arguments.
|
|
218
|
+
For example, ``variant="large"`` selects the large DOFA variant.
|
|
219
|
+
The accepted keys depend on the model; call :func:`describe_model`
|
|
220
|
+
to see the ``"model_config"`` schema for a given model.
|
|
221
|
+
|
|
222
|
+
Returns
|
|
223
|
+
-------
|
|
224
|
+
Embedding
|
|
225
|
+
Normalized embedding output for the requested location.
|
|
226
|
+
|
|
227
|
+
Raises
|
|
228
|
+
------
|
|
229
|
+
ModelError
|
|
230
|
+
If inputs/specs are invalid or requested model/backend configuration is
|
|
231
|
+
unsupported.
|
|
232
|
+
SpecError
|
|
233
|
+
If spatial or temporal specifications fail validation.
|
|
234
|
+
|
|
235
|
+
Notes
|
|
236
|
+
-----
|
|
237
|
+
This function reuses a cached embedder instance when possible to avoid
|
|
238
|
+
repeatedly loading model weights / initializing providers.
|
|
239
|
+
|
|
240
|
+
Examples
|
|
241
|
+
--------
|
|
242
|
+
>>> emb = get_embedding("dofa", spatial=point, temporal=t, variant="large")
|
|
243
|
+
"""
|
|
244
|
+
model_config = model_kwargs or None
|
|
245
|
+
_validate_specs(spatial=spatial, temporal=temporal, output=output)
|
|
246
|
+
sensor_eff = _resolve_sensor_for_model(
|
|
247
|
+
_normalize_model_name(model),
|
|
248
|
+
sensor=sensor,
|
|
249
|
+
fetch=fetch,
|
|
250
|
+
modality=modality,
|
|
251
|
+
default_when_missing=False,
|
|
252
|
+
)
|
|
253
|
+
ctx = _prepare_embedding_request_context(
|
|
254
|
+
model=model,
|
|
255
|
+
temporal=temporal,
|
|
256
|
+
sensor=sensor_eff,
|
|
257
|
+
model_config=model_config,
|
|
258
|
+
output=output,
|
|
259
|
+
backend=backend,
|
|
260
|
+
device=device,
|
|
261
|
+
input_prep=input_prep,
|
|
262
|
+
)
|
|
263
|
+
return _run_embedding_request_shared(
|
|
264
|
+
spatials=[spatial],
|
|
265
|
+
temporal=temporal,
|
|
266
|
+
sensor=sensor_eff,
|
|
267
|
+
output=output,
|
|
268
|
+
ctx=ctx,
|
|
269
|
+
)[0]
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
def get_embeddings_batch(
|
|
273
|
+
model: str,
|
|
274
|
+
*,
|
|
275
|
+
spatials: list[SpatialSpec],
|
|
276
|
+
temporal: TemporalSpec | None = None,
|
|
277
|
+
sensor: SensorSpec | None = None,
|
|
278
|
+
fetch: FetchSpec | None = None,
|
|
279
|
+
modality: str | None = None,
|
|
280
|
+
output: OutputSpec = OutputSpec.pooled(),
|
|
281
|
+
backend: str = "auto",
|
|
282
|
+
device: str = "auto",
|
|
283
|
+
input_prep: InputPrepSpec | str | None = "resize",
|
|
284
|
+
**model_kwargs: Any,
|
|
285
|
+
) -> list[Embedding]:
|
|
286
|
+
"""Compute embeddings for multiple spatials using a shared embedder instance.
|
|
287
|
+
|
|
288
|
+
Parameters
|
|
289
|
+
----------
|
|
290
|
+
model : str
|
|
291
|
+
Model identifier or alias.
|
|
292
|
+
spatials : list[SpatialSpec]
|
|
293
|
+
Spatial requests to embed.
|
|
294
|
+
temporal : TemporalSpec or None
|
|
295
|
+
Optional temporal filter.
|
|
296
|
+
sensor : SensorSpec or None
|
|
297
|
+
Optional sensor override.
|
|
298
|
+
fetch : FetchSpec or None
|
|
299
|
+
Lightweight fetch-policy override applied to the model default sensor.
|
|
300
|
+
Cannot be combined with ``sensor``.
|
|
301
|
+
modality : str or None
|
|
302
|
+
Optional modality selector for models that expose multiple input
|
|
303
|
+
branches.
|
|
304
|
+
output : OutputSpec
|
|
305
|
+
Output representation policy.
|
|
306
|
+
backend : str
|
|
307
|
+
Backend/provider selector.
|
|
308
|
+
device : str
|
|
309
|
+
Target inference device.
|
|
310
|
+
input_prep : InputPrepSpec or str or None
|
|
311
|
+
Optional API-side input preprocessing policy.
|
|
312
|
+
**model_kwargs
|
|
313
|
+
Model-specific settings passed directly as keyword arguments.
|
|
314
|
+
For example, ``variant="large"`` selects the large DOFA variant.
|
|
315
|
+
The accepted keys depend on the model; call :func:`describe_model`
|
|
316
|
+
to see the ``"model_config"`` schema for a given model.
|
|
317
|
+
|
|
318
|
+
Returns
|
|
319
|
+
-------
|
|
320
|
+
list[Embedding]
|
|
321
|
+
Embeddings in the same order as ``spatials``.
|
|
322
|
+
|
|
323
|
+
Raises
|
|
324
|
+
------
|
|
325
|
+
ModelError
|
|
326
|
+
If inputs/specs are invalid or requested model/backend configuration is
|
|
327
|
+
unsupported.
|
|
328
|
+
SpecError
|
|
329
|
+
If spatial or temporal specifications fail validation.
|
|
330
|
+
|
|
331
|
+
Examples
|
|
332
|
+
--------
|
|
333
|
+
>>> embs = get_embeddings_batch("dofa", spatials=points, temporal=t, variant="large")
|
|
334
|
+
"""
|
|
335
|
+
model_config = model_kwargs or None
|
|
336
|
+
_validate_spatial_list(spatials=spatials, temporal=temporal, output=output)
|
|
337
|
+
sensor_eff = _resolve_sensor_for_model(
|
|
338
|
+
_normalize_model_name(model),
|
|
339
|
+
sensor=sensor,
|
|
340
|
+
fetch=fetch,
|
|
341
|
+
modality=modality,
|
|
342
|
+
default_when_missing=False,
|
|
343
|
+
)
|
|
344
|
+
ctx = _prepare_embedding_request_context(
|
|
345
|
+
model=model,
|
|
346
|
+
temporal=temporal,
|
|
347
|
+
sensor=sensor_eff,
|
|
348
|
+
model_config=model_config,
|
|
349
|
+
output=output,
|
|
350
|
+
backend=backend,
|
|
351
|
+
device=device,
|
|
352
|
+
input_prep=input_prep,
|
|
353
|
+
)
|
|
354
|
+
return _run_embedding_request_shared(
|
|
355
|
+
spatials=spatials,
|
|
356
|
+
temporal=temporal,
|
|
357
|
+
sensor=sensor_eff,
|
|
358
|
+
output=output,
|
|
359
|
+
ctx=ctx,
|
|
360
|
+
)
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
# -----------------------------------------------------------------------------
|
|
364
|
+
# Public: batch export (core)
|
|
365
|
+
# -----------------------------------------------------------------------------
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
def export_batch(
|
|
369
|
+
*,
|
|
370
|
+
spatials: list[SpatialSpec],
|
|
371
|
+
temporal: TemporalSpec | None,
|
|
372
|
+
models: list[str | ExportModelRequest],
|
|
373
|
+
target: ExportTarget,
|
|
374
|
+
config: ExportConfig = ExportConfig(),
|
|
375
|
+
backend: str = "auto",
|
|
376
|
+
device: str = "auto",
|
|
377
|
+
output: OutputSpec = OutputSpec.pooled(),
|
|
378
|
+
sensor: SensorSpec | None = None,
|
|
379
|
+
fetch: FetchSpec | None = None,
|
|
380
|
+
modality: str | None = None,
|
|
381
|
+
per_model_sensors: dict[str, SensorSpec] | None = None,
|
|
382
|
+
per_model_fetches: dict[str, FetchSpec] | None = None,
|
|
383
|
+
per_model_modalities: dict[str, str] | None = None,
|
|
384
|
+
) -> Any:
|
|
385
|
+
"""Export inputs + embeddings for many spatials and many models.
|
|
386
|
+
|
|
387
|
+
This is the recommended high-level entrypoint for batch export.
|
|
388
|
+
Delegates to :class:`~rs_embed.pipelines.exporter.BatchExporter`.
|
|
389
|
+
|
|
390
|
+
Parameters
|
|
391
|
+
----------
|
|
392
|
+
spatials : list[SpatialSpec]
|
|
393
|
+
Spatial requests to export.
|
|
394
|
+
temporal : TemporalSpec or None
|
|
395
|
+
Optional temporal filter applied to all spatial requests.
|
|
396
|
+
models : list[str | ExportModelRequest]
|
|
397
|
+
Model identifiers or per-model request objects. To pass model-specific
|
|
398
|
+
settings (e.g. variant selection), use
|
|
399
|
+
:meth:`ExportModelRequest.configure` instead of raw strings::
|
|
400
|
+
|
|
401
|
+
models=[ExportModelRequest.configure("dofa", variant="large")]
|
|
402
|
+
target : ExportTarget
|
|
403
|
+
Output destination: use :meth:`ExportTarget.per_item` for per-item
|
|
404
|
+
directory exports or :meth:`ExportTarget.combined` for a single file.
|
|
405
|
+
config : ExportConfig
|
|
406
|
+
Runtime configuration (format, workers, resume, etc.).
|
|
407
|
+
Defaults to :class:`ExportConfig` with all defaults applied.
|
|
408
|
+
backend : str
|
|
409
|
+
Backend/provider selector.
|
|
410
|
+
device : str
|
|
411
|
+
Target inference device.
|
|
412
|
+
output : OutputSpec
|
|
413
|
+
Embedding output representation policy.
|
|
414
|
+
sensor : SensorSpec or None
|
|
415
|
+
Default sensor for all models unless overridden.
|
|
416
|
+
fetch : FetchSpec or None
|
|
417
|
+
Default fetch-policy override for all models unless overridden.
|
|
418
|
+
Cannot be combined with ``sensor``.
|
|
419
|
+
modality : str or None
|
|
420
|
+
Optional global modality selector applied to models that expose
|
|
421
|
+
public modality switching.
|
|
422
|
+
per_model_sensors : dict[str, SensorSpec] or None
|
|
423
|
+
Per-model sensor overrides keyed by model name.
|
|
424
|
+
per_model_fetches : dict[str, FetchSpec] or None
|
|
425
|
+
Per-model fetch-policy overrides keyed by model name. Cannot be
|
|
426
|
+
combined with sensor overrides for the same model.
|
|
427
|
+
per_model_modalities : dict[str, str] or None
|
|
428
|
+
Optional per-model modality overrides keyed by model name.
|
|
429
|
+
|
|
430
|
+
Returns
|
|
431
|
+
-------
|
|
432
|
+
Any
|
|
433
|
+
Export result object returned by :class:`BatchExporter`.
|
|
434
|
+
|
|
435
|
+
Raises
|
|
436
|
+
------
|
|
437
|
+
ModelError
|
|
438
|
+
If arguments are invalid or unsupported (for example empty inputs,
|
|
439
|
+
unsupported format, or incompatible model/backend settings).
|
|
440
|
+
SpecError
|
|
441
|
+
If spatial or temporal specifications fail validation.
|
|
442
|
+
"""
|
|
443
|
+
from .pipelines.exporter import BatchExporter
|
|
444
|
+
|
|
445
|
+
if not isinstance(spatials, list) or len(spatials) == 0:
|
|
446
|
+
raise ModelError("spatials must be a non-empty list[SpatialSpec].")
|
|
447
|
+
|
|
448
|
+
backend_n = _normalize_backend_name(backend)
|
|
449
|
+
device_n = _normalize_device_name(device)
|
|
450
|
+
|
|
451
|
+
fmt, ext = _normalize_export_format(config.format)
|
|
452
|
+
export_config = replace(config, format=fmt)
|
|
453
|
+
|
|
454
|
+
export_target = _normalize_export_target(
|
|
455
|
+
n_spatials=len(spatials),
|
|
456
|
+
ext=ext,
|
|
457
|
+
target=target,
|
|
458
|
+
)
|
|
459
|
+
|
|
460
|
+
_validate_spatial_list(spatials=spatials, temporal=temporal, output=output)
|
|
461
|
+
|
|
462
|
+
resume_manifest = _maybe_return_completed_combined_resume(
|
|
463
|
+
target=export_target,
|
|
464
|
+
config=export_config,
|
|
465
|
+
spatials=spatials,
|
|
466
|
+
temporal=temporal,
|
|
467
|
+
output=output,
|
|
468
|
+
backend=backend_n,
|
|
469
|
+
device=device_n,
|
|
470
|
+
)
|
|
471
|
+
if resume_manifest is not None:
|
|
472
|
+
return resume_manifest
|
|
473
|
+
|
|
474
|
+
model_configs, resolved_backend = _resolve_export_model_configs(
|
|
475
|
+
models=models,
|
|
476
|
+
backend_n=backend_n,
|
|
477
|
+
temporal=temporal,
|
|
478
|
+
output=output,
|
|
479
|
+
sensor=sensor,
|
|
480
|
+
fetch=fetch,
|
|
481
|
+
modality=modality,
|
|
482
|
+
per_model_sensors=per_model_sensors,
|
|
483
|
+
per_model_fetches=per_model_fetches,
|
|
484
|
+
per_model_modalities=per_model_modalities,
|
|
485
|
+
)
|
|
486
|
+
|
|
487
|
+
exporter = BatchExporter(
|
|
488
|
+
spatials=spatials,
|
|
489
|
+
temporal=temporal,
|
|
490
|
+
models=model_configs,
|
|
491
|
+
target=export_target,
|
|
492
|
+
output=output,
|
|
493
|
+
config=export_config,
|
|
494
|
+
backend=backend_n,
|
|
495
|
+
resolved_backend=resolved_backend,
|
|
496
|
+
device=device_n,
|
|
497
|
+
provider_factory=provider_factory_for_backend(backend_n),
|
|
498
|
+
progress_factory=_create_progress,
|
|
499
|
+
)
|
|
500
|
+
return exporter.run()
|