euler-inference 2.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- euler_inference/__init__.py +8 -0
- euler_inference/__main__.py +5 -0
- euler_inference/_outputs.py +409 -0
- euler_inference/api.py +165 -0
- euler_inference/config.py +297 -0
- euler_inference/inference.py +332 -0
- euler_inference/model_card.py +259 -0
- euler_inference/models/__init__.py +59 -0
- euler_inference/models/external_model.py +208 -0
- euler_inference-2.0.1.dist-info/METADATA +13 -0
- euler_inference-2.0.1.dist-info/RECORD +14 -0
- euler_inference-2.0.1.dist-info/WHEEL +5 -0
- euler_inference-2.0.1.dist-info/entry_points.txt +2 -0
- euler_inference-2.0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
"""Model inference pipeline using euler-loading."""
|
|
2
|
+
|
|
3
|
+
# Note: We don't import submodules here to avoid RuntimeWarning when running
|
|
4
|
+
# `python -m euler_inference`. Import directly from submodules instead:
|
|
5
|
+
# from euler_inference.config import InferenceConfig
|
|
6
|
+
# from euler_inference.inference import run_inference
|
|
7
|
+
|
|
8
|
+
__all__ = ["config", "inference", "models"]
|
|
@@ -0,0 +1,409 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
import logging
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import IO, Any
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
from euler_loading import MultiModalDataset
|
|
10
|
+
|
|
11
|
+
from euler_inference.config import InferenceConfig, OutputConfig
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def get_save_path(output_base: Path, full_id: str, suffix: str, extension: str) -> Path:
|
|
15
|
+
"""
|
|
16
|
+
Get save path from full_id with configurable suffix and extension.
|
|
17
|
+
|
|
18
|
+
The full_id is "/" separated, representing the hierarchical structure
|
|
19
|
+
the file originates from (e.g., "Scene01/clone/Camera_0/00001").
|
|
20
|
+
|
|
21
|
+
We save as: output_base/Scene01/clone/Camera_0/00001{suffix}.{extension}
|
|
22
|
+
"""
|
|
23
|
+
full_id = full_id.lstrip("/")
|
|
24
|
+
return output_base / f"{full_id}{suffix}.{extension}"
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _normalize_image_array(data: np.ndarray) -> np.ndarray:
|
|
28
|
+
if data.dtype in (np.float32, np.float64):
|
|
29
|
+
return (np.clip(data, 0, 1) * 255).astype(np.uint8)
|
|
30
|
+
if data.dtype != np.uint8:
|
|
31
|
+
return data.astype(np.uint8)
|
|
32
|
+
return data
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _build_image(data: np.ndarray):
|
|
36
|
+
from PIL import Image
|
|
37
|
+
|
|
38
|
+
normalized = _normalize_image_array(data)
|
|
39
|
+
if normalized.ndim == 2:
|
|
40
|
+
return Image.fromarray(normalized, mode="L")
|
|
41
|
+
if normalized.ndim == 3 and normalized.shape[2] == 3:
|
|
42
|
+
return Image.fromarray(normalized, mode="RGB")
|
|
43
|
+
if normalized.ndim == 3 and normalized.shape[2] == 4:
|
|
44
|
+
return Image.fromarray(normalized, mode="RGBA")
|
|
45
|
+
raise ValueError(f"Unsupported array shape for image: {data.shape}")
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _save_image_output(data: np.ndarray, target: Path | IO[bytes], output_type: str) -> None:
|
|
49
|
+
format_name = {"png": "PNG", "jpg": "JPEG", "jpeg": "JPEG"}[output_type]
|
|
50
|
+
_build_image(data).save(target, format=format_name)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _save_exr_output(data: np.ndarray, target: Path) -> None:
|
|
54
|
+
try:
|
|
55
|
+
import Imath
|
|
56
|
+
import OpenEXR
|
|
57
|
+
except ImportError as exc:
|
|
58
|
+
raise ImportError(
|
|
59
|
+
"OpenEXR package required for EXR output. Install with: pip install OpenEXR"
|
|
60
|
+
) from exc
|
|
61
|
+
|
|
62
|
+
half_chan = Imath.Channel(Imath.PixelType(Imath.PixelType.FLOAT))
|
|
63
|
+
if data.ndim == 2:
|
|
64
|
+
height, width = data.shape
|
|
65
|
+
header = OpenEXR.Header(width, height)
|
|
66
|
+
header["channels"] = {"Y": half_chan}
|
|
67
|
+
exr = OpenEXR.OutputFile(str(target), header)
|
|
68
|
+
exr.writePixels({"Y": data.astype(np.float32).tobytes()})
|
|
69
|
+
exr.close()
|
|
70
|
+
return
|
|
71
|
+
|
|
72
|
+
if data.ndim == 3 and data.shape[2] == 3:
|
|
73
|
+
height, width, _ = data.shape
|
|
74
|
+
header = OpenEXR.Header(width, height)
|
|
75
|
+
header["channels"] = {"R": half_chan, "G": half_chan, "B": half_chan}
|
|
76
|
+
exr = OpenEXR.OutputFile(str(target), header)
|
|
77
|
+
exr.writePixels({
|
|
78
|
+
"R": data[:, :, 0].astype(np.float32).tobytes(),
|
|
79
|
+
"G": data[:, :, 1].astype(np.float32).tobytes(),
|
|
80
|
+
"B": data[:, :, 2].astype(np.float32).tobytes(),
|
|
81
|
+
})
|
|
82
|
+
exr.close()
|
|
83
|
+
return
|
|
84
|
+
|
|
85
|
+
raise ValueError(f"Unsupported array shape for EXR: {data.shape}")
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def save_output(data: np.ndarray, target: Path | IO[bytes], output_type: str) -> None:
|
|
89
|
+
"""
|
|
90
|
+
Save output data to file in the specified format.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
data: numpy array to save
|
|
94
|
+
target: Path to save to, or a writable file-like object (for zip streaming)
|
|
95
|
+
output_type: File format ("npy", "png", "jpg", "jpeg", "exr")
|
|
96
|
+
|
|
97
|
+
Raises:
|
|
98
|
+
ValueError: If output_type is not supported
|
|
99
|
+
"""
|
|
100
|
+
if isinstance(target, Path):
|
|
101
|
+
target.parent.mkdir(parents=True, exist_ok=True)
|
|
102
|
+
|
|
103
|
+
if output_type == "npy":
|
|
104
|
+
np.save(target, data)
|
|
105
|
+
return
|
|
106
|
+
|
|
107
|
+
if output_type in ("png", "jpg", "jpeg"):
|
|
108
|
+
_save_image_output(data, target, output_type)
|
|
109
|
+
return
|
|
110
|
+
|
|
111
|
+
if output_type == "exr":
|
|
112
|
+
if not isinstance(target, Path):
|
|
113
|
+
raise ValueError("EXR output is not supported in zip mode")
|
|
114
|
+
_save_exr_output(data, target)
|
|
115
|
+
return
|
|
116
|
+
|
|
117
|
+
raise ValueError(f"Unsupported output type: {output_type}")
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def _get_writer_kwargs(
|
|
121
|
+
output_config: OutputConfig,
|
|
122
|
+
*,
|
|
123
|
+
strict: bool = True,
|
|
124
|
+
dataset_name: str | None = None,
|
|
125
|
+
) -> dict[str, Any]:
|
|
126
|
+
"""Derive DatasetWriter/ZipDatasetWriter kwargs from an OutputConfig."""
|
|
127
|
+
writer_meta = output_config.writer or {}
|
|
128
|
+
name = writer_meta.get("name", dataset_name or output_config.key)
|
|
129
|
+
type_ = writer_meta.get(
|
|
130
|
+
"type",
|
|
131
|
+
writer_meta.get("euler_train", {}).get("modality_type", output_config.key),
|
|
132
|
+
)
|
|
133
|
+
default_modality_type = output_config.key if strict else "other"
|
|
134
|
+
euler_train = writer_meta.get("euler_train", {
|
|
135
|
+
"used_as": "target",
|
|
136
|
+
"modality_type": default_modality_type,
|
|
137
|
+
})
|
|
138
|
+
kwargs = dict(name=name, type=type_, euler_train=euler_train, separator=None)
|
|
139
|
+
|
|
140
|
+
meta = writer_meta.get("meta")
|
|
141
|
+
if meta is not None:
|
|
142
|
+
kwargs["meta"] = meta
|
|
143
|
+
|
|
144
|
+
euler_loading = writer_meta.get("euler_loading")
|
|
145
|
+
if euler_loading is not None:
|
|
146
|
+
merged = {
|
|
147
|
+
key: value for key, value in euler_train.items()
|
|
148
|
+
if key in ("used_as", "modality_type")
|
|
149
|
+
}
|
|
150
|
+
merged.update(euler_loading)
|
|
151
|
+
kwargs["euler_loading"] = merged
|
|
152
|
+
return kwargs
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def create_writers(
|
|
156
|
+
config: InferenceConfig,
|
|
157
|
+
dataset_name: str | None = None,
|
|
158
|
+
outputs: list[OutputConfig] | None = None,
|
|
159
|
+
) -> dict[str, Any]:
|
|
160
|
+
"""Create one DatasetWriter or ZipDatasetWriter per output key."""
|
|
161
|
+
from ds_crawler import DatasetWriter, ZipDatasetWriter
|
|
162
|
+
|
|
163
|
+
output_base = Path(config.output_base_path)
|
|
164
|
+
writers: dict[str, Any] = {}
|
|
165
|
+
selected_outputs = outputs if outputs is not None else config.outputs
|
|
166
|
+
for output_config in selected_outputs:
|
|
167
|
+
kwargs = _get_writer_kwargs(
|
|
168
|
+
output_config,
|
|
169
|
+
strict=config.strict,
|
|
170
|
+
dataset_name=dataset_name,
|
|
171
|
+
)
|
|
172
|
+
if config.zip:
|
|
173
|
+
root = output_base / f"{output_config.key}.zip"
|
|
174
|
+
writers[output_config.key] = ZipDatasetWriter(root, **kwargs)
|
|
175
|
+
else:
|
|
176
|
+
root = output_base / output_config.key
|
|
177
|
+
writers[output_config.key] = DatasetWriter(root, **kwargs)
|
|
178
|
+
return writers
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
@dataclass(frozen=True)
|
|
182
|
+
class _WriterBackedOutputPlan:
|
|
183
|
+
output_config: OutputConfig
|
|
184
|
+
source_modality: str
|
|
185
|
+
output_root: Path
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def _resolve_source_modality(
|
|
189
|
+
output_config: OutputConfig,
|
|
190
|
+
modality_names: set[str],
|
|
191
|
+
) -> tuple[str | None, bool]:
|
|
192
|
+
if output_config.source_modality:
|
|
193
|
+
return output_config.source_modality, True
|
|
194
|
+
if output_config.key in modality_names:
|
|
195
|
+
return output_config.key, False
|
|
196
|
+
return None, False
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def _output_root(base_path: str, output_key: str, *, zip_mode: bool) -> Path:
|
|
200
|
+
output_base = Path(base_path)
|
|
201
|
+
if zip_mode:
|
|
202
|
+
return output_base / f"{output_key}.zip"
|
|
203
|
+
return output_base / output_key
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
def _build_output_plans(
|
|
207
|
+
config: InferenceConfig,
|
|
208
|
+
dataset: MultiModalDataset,
|
|
209
|
+
logger: logging.Logger,
|
|
210
|
+
) -> tuple[list[_WriterBackedOutputPlan], list[OutputConfig]]:
|
|
211
|
+
"""Split outputs into euler-loading-backed vs legacy serialization."""
|
|
212
|
+
modality_names = set(config.dataset.modalities)
|
|
213
|
+
writer_backed: list[_WriterBackedOutputPlan] = []
|
|
214
|
+
legacy: list[OutputConfig] = []
|
|
215
|
+
|
|
216
|
+
for output_config in config.outputs:
|
|
217
|
+
source_modality, explicit = _resolve_source_modality(
|
|
218
|
+
output_config,
|
|
219
|
+
modality_names,
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
if source_modality is None:
|
|
223
|
+
legacy.append(output_config)
|
|
224
|
+
continue
|
|
225
|
+
|
|
226
|
+
if source_modality not in modality_names:
|
|
227
|
+
if explicit:
|
|
228
|
+
raise ValueError(
|
|
229
|
+
f"Output '{output_config.key}' references source_modality="
|
|
230
|
+
f"{source_modality!r}, but that modality is not present in "
|
|
231
|
+
"dataset.modalities."
|
|
232
|
+
)
|
|
233
|
+
legacy.append(output_config)
|
|
234
|
+
continue
|
|
235
|
+
|
|
236
|
+
try:
|
|
237
|
+
dataset.get_writer(source_modality)
|
|
238
|
+
except (KeyError, ValueError) as exc:
|
|
239
|
+
if explicit:
|
|
240
|
+
raise ValueError(
|
|
241
|
+
f"Output '{output_config.key}' requested source_modality="
|
|
242
|
+
f"{source_modality!r}, but no euler-loading writer is "
|
|
243
|
+
f"available: {exc}"
|
|
244
|
+
) from exc
|
|
245
|
+
logger.warning(
|
|
246
|
+
"Output '%s' matches modality '%s', but no euler-loading "
|
|
247
|
+
"writer is available (%s). Falling back to the legacy "
|
|
248
|
+
"serializer.",
|
|
249
|
+
output_config.key,
|
|
250
|
+
source_modality,
|
|
251
|
+
exc,
|
|
252
|
+
)
|
|
253
|
+
legacy.append(output_config)
|
|
254
|
+
continue
|
|
255
|
+
|
|
256
|
+
if output_config.writer:
|
|
257
|
+
logger.warning(
|
|
258
|
+
"Output '%s' is source-backed via modality '%s'; "
|
|
259
|
+
"outputs[].writer metadata is ignored for this mode.",
|
|
260
|
+
output_config.key,
|
|
261
|
+
source_modality,
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
output_root = _output_root(
|
|
265
|
+
config.output_base_path,
|
|
266
|
+
output_config.key,
|
|
267
|
+
zip_mode=config.zip,
|
|
268
|
+
)
|
|
269
|
+
logger.info(
|
|
270
|
+
"Output '%s': mirroring source modality '%s' into %s. Source "
|
|
271
|
+
"filenames and extensions are preserved.",
|
|
272
|
+
output_config.key,
|
|
273
|
+
source_modality,
|
|
274
|
+
output_root,
|
|
275
|
+
)
|
|
276
|
+
writer_backed.append(
|
|
277
|
+
_WriterBackedOutputPlan(
|
|
278
|
+
output_config=output_config,
|
|
279
|
+
source_modality=source_modality,
|
|
280
|
+
output_root=output_root,
|
|
281
|
+
)
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
for output_config in legacy:
|
|
285
|
+
root = _output_root(
|
|
286
|
+
config.output_base_path,
|
|
287
|
+
output_config.key,
|
|
288
|
+
zip_mode=config.zip,
|
|
289
|
+
)
|
|
290
|
+
logger.info(
|
|
291
|
+
"Output '%s': using the legacy serializer into %s.",
|
|
292
|
+
output_config.key,
|
|
293
|
+
root,
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
return writer_backed, legacy
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
def _create_source_backed_writers(
|
|
300
|
+
plans: list[_WriterBackedOutputPlan],
|
|
301
|
+
*,
|
|
302
|
+
dataset: MultiModalDataset,
|
|
303
|
+
) -> dict[str, Any]:
|
|
304
|
+
writers: dict[str, Any] = {}
|
|
305
|
+
for plan in plans:
|
|
306
|
+
writers[plan.output_config.key] = dataset.create_output_writer(
|
|
307
|
+
plan.source_modality,
|
|
308
|
+
plan.output_root,
|
|
309
|
+
zip=plan.output_root.suffix.lower() == ".zip",
|
|
310
|
+
)
|
|
311
|
+
return writers
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
def _full_id_leaf(full_id: str) -> str:
|
|
315
|
+
parts = [part for part in full_id.split("/") if part]
|
|
316
|
+
return parts[-1] if parts else full_id
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
@dataclass
|
|
320
|
+
class PreparedOutputs:
|
|
321
|
+
dataset: MultiModalDataset
|
|
322
|
+
expected_output_keys: tuple[str, ...]
|
|
323
|
+
writer_backed_plans: list[_WriterBackedOutputPlan]
|
|
324
|
+
legacy_outputs: list[OutputConfig]
|
|
325
|
+
source_backed_writers: dict[str, Any]
|
|
326
|
+
legacy_writers: dict[str, Any]
|
|
327
|
+
zip_mode: bool
|
|
328
|
+
|
|
329
|
+
@classmethod
|
|
330
|
+
def prepare(
|
|
331
|
+
cls,
|
|
332
|
+
*,
|
|
333
|
+
config: InferenceConfig,
|
|
334
|
+
dataset: MultiModalDataset,
|
|
335
|
+
logger: logging.Logger,
|
|
336
|
+
) -> "PreparedOutputs":
|
|
337
|
+
writer_backed_plans, legacy_outputs = _build_output_plans(
|
|
338
|
+
config,
|
|
339
|
+
dataset,
|
|
340
|
+
logger,
|
|
341
|
+
)
|
|
342
|
+
source_backed_writers = _create_source_backed_writers(
|
|
343
|
+
writer_backed_plans,
|
|
344
|
+
dataset=dataset,
|
|
345
|
+
) if writer_backed_plans else {}
|
|
346
|
+
legacy_writers = create_writers(
|
|
347
|
+
config,
|
|
348
|
+
dataset_name=dataset.get_dataset_name(),
|
|
349
|
+
outputs=legacy_outputs,
|
|
350
|
+
) if legacy_outputs else {}
|
|
351
|
+
return cls(
|
|
352
|
+
dataset=dataset,
|
|
353
|
+
expected_output_keys=tuple(output.key for output in config.outputs),
|
|
354
|
+
writer_backed_plans=writer_backed_plans,
|
|
355
|
+
legacy_outputs=legacy_outputs,
|
|
356
|
+
source_backed_writers=source_backed_writers,
|
|
357
|
+
legacy_writers=legacy_writers,
|
|
358
|
+
zip_mode=config.zip,
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
def validate_prediction_keys(self, outputs: dict[str, Any]) -> None:
|
|
362
|
+
missing_keys = [
|
|
363
|
+
output_key
|
|
364
|
+
for output_key in self.expected_output_keys
|
|
365
|
+
if output_key not in outputs
|
|
366
|
+
]
|
|
367
|
+
if missing_keys:
|
|
368
|
+
raise ValueError(
|
|
369
|
+
"Model did not return the configured output keys "
|
|
370
|
+
f"{missing_keys}. Got: {list(outputs.keys())}"
|
|
371
|
+
)
|
|
372
|
+
|
|
373
|
+
def write_predictions(
|
|
374
|
+
self,
|
|
375
|
+
*,
|
|
376
|
+
sample_index: int,
|
|
377
|
+
full_id: str,
|
|
378
|
+
outputs: dict[str, Any],
|
|
379
|
+
) -> None:
|
|
380
|
+
for plan in self.writer_backed_plans:
|
|
381
|
+
self.dataset.write_sample(
|
|
382
|
+
sample_index,
|
|
383
|
+
{plan.source_modality: outputs[plan.output_config.key]},
|
|
384
|
+
{plan.source_modality: self.source_backed_writers[plan.output_config.key]},
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
if not self.legacy_writers:
|
|
388
|
+
return
|
|
389
|
+
|
|
390
|
+
file_id = _full_id_leaf(full_id)
|
|
391
|
+
for output_config in self.legacy_outputs:
|
|
392
|
+
output_data = outputs[output_config.key]
|
|
393
|
+
basename = (
|
|
394
|
+
f"{file_id}{output_config.effective_suffix}.{output_config.type}"
|
|
395
|
+
)
|
|
396
|
+
writer = self.legacy_writers[output_config.key]
|
|
397
|
+
if self.zip_mode:
|
|
398
|
+
with writer.open(full_id, basename) as handle:
|
|
399
|
+
save_output(output_data, handle, output_config.type)
|
|
400
|
+
else:
|
|
401
|
+
path = writer.get_path(full_id, basename)
|
|
402
|
+
save_output(output_data, path, output_config.type)
|
|
403
|
+
|
|
404
|
+
def finalize(self) -> None:
|
|
405
|
+
for writer in self.legacy_writers.values():
|
|
406
|
+
writer.save_index()
|
|
407
|
+
|
|
408
|
+
for writer in self.source_backed_writers.values():
|
|
409
|
+
writer.save_index()
|
euler_inference/api.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Python API for programmatic inference.
|
|
3
|
+
|
|
4
|
+
Provides a simple `infer()` function so pipelines can call inference
|
|
5
|
+
from Python without constructing JSON configs manually.
|
|
6
|
+
|
|
7
|
+
Usage (model card):
|
|
8
|
+
from euler_inference.api import infer
|
|
9
|
+
|
|
10
|
+
# With a card file and placeholder bindings
|
|
11
|
+
infer(
|
|
12
|
+
model_card="model_card.json",
|
|
13
|
+
bindings={"weights": "/path/to/checkpoint.pt"},
|
|
14
|
+
data={"rgb": "/data/rgb"},
|
|
15
|
+
output_base_path="/output",
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
# With an already-resolved card dict (from server)
|
|
19
|
+
infer(
|
|
20
|
+
model_card={"model": "/abs/path/model.py", ...},
|
|
21
|
+
data={"rgb": "/data/rgb"},
|
|
22
|
+
output_base_path="/output",
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
Usage (direct):
|
|
26
|
+
from euler_inference.api import infer
|
|
27
|
+
|
|
28
|
+
infer(
|
|
29
|
+
model_path="/path/to/model.py",
|
|
30
|
+
output_base_path="/output",
|
|
31
|
+
dataset_modalities={"rgb": "/data/rgb"},
|
|
32
|
+
)
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
from pathlib import Path
|
|
36
|
+
from typing import Optional, Union
|
|
37
|
+
|
|
38
|
+
from euler_inference.config import (
|
|
39
|
+
DatasetConfig,
|
|
40
|
+
ExternalModelConfig,
|
|
41
|
+
InferenceConfig,
|
|
42
|
+
OutputConfig,
|
|
43
|
+
)
|
|
44
|
+
from euler_inference.inference import run_inference
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def infer(
|
|
48
|
+
# Direct args (optional when model_card is used)
|
|
49
|
+
model_path: Optional[str] = None,
|
|
50
|
+
output_base_path: str = "",
|
|
51
|
+
dataset_modalities: Optional[dict] = None,
|
|
52
|
+
*,
|
|
53
|
+
# Model card args
|
|
54
|
+
model_card: Optional[Union[str, dict]] = None,
|
|
55
|
+
bindings: Optional[dict[str, str]] = None,
|
|
56
|
+
data: Optional[dict[str, str]] = None,
|
|
57
|
+
hierarchical_data: Optional[dict[str, str]] = None,
|
|
58
|
+
# Shared kwargs
|
|
59
|
+
model_config: Optional[dict] = None,
|
|
60
|
+
outputs: Optional[list[dict]] = None,
|
|
61
|
+
dataset_hierarchical_modalities: Optional[dict] = None,
|
|
62
|
+
device: Optional[str] = None,
|
|
63
|
+
max_samples: Optional[int] = None,
|
|
64
|
+
zip: bool = False,
|
|
65
|
+
strict: bool = True,
|
|
66
|
+
verbose: bool = False,
|
|
67
|
+
) -> None:
|
|
68
|
+
"""
|
|
69
|
+
Run inference with an external model.
|
|
70
|
+
|
|
71
|
+
Supports two modes:
|
|
72
|
+
|
|
73
|
+
1. **Model card mode** (preferred): Pass `model_card` as a path to a
|
|
74
|
+
model_card.json file or an already-resolved dict. Placeholder bindings
|
|
75
|
+
are provided via `bindings`, `data`, and `hierarchical_data`.
|
|
76
|
+
|
|
77
|
+
2. **Direct mode**: Pass `model_path`, `output_base_path`, and
|
|
78
|
+
`dataset_modalities` directly.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
model_path: (Direct) Absolute path to the model.py file
|
|
82
|
+
output_base_path: Directory to save predictions
|
|
83
|
+
dataset_modalities: (Direct) Mapping of modality names to paths
|
|
84
|
+
model_card: Path to model_card.json (str) or already-resolved card
|
|
85
|
+
dict. When provided, direct args are ignored.
|
|
86
|
+
bindings: Placeholder bindings for model card resolution
|
|
87
|
+
(e.g., {"weights": "/path/to/checkpoint.pt"})
|
|
88
|
+
data: Input modality path bindings
|
|
89
|
+
(e.g., {"rgb": "/data/rgb"})
|
|
90
|
+
hierarchical_data: Hierarchical input path bindings
|
|
91
|
+
(e.g., {"textgt": "/data/textgt"})
|
|
92
|
+
model_config: Model-specific config dict
|
|
93
|
+
outputs: List of output dicts
|
|
94
|
+
dataset_hierarchical_modalities: Hierarchical modality paths
|
|
95
|
+
device: Device string ("cuda", "cpu", "mps"). Auto-detected if None.
|
|
96
|
+
max_samples: Max samples to process. None for all.
|
|
97
|
+
verbose: Enable verbose logging.
|
|
98
|
+
"""
|
|
99
|
+
if model_card is not None:
|
|
100
|
+
from euler_inference.model_card import (
|
|
101
|
+
load_model_card,
|
|
102
|
+
model_card_to_config,
|
|
103
|
+
resolve_placeholders,
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
if isinstance(model_card, str):
|
|
107
|
+
# Path to card file — load and resolve placeholders
|
|
108
|
+
card_dict, card_dir = load_model_card(model_card)
|
|
109
|
+
all_bindings: dict[str, str] = {}
|
|
110
|
+
all_bindings.update(bindings or {})
|
|
111
|
+
all_bindings.update(data or {})
|
|
112
|
+
all_bindings.update(hierarchical_data or {})
|
|
113
|
+
card_dict = resolve_placeholders(card_dict, all_bindings)
|
|
114
|
+
elif isinstance(model_card, dict):
|
|
115
|
+
# Already-resolved dict (from server)
|
|
116
|
+
card_dict = model_card
|
|
117
|
+
card_dir = Path(".")
|
|
118
|
+
else:
|
|
119
|
+
raise TypeError(
|
|
120
|
+
f"model_card must be str or dict, got {type(model_card).__name__}"
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
config = model_card_to_config(
|
|
124
|
+
card_dict,
|
|
125
|
+
card_dir,
|
|
126
|
+
output_base_path=output_base_path,
|
|
127
|
+
device=device,
|
|
128
|
+
max_samples=max_samples,
|
|
129
|
+
zip=zip,
|
|
130
|
+
strict=strict,
|
|
131
|
+
)
|
|
132
|
+
run_inference(config, verbose=verbose)
|
|
133
|
+
return
|
|
134
|
+
|
|
135
|
+
# Direct code path
|
|
136
|
+
if model_path is None:
|
|
137
|
+
raise ValueError("Either model_card or model_path must be provided")
|
|
138
|
+
if dataset_modalities is None:
|
|
139
|
+
raise ValueError("dataset_modalities is required when not using model_card")
|
|
140
|
+
|
|
141
|
+
parsed_outputs = None
|
|
142
|
+
outputs_from_default = True
|
|
143
|
+
if outputs is not None:
|
|
144
|
+
parsed_outputs = [OutputConfig.from_dict(o) for o in outputs]
|
|
145
|
+
outputs_from_default = False
|
|
146
|
+
|
|
147
|
+
config = InferenceConfig(
|
|
148
|
+
external_model=ExternalModelConfig(
|
|
149
|
+
model_path=model_path,
|
|
150
|
+
model_config=model_config,
|
|
151
|
+
),
|
|
152
|
+
dataset=DatasetConfig(
|
|
153
|
+
modalities=dataset_modalities,
|
|
154
|
+
hierarchical_modalities=dataset_hierarchical_modalities,
|
|
155
|
+
),
|
|
156
|
+
output_base_path=output_base_path,
|
|
157
|
+
device=device,
|
|
158
|
+
max_samples=max_samples,
|
|
159
|
+
zip=zip,
|
|
160
|
+
strict=strict,
|
|
161
|
+
_outputs_from_default=outputs_from_default,
|
|
162
|
+
**({"outputs": parsed_outputs} if parsed_outputs else {}),
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
run_inference(config, verbose=verbose)
|