luminarycloud 0.21.1__py3-none-any.whl → 0.21.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- luminarycloud/_helpers/download.py +3 -1
- luminarycloud/_proto/api/v0/luminarycloud/geometry/geometry_pb2.py +124 -124
- luminarycloud/_proto/api/v0/luminarycloud/geometry/geometry_pb2.pyi +8 -1
- luminarycloud/_proto/api/v0/luminarycloud/mesh/mesh_pb2.py +11 -11
- luminarycloud/_proto/api/v0/luminarycloud/mesh/mesh_pb2.pyi +9 -2
- luminarycloud/_proto/api/v0/luminarycloud/physics_ai/physics_ai_pb2.py +33 -20
- luminarycloud/_proto/api/v0/luminarycloud/physics_ai/physics_ai_pb2.pyi +21 -1
- luminarycloud/_proto/api/v0/luminarycloud/project/project_pb2.py +16 -16
- luminarycloud/_proto/api/v0/luminarycloud/project/project_pb2.pyi +7 -3
- luminarycloud/_proto/api/v0/luminarycloud/simulation/simulation_pb2.py +55 -55
- luminarycloud/_proto/api/v0/luminarycloud/simulation/simulation_pb2.pyi +4 -0
- luminarycloud/_proto/api/v0/luminarycloud/vis/vis_pb2.py +28 -26
- luminarycloud/_proto/api/v0/luminarycloud/vis/vis_pb2.pyi +7 -1
- luminarycloud/_proto/assistant/assistant_pb2.py +47 -34
- luminarycloud/_proto/assistant/assistant_pb2.pyi +21 -1
- luminarycloud/_proto/base/base_pb2.py +17 -7
- luminarycloud/_proto/base/base_pb2.pyi +26 -0
- luminarycloud/_proto/cad/transformation_pb2.py +60 -16
- luminarycloud/_proto/cad/transformation_pb2.pyi +138 -32
- luminarycloud/_proto/hexmesh/hexmesh_pb2.py +20 -18
- luminarycloud/_proto/hexmesh/hexmesh_pb2.pyi +7 -2
- luminarycloud/_proto/quantity/quantity_options_pb2.py +6 -6
- luminarycloud/_proto/quantity/quantity_options_pb2.pyi +10 -1
- luminarycloud/_proto/quantity/quantity_pb2.py +166 -166
- luminarycloud/_proto/quantity/quantity_pb2.pyi +1 -1
- luminarycloud/enum/gpu_type.py +2 -0
- luminarycloud/feature_modification.py +13 -34
- luminarycloud/physics_ai/solution.py +3 -1
- luminarycloud/pipelines/__init__.py +3 -3
- luminarycloud/pipelines/api.py +81 -0
- luminarycloud/pipelines/core.py +103 -96
- luminarycloud/pipelines/{operators.py → stages.py} +28 -28
- luminarycloud/project.py +10 -1
- luminarycloud/types/matrix3.py +12 -0
- luminarycloud/volume_selection.py +18 -60
- {luminarycloud-0.21.1.dist-info → luminarycloud-0.21.2.dist-info}/METADATA +1 -1
- {luminarycloud-0.21.1.dist-info → luminarycloud-0.21.2.dist-info}/RECORD +38 -38
- {luminarycloud-0.21.1.dist-info → luminarycloud-0.21.2.dist-info}/WHEEL +0 -0
luminarycloud/enum/gpu_type.py
CHANGED
|
@@ -10,3 +10,5 @@ class GPUType(IntEnum):
|
|
|
10
10
|
UNSPECIFIED = simulationpb.SimulationOptions.GPU_TYPE_UNSPECIFIED
|
|
11
11
|
V100 = simulationpb.SimulationOptions.GPU_TYPE_V100
|
|
12
12
|
A100 = simulationpb.SimulationOptions.GPU_TYPE_A100
|
|
13
|
+
T4 = simulationpb.SimulationOptions.GPU_TYPE_T4
|
|
14
|
+
H100 = simulationpb.SimulationOptions.GPU_TYPE_H100
|
|
@@ -1,12 +1,12 @@
|
|
|
1
1
|
# Copyright 2025 Luminary Cloud, Inc. All Rights Reserved.
|
|
2
2
|
from enum import Enum, auto
|
|
3
|
-
from typing import Dict, Iterable, List, Optional
|
|
3
|
+
from typing import Dict, Iterable, List, Optional
|
|
4
4
|
from copy import deepcopy
|
|
5
5
|
|
|
6
|
-
from luminarycloud.types.adfloat import _to_ad_proto
|
|
6
|
+
from luminarycloud.types.adfloat import _to_ad_proto
|
|
7
7
|
from ._proto.geometry import geometry_pb2 as gpb
|
|
8
8
|
from .types import Vector3Like
|
|
9
|
-
from .types.vector3 import
|
|
9
|
+
from .types.vector3 import _to_vector3_ad_proto
|
|
10
10
|
from .params.geometry import Shape, Sphere, Cube, Cylinder, Torus, Cone, HalfSphere, Volume
|
|
11
11
|
from google.protobuf.internal.containers import RepeatedScalarFieldContainer
|
|
12
12
|
|
|
@@ -483,10 +483,7 @@ def modify_translate(
|
|
|
483
483
|
_update_repeated_field(transform_op.body, vol_ids)
|
|
484
484
|
|
|
485
485
|
if displacement is not None:
|
|
486
|
-
|
|
487
|
-
transform_op.translation.vector.x = vec.x
|
|
488
|
-
transform_op.translation.vector.y = vec.y
|
|
489
|
-
transform_op.translation.vector.z = vec.z
|
|
486
|
+
transform_op.translation.vector.CopyFrom(_to_vector3_ad_proto(displacement))
|
|
490
487
|
|
|
491
488
|
if keep is not None:
|
|
492
489
|
transform_op.keep = keep
|
|
@@ -532,19 +529,13 @@ def modify_rotate(
|
|
|
532
529
|
|
|
533
530
|
# Update existing rotation
|
|
534
531
|
if angle is not None:
|
|
535
|
-
transform_op.rotation.angle
|
|
532
|
+
transform_op.rotation.angle.CopyFrom(_to_ad_proto(angle))
|
|
536
533
|
|
|
537
534
|
if axis is not None:
|
|
538
|
-
|
|
539
|
-
transform_op.rotation.arbitrary.direction.x = axis_vec.x
|
|
540
|
-
transform_op.rotation.arbitrary.direction.y = axis_vec.y
|
|
541
|
-
transform_op.rotation.arbitrary.direction.z = axis_vec.z
|
|
535
|
+
transform_op.rotation.arbitrary.direction.CopyFrom(_to_vector3_ad_proto(axis))
|
|
542
536
|
|
|
543
537
|
if origin is not None:
|
|
544
|
-
|
|
545
|
-
transform_op.rotation.arbitrary.origin.x = origin_vec.x
|
|
546
|
-
transform_op.rotation.arbitrary.origin.y = origin_vec.y
|
|
547
|
-
transform_op.rotation.arbitrary.origin.z = origin_vec.z
|
|
538
|
+
transform_op.rotation.arbitrary.origin.CopyFrom(_to_vector3_ad_proto(origin))
|
|
548
539
|
|
|
549
540
|
if keep is not None:
|
|
550
541
|
transform_op.keep = keep
|
|
@@ -587,13 +578,10 @@ def modify_scale(
|
|
|
587
578
|
_update_repeated_field(transform_op.body, vol_ids)
|
|
588
579
|
|
|
589
580
|
if scale_factor is not None:
|
|
590
|
-
transform_op.scaling.isotropic
|
|
581
|
+
transform_op.scaling.isotropic.CopyFrom(_to_ad_proto(scale_factor))
|
|
591
582
|
|
|
592
583
|
if origin is not None:
|
|
593
|
-
|
|
594
|
-
transform_op.scaling.arbitrary.x = origin_vec.x
|
|
595
|
-
transform_op.scaling.arbitrary.y = origin_vec.y
|
|
596
|
-
transform_op.scaling.arbitrary.z = origin_vec.z
|
|
584
|
+
transform_op.scaling.arbitrary.CopyFrom(_to_vector3_ad_proto(origin))
|
|
597
585
|
|
|
598
586
|
if keep is not None:
|
|
599
587
|
transform_op.keep = keep
|
|
@@ -769,10 +757,7 @@ def modify_linear_pattern(
|
|
|
769
757
|
|
|
770
758
|
if direction is not None:
|
|
771
759
|
# Update existing linear pattern direction
|
|
772
|
-
|
|
773
|
-
pattern_op.direction.linear_spacing.vector.x = dir_vec.x
|
|
774
|
-
pattern_op.direction.linear_spacing.vector.y = dir_vec.y
|
|
775
|
-
pattern_op.direction.linear_spacing.vector.z = dir_vec.z
|
|
760
|
+
pattern_op.direction.linear_spacing.vector.CopyFrom(_to_vector3_ad_proto(direction))
|
|
776
761
|
|
|
777
762
|
if quantity is not None:
|
|
778
763
|
pattern_op.direction.quantity = quantity
|
|
@@ -832,19 +817,13 @@ def modify_circular_pattern(
|
|
|
832
817
|
circular = pattern_op.direction.circular_distribution
|
|
833
818
|
|
|
834
819
|
if angle is not None:
|
|
835
|
-
circular.rotation.angle
|
|
820
|
+
circular.rotation.angle.CopyFrom(_to_ad_proto(angle))
|
|
836
821
|
|
|
837
822
|
if axis is not None:
|
|
838
|
-
|
|
839
|
-
circular.rotation.arbitrary.direction.x = axis_vec.x
|
|
840
|
-
circular.rotation.arbitrary.direction.y = axis_vec.y
|
|
841
|
-
circular.rotation.arbitrary.direction.z = axis_vec.z
|
|
823
|
+
circular.rotation.arbitrary.direction.CopyFrom(_to_vector3_ad_proto(axis))
|
|
842
824
|
|
|
843
825
|
if origin is not None:
|
|
844
|
-
|
|
845
|
-
circular.rotation.arbitrary.origin.x = origin_vec.x
|
|
846
|
-
circular.rotation.arbitrary.origin.y = origin_vec.y
|
|
847
|
-
circular.rotation.arbitrary.origin.z = origin_vec.z
|
|
826
|
+
circular.rotation.arbitrary.origin.CopyFrom(_to_vector3_ad_proto(origin))
|
|
848
827
|
|
|
849
828
|
if full_rotation is not None:
|
|
850
829
|
circular.full = full_rotation
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import tarfile
|
|
4
|
-
from typing import List, Optional, BinaryIO, cast
|
|
4
|
+
from typing import List, Optional, BinaryIO, cast, Dict
|
|
5
5
|
|
|
6
6
|
from .._client import get_default_client
|
|
7
7
|
from .._helpers.download import download_solution_physics_ai as _download_solution_physics_ai
|
|
@@ -16,6 +16,7 @@ def _download_processed_solution_physics_ai( # noqa: F841
|
|
|
16
16
|
volume_fields_to_keep: Optional[List[QuantityType]] = None,
|
|
17
17
|
process_volume: bool = False,
|
|
18
18
|
single_precision: bool = True,
|
|
19
|
+
internal_options: Optional[Dict[str, str]] = None,
|
|
19
20
|
) -> tarfile.TarFile:
|
|
20
21
|
"""
|
|
21
22
|
Download solution data with physics AI processing applied.
|
|
@@ -50,6 +51,7 @@ def _download_processed_solution_physics_ai( # noqa: F841
|
|
|
50
51
|
volume_fields_to_keep=volume_fields_to_keep,
|
|
51
52
|
process_volume=process_volume,
|
|
52
53
|
single_precision=single_precision,
|
|
54
|
+
internal_options=internal_options,
|
|
53
55
|
)
|
|
54
56
|
|
|
55
57
|
assert stream is not None, "Failed to download solution data"
|
|
@@ -2,6 +2,8 @@
|
|
|
2
2
|
from .core import (
|
|
3
3
|
Pipeline as Pipeline,
|
|
4
4
|
PipelineParameter as PipelineParameter,
|
|
5
|
+
# Stage base class, mainly exported for testing
|
|
6
|
+
Stage as Stage,
|
|
5
7
|
)
|
|
6
8
|
|
|
7
9
|
from .parameters import (
|
|
@@ -11,9 +13,7 @@ from .parameters import (
|
|
|
11
13
|
BoolPipelineParameter as BoolPipelineParameter,
|
|
12
14
|
)
|
|
13
15
|
|
|
14
|
-
from .
|
|
15
|
-
# Operator base class, mainly exported for testing
|
|
16
|
-
Operator as Operator,
|
|
16
|
+
from .stages import (
|
|
17
17
|
# PipelineOutputs, i.e. things that "flow" in a Pipeline
|
|
18
18
|
PipelineOutputGeometry as PipelineOutputGeometry,
|
|
19
19
|
PipelineOutputMesh as PipelineOutputMesh,
|
luminarycloud/pipelines/api.py
CHANGED
|
@@ -3,11 +3,15 @@ from typing import Any, Literal
|
|
|
3
3
|
from dataclasses import dataclass
|
|
4
4
|
|
|
5
5
|
from datetime import datetime
|
|
6
|
+
from time import time, sleep
|
|
7
|
+
import logging
|
|
6
8
|
|
|
7
9
|
from .arguments import PipelineArgValueType
|
|
8
10
|
from ..pipelines import Pipeline, PipelineArgs
|
|
9
11
|
from .._client import get_default_client
|
|
10
12
|
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
11
15
|
|
|
12
16
|
@dataclass
|
|
13
17
|
class LogLine:
|
|
@@ -152,6 +156,83 @@ class PipelineJobRecord:
|
|
|
152
156
|
res = get_default_client().http.get(f"/rest/v0/pipeline_jobs/{self.id}/artifacts")
|
|
153
157
|
return res["data"]
|
|
154
158
|
|
|
159
|
+
def wait(
|
|
160
|
+
self,
|
|
161
|
+
*,
|
|
162
|
+
interval_seconds: float = 5,
|
|
163
|
+
timeout_seconds: float = float("inf"),
|
|
164
|
+
print_logs: bool = False,
|
|
165
|
+
) -> Literal["completed", "failed"]:
|
|
166
|
+
"""
|
|
167
|
+
Wait for the pipeline job to complete or fail.
|
|
168
|
+
|
|
169
|
+
This method polls the pipeline job status at regular intervals until it reaches
|
|
170
|
+
a terminal state (completed or failed).
|
|
171
|
+
|
|
172
|
+
Parameters
|
|
173
|
+
----------
|
|
174
|
+
interval_seconds : float
|
|
175
|
+
Number of seconds between status polls. Default is 5 seconds.
|
|
176
|
+
timeout_seconds : float
|
|
177
|
+
Number of seconds before the operation times out. Default is infinity.
|
|
178
|
+
print_logs : bool
|
|
179
|
+
If True, prints new log lines as they become available. Default is False.
|
|
180
|
+
|
|
181
|
+
Returns
|
|
182
|
+
-------
|
|
183
|
+
Literal["completed", "failed"]
|
|
184
|
+
The final status of the pipeline job.
|
|
185
|
+
|
|
186
|
+
Raises
|
|
187
|
+
------
|
|
188
|
+
TimeoutError
|
|
189
|
+
If the pipeline job does not complete within the specified timeout.
|
|
190
|
+
|
|
191
|
+
Examples
|
|
192
|
+
--------
|
|
193
|
+
>>> pipeline_job = pipelines.create_pipeline_job(pipeline.id, args, "My Job")
|
|
194
|
+
>>> final_status = pipeline_job.wait(timeout_seconds=3600)
|
|
195
|
+
>>> print(f"Pipeline job finished with status: {final_status}")
|
|
196
|
+
"""
|
|
197
|
+
deadline = time() + timeout_seconds
|
|
198
|
+
last_log_count = 0
|
|
199
|
+
|
|
200
|
+
while True:
|
|
201
|
+
# Refresh the pipeline job status
|
|
202
|
+
updated_job = get_pipeline_job(self.id)
|
|
203
|
+
|
|
204
|
+
# Print new logs if requested
|
|
205
|
+
if print_logs:
|
|
206
|
+
logs = updated_job.logs()
|
|
207
|
+
if len(logs) > last_log_count:
|
|
208
|
+
for log_line in logs[last_log_count:]:
|
|
209
|
+
print(f"[{log_line.timestamp}] {log_line.message}")
|
|
210
|
+
last_log_count = len(logs)
|
|
211
|
+
|
|
212
|
+
# Check if we've reached a terminal state
|
|
213
|
+
if updated_job.status == "completed":
|
|
214
|
+
logger.info(f"Pipeline job {self.id} completed successfully")
|
|
215
|
+
return "completed"
|
|
216
|
+
elif updated_job.status == "failed":
|
|
217
|
+
logger.warning(f"Pipeline job {self.id} failed")
|
|
218
|
+
return "failed"
|
|
219
|
+
|
|
220
|
+
# Check timeout
|
|
221
|
+
if time() >= deadline:
|
|
222
|
+
raise TimeoutError(
|
|
223
|
+
f"Timed out waiting for pipeline job {self.id} to complete. "
|
|
224
|
+
f"Current status: {updated_job.status}"
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
# Wait before next poll
|
|
228
|
+
sleep(max(0, min(interval_seconds, deadline - time())))
|
|
229
|
+
|
|
230
|
+
# Update self with the latest status
|
|
231
|
+
self.status = updated_job.status
|
|
232
|
+
self.updated_at = updated_job.updated_at
|
|
233
|
+
self.started_at = updated_job.started_at
|
|
234
|
+
self.completed_at = updated_job.completed_at
|
|
235
|
+
|
|
155
236
|
|
|
156
237
|
@dataclass
|
|
157
238
|
class PipelineJobRunRecord:
|
luminarycloud/pipelines/core.py
CHANGED
|
@@ -88,48 +88,47 @@ class PipelineParameter(ABC):
|
|
|
88
88
|
|
|
89
89
|
class PipelineInput:
|
|
90
90
|
"""
|
|
91
|
-
A named input for
|
|
91
|
+
A named input for a Stage. Explicitly connected to a PipelineOutput.
|
|
92
92
|
"""
|
|
93
93
|
|
|
94
|
-
def __init__(self, upstream_output: "PipelineOutput", owner: "
|
|
94
|
+
def __init__(self, upstream_output: "PipelineOutput", owner: "Stage", name: str):
|
|
95
95
|
self.upstream_output = upstream_output
|
|
96
96
|
self.owner = owner
|
|
97
97
|
self.name = name
|
|
98
98
|
|
|
99
|
-
def _to_dict(self,
|
|
100
|
-
if self.upstream_output.owner not in
|
|
99
|
+
def _to_dict(self, id_for_stage: dict) -> dict:
|
|
100
|
+
if self.upstream_output.owner not in id_for_stage:
|
|
101
101
|
raise ValueError(
|
|
102
|
-
f"
|
|
102
|
+
f"Stage {self.owner} depends on a stage, {self.upstream_output.owner}, that isn't in the Pipeline. Did you forget to add it?"
|
|
103
103
|
)
|
|
104
|
-
|
|
104
|
+
upstream_stage_id = id_for_stage[self.upstream_output.owner]
|
|
105
105
|
upstream_output_name = self.upstream_output.name
|
|
106
|
-
return {self.name: f"{
|
|
106
|
+
return {self.name: f"{upstream_stage_id}.{upstream_output_name}"}
|
|
107
107
|
|
|
108
108
|
|
|
109
109
|
class PipelineOutput(ABC):
|
|
110
110
|
"""
|
|
111
|
-
A named output for
|
|
112
|
-
connected PipelineInputs.
|
|
111
|
+
A named output for a Stage. Can be used to spawn any number of connected PipelineInputs.
|
|
113
112
|
"""
|
|
114
113
|
|
|
115
|
-
def __init__(self, owner: "
|
|
114
|
+
def __init__(self, owner: "Stage", name: str):
|
|
116
115
|
self.owner = owner
|
|
117
116
|
self.name = name
|
|
118
117
|
self.downstream_inputs: list[PipelineInput] = []
|
|
119
118
|
|
|
120
|
-
def _spawn_input(self, owner: "
|
|
119
|
+
def _spawn_input(self, owner: "Stage", name: str) -> PipelineInput:
|
|
121
120
|
input = PipelineInput(self, owner, name)
|
|
122
121
|
self.downstream_inputs.append(input)
|
|
123
122
|
return input
|
|
124
123
|
|
|
125
124
|
|
|
126
|
-
class
|
|
125
|
+
class StageInputs:
|
|
127
126
|
"""
|
|
128
|
-
A collection of all PipelineInputs for
|
|
127
|
+
A collection of all PipelineInputs for a Stage.
|
|
129
128
|
"""
|
|
130
129
|
|
|
131
130
|
def __init__(
|
|
132
|
-
self, owner: "
|
|
131
|
+
self, owner: "Stage", **input_descriptors: tuple[Type[PipelineOutput], PipelineOutput]
|
|
133
132
|
):
|
|
134
133
|
"""
|
|
135
134
|
input_descriptors is a dict of input name -> (required_upstream_output_type, upstream_output)
|
|
@@ -144,26 +143,26 @@ class OperatorInputs:
|
|
|
144
143
|
)
|
|
145
144
|
self.inputs.add(upstream_output._spawn_input(owner, name))
|
|
146
145
|
|
|
147
|
-
def _to_dict(self,
|
|
146
|
+
def _to_dict(self, id_for_stage: dict) -> dict[str, str]:
|
|
148
147
|
d: dict[str, str] = {}
|
|
149
148
|
for input in self.inputs:
|
|
150
|
-
d |= input._to_dict(
|
|
149
|
+
d |= input._to_dict(id_for_stage)
|
|
151
150
|
return d
|
|
152
151
|
|
|
153
152
|
|
|
154
|
-
T = TypeVar("T", bound="
|
|
153
|
+
T = TypeVar("T", bound="StageOutputs")
|
|
155
154
|
|
|
156
155
|
|
|
157
|
-
class
|
|
156
|
+
class StageOutputs(ABC):
|
|
158
157
|
"""
|
|
159
|
-
A collection of all PipelineOutputs for
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
158
|
+
A collection of all PipelineOutputs for a Stage. Must be subclassed, and the subclass must also
|
|
159
|
+
be a dataclass whose fields are all PipelineOutput subclasses. Then that subclass should be
|
|
160
|
+
instantiated with `_instantiate_for`. Sounds a little complicated, perhaps, but it's not bad.
|
|
161
|
+
See the existing subclasses in `./stages.py` for examples.
|
|
163
162
|
"""
|
|
164
163
|
|
|
165
164
|
@classmethod
|
|
166
|
-
def _instantiate_for(cls: type[T], owner: "
|
|
165
|
+
def _instantiate_for(cls: type[T], owner: "Stage") -> T:
|
|
167
166
|
# create an instance with all fields instantiated with the given owner, and named by the
|
|
168
167
|
# field name.
|
|
169
168
|
# Also validate here that we are a dataclass, and all our fields are PipelineOutput types.
|
|
@@ -188,41 +187,41 @@ class OperatorOutputs(ABC):
|
|
|
188
187
|
return inputs
|
|
189
188
|
|
|
190
189
|
|
|
191
|
-
class
|
|
190
|
+
class StageRegistry:
|
|
192
191
|
def __init__(self):
|
|
193
|
-
self.
|
|
192
|
+
self.stages = {}
|
|
194
193
|
|
|
195
|
-
def register(self,
|
|
196
|
-
self.
|
|
194
|
+
def register(self, stage_class: Type["Stage"]) -> None:
|
|
195
|
+
self.stages[stage_class.__name__] = stage_class
|
|
197
196
|
|
|
198
|
-
def get(self,
|
|
199
|
-
if
|
|
200
|
-
raise ValueError(f"Unknown
|
|
201
|
-
return self.
|
|
197
|
+
def get(self, stage_type_name: str) -> Type["Stage"]:
|
|
198
|
+
if stage_type_name not in self.stages:
|
|
199
|
+
raise ValueError(f"Unknown stage type: {stage_type_name}")
|
|
200
|
+
return self.stages[stage_type_name]
|
|
202
201
|
|
|
203
202
|
|
|
204
|
-
TOutputs = TypeVar("TOutputs", bound=
|
|
203
|
+
TOutputs = TypeVar("TOutputs", bound=StageOutputs)
|
|
205
204
|
|
|
206
205
|
|
|
207
|
-
class
|
|
206
|
+
class Stage(Generic[TOutputs], ABC):
|
|
208
207
|
def __init__(
|
|
209
208
|
self,
|
|
210
|
-
|
|
209
|
+
stage_name: str | None,
|
|
211
210
|
params: dict,
|
|
212
|
-
inputs:
|
|
211
|
+
inputs: StageInputs,
|
|
213
212
|
outputs: TOutputs,
|
|
214
213
|
):
|
|
215
|
-
self.
|
|
216
|
-
self.
|
|
214
|
+
self._stage_type_name = self.__class__.__name__
|
|
215
|
+
self._name = stage_name if stage_name is not None else self._stage_type_name
|
|
217
216
|
self._params = params
|
|
218
217
|
self._inputs = inputs
|
|
219
218
|
self.outputs = outputs
|
|
220
|
-
ensure_yamlizable(self._params_dict()[0], "
|
|
219
|
+
ensure_yamlizable(self._params_dict()[0], "Stage parameters")
|
|
221
220
|
|
|
222
221
|
def is_source(self) -> bool:
|
|
223
222
|
return len(self._inputs.inputs) == 0
|
|
224
223
|
|
|
225
|
-
def inputs_dict(self) -> dict[str, tuple["
|
|
224
|
+
def inputs_dict(self) -> dict[str, tuple["Stage", str]]:
|
|
226
225
|
inputs = {}
|
|
227
226
|
for pipeline_input in self._inputs.inputs:
|
|
228
227
|
inputs[pipeline_input.name] = (
|
|
@@ -231,16 +230,16 @@ class Operator(Generic[TOutputs], ABC):
|
|
|
231
230
|
)
|
|
232
231
|
return inputs
|
|
233
232
|
|
|
234
|
-
def
|
|
233
|
+
def downstream_stages(self) -> list["Stage"]:
|
|
235
234
|
return [input.owner for input in self.outputs.downstream_inputs()]
|
|
236
235
|
|
|
237
|
-
def _to_dict(self,
|
|
236
|
+
def _to_dict(self, id_for_stage: dict) -> tuple[dict, set[PipelineParameter]]:
|
|
238
237
|
params, pipeline_params_set = self._params_dict()
|
|
239
238
|
d = {
|
|
240
|
-
"name": self.
|
|
241
|
-
"operator": self.
|
|
239
|
+
"name": self._name,
|
|
240
|
+
"operator": self._stage_type_name, # TODO: change key to "stage_type" when we're ready to bump the yaml schema version
|
|
242
241
|
"params": params,
|
|
243
|
-
"inputs": self._inputs._to_dict(
|
|
242
|
+
"inputs": self._inputs._to_dict(id_for_stage),
|
|
244
243
|
}
|
|
245
244
|
return d, pipeline_params_set
|
|
246
245
|
|
|
@@ -256,64 +255,64 @@ class Operator(Generic[TOutputs], ABC):
|
|
|
256
255
|
return d, pipeline_params
|
|
257
256
|
|
|
258
257
|
def __str__(self) -> str:
|
|
259
|
-
return f'{self.
|
|
258
|
+
return f'{self._stage_type_name}(name="{self._name}")'
|
|
260
259
|
|
|
261
|
-
_registry =
|
|
260
|
+
_registry = StageRegistry()
|
|
262
261
|
|
|
263
262
|
def __init_subclass__(cls, **kwargs):
|
|
264
263
|
super().__init_subclass__(**kwargs)
|
|
265
|
-
|
|
264
|
+
Stage._registry.register(cls)
|
|
266
265
|
|
|
267
266
|
@classmethod
|
|
268
|
-
def _get_subclass(cls,
|
|
269
|
-
return cls._registry.get(
|
|
267
|
+
def _get_subclass(cls, stage_type_name: str) -> Type["Stage"]:
|
|
268
|
+
return cls._registry.get(stage_type_name)
|
|
270
269
|
|
|
271
270
|
@classmethod
|
|
272
271
|
def _parse_params(cls, params: dict) -> dict:
|
|
273
|
-
#
|
|
274
|
-
#
|
|
272
|
+
# Stages with params that are just primitives or PipelineParams have no parsing to do.
|
|
273
|
+
# Stages with more complicated params should override this method.
|
|
275
274
|
return params
|
|
276
275
|
|
|
277
276
|
|
|
278
277
|
class Pipeline:
|
|
279
|
-
def __init__(self,
|
|
280
|
-
self.
|
|
281
|
-
self.
|
|
278
|
+
def __init__(self, stages: list[Stage]):
|
|
279
|
+
self.stages = stages
|
|
280
|
+
self._stage_ids = self._assign_ids_to_stages()
|
|
282
281
|
|
|
283
282
|
def to_yaml(self) -> str:
|
|
284
283
|
return yaml.safe_dump(self._to_dict())
|
|
285
284
|
|
|
286
285
|
def pipeline_params(self) -> set[PipelineParameter]:
|
|
287
|
-
return self.
|
|
286
|
+
return self._stages_dict_and_params()[1]
|
|
288
287
|
|
|
289
|
-
def
|
|
290
|
-
return self.
|
|
288
|
+
def _get_stage_id(self, stage: Stage) -> str:
|
|
289
|
+
return self._stage_ids[stage]
|
|
291
290
|
|
|
292
|
-
def
|
|
293
|
-
|
|
294
|
-
|
|
291
|
+
def _stages_dict_and_params(self) -> tuple[dict, set[PipelineParameter]]:
|
|
292
|
+
id_for_stage = self._stage_ids
|
|
293
|
+
stages = {}
|
|
295
294
|
params = set()
|
|
296
|
-
for
|
|
297
|
-
|
|
298
|
-
|
|
295
|
+
for stage in id_for_stage.keys():
|
|
296
|
+
stage_dict, referenced_params = stage._to_dict(id_for_stage)
|
|
297
|
+
stages[id_for_stage[stage]] = stage_dict
|
|
299
298
|
params.update(referenced_params)
|
|
300
|
-
return
|
|
299
|
+
return stages, params
|
|
301
300
|
|
|
302
301
|
def _to_dict(self) -> dict:
|
|
303
|
-
|
|
302
|
+
stages, params = self._stages_dict_and_params()
|
|
304
303
|
|
|
305
304
|
d = {
|
|
306
305
|
"lc_pipeline": {
|
|
307
306
|
"schema_version": 1,
|
|
308
307
|
"params": self._pipeline_params_dict(params),
|
|
309
|
-
"tasks":
|
|
308
|
+
"tasks": stages, # TODO: change key to "stages" when we're ready to bump the yaml schema version
|
|
310
309
|
}
|
|
311
310
|
}
|
|
312
311
|
ensure_yamlizable(d, "Pipeline")
|
|
313
312
|
return d
|
|
314
313
|
|
|
315
|
-
def
|
|
316
|
-
return {
|
|
314
|
+
def _assign_ids_to_stages(self) -> dict[Stage, str]:
|
|
315
|
+
return {stage: f"s{i + 1}-{stage._stage_type_name}" for i, stage in enumerate(self.stages)}
|
|
317
316
|
|
|
318
317
|
def _pipeline_params_dict(self, params: set[PipelineParameter]) -> dict:
|
|
319
318
|
d: dict[str, dict] = {}
|
|
@@ -334,7 +333,9 @@ class Pipeline:
|
|
|
334
333
|
d = d["lc_pipeline"]
|
|
335
334
|
if "schema_version" not in d:
|
|
336
335
|
raise ValueError("Invalid pipeline YAML: missing 'schema_version' key")
|
|
337
|
-
if
|
|
336
|
+
if (
|
|
337
|
+
"tasks" not in d
|
|
338
|
+
): # TODO: change key to "stages" when we're ready to bump the yaml schema version
|
|
338
339
|
raise ValueError("Invalid pipeline YAML: missing 'tasks' key")
|
|
339
340
|
|
|
340
341
|
if d["schema_version"] != 1:
|
|
@@ -347,18 +348,20 @@ class Pipeline:
|
|
|
347
348
|
param_name
|
|
348
349
|
)
|
|
349
350
|
|
|
350
|
-
# ...and use them as replacements for any references in the
|
|
351
|
-
for
|
|
352
|
-
|
|
353
|
-
|
|
351
|
+
# ...and use them as replacements for any references in the stages' parameters
|
|
352
|
+
for stage_dict in d[
|
|
353
|
+
"tasks"
|
|
354
|
+
].values(): # TODO: change key to "stages" when we're ready to bump the yaml schema version
|
|
355
|
+
stage_dict["params"] = _recursive_replace_pipeline_params(
|
|
356
|
+
stage_dict["params"], parsed_params
|
|
354
357
|
)
|
|
355
358
|
|
|
356
|
-
# then, finish parsing the
|
|
357
|
-
|
|
358
|
-
for
|
|
359
|
-
|
|
359
|
+
# then, finish parsing the stages
|
|
360
|
+
parsed_stages = {}
|
|
361
|
+
for stage_id in d["tasks"]:
|
|
362
|
+
_parse_stage(d, stage_id, parsed_stages)
|
|
360
363
|
|
|
361
|
-
return cls(list(
|
|
364
|
+
return cls(list(parsed_stages.values()))
|
|
362
365
|
|
|
363
366
|
|
|
364
367
|
def _recursive_replace_pipeline_params(d: Any, parsed_params: dict) -> Any:
|
|
@@ -368,7 +371,7 @@ def _recursive_replace_pipeline_params(d: Any, parsed_params: dict) -> Any:
|
|
|
368
371
|
pp_name = d["$pipeline_param"]
|
|
369
372
|
if pp_name not in parsed_params:
|
|
370
373
|
raise ValueError(
|
|
371
|
-
f'Pipeline parameter "{pp_name}" referenced in a pipeline
|
|
374
|
+
f'Pipeline parameter "{pp_name}" referenced in a pipeline stage, but not found in pipeline\'s declared parameters'
|
|
372
375
|
)
|
|
373
376
|
return parsed_params[pp_name]
|
|
374
377
|
else:
|
|
@@ -382,28 +385,32 @@ def _recursive_replace_pipeline_params(d: Any, parsed_params: dict) -> Any:
|
|
|
382
385
|
return d
|
|
383
386
|
|
|
384
387
|
|
|
385
|
-
def
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
388
|
+
def _parse_stage(pipeline_dict: dict, stage_id: str, all_stages: dict[str, Stage]) -> Stage:
|
|
389
|
+
all_stages_dict = pipeline_dict[
|
|
390
|
+
"tasks"
|
|
391
|
+
] # TODO: change key to "stages" when we're ready to bump the yaml schema version
|
|
392
|
+
if stage_id in all_stages:
|
|
393
|
+
return all_stages[stage_id]
|
|
394
|
+
stage_dict = all_stages_dict[stage_id]
|
|
395
|
+
stage_type_name = stage_dict[
|
|
396
|
+
"operator"
|
|
397
|
+
] # TODO: change key to "stage_type" when we're ready to bump the yaml schema version
|
|
398
|
+
stage_class = Stage._get_subclass(stage_type_name)
|
|
392
399
|
|
|
393
400
|
parsed_inputs = {}
|
|
394
|
-
for input_name, input_value in
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
source_output = getattr(
|
|
401
|
+
for input_name, input_value in stage_dict["inputs"].items():
|
|
402
|
+
source_stage_id, source_output_name = input_value.split(".")
|
|
403
|
+
source_stage = _parse_stage(pipeline_dict, source_stage_id, all_stages)
|
|
404
|
+
source_output = getattr(source_stage.outputs, source_output_name)
|
|
398
405
|
parsed_inputs[input_name] = source_output
|
|
399
406
|
|
|
400
|
-
parsed_params =
|
|
407
|
+
parsed_params = stage_class._parse_params(stage_dict["params"])
|
|
401
408
|
|
|
402
|
-
|
|
403
|
-
"
|
|
409
|
+
stage_params = {
|
|
410
|
+
"stage_name": stage_dict["name"],
|
|
404
411
|
**parsed_params,
|
|
405
412
|
**parsed_inputs,
|
|
406
413
|
}
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
return
|
|
414
|
+
stage = stage_class(**stage_params)
|
|
415
|
+
all_stages[stage_id] = stage
|
|
416
|
+
return stage
|