luminarycloud 0.21.0__py3-none-any.whl → 0.21.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. luminarycloud/_client/http_client.py +10 -8
  2. luminarycloud/_helpers/_upload_mesh.py +1 -0
  3. luminarycloud/_helpers/download.py +3 -1
  4. luminarycloud/_helpers/upload.py +15 -6
  5. luminarycloud/_proto/api/v0/luminarycloud/geometry/geometry_pb2.py +124 -124
  6. luminarycloud/_proto/api/v0/luminarycloud/geometry/geometry_pb2.pyi +8 -1
  7. luminarycloud/_proto/api/v0/luminarycloud/mesh/mesh_pb2.py +11 -11
  8. luminarycloud/_proto/api/v0/luminarycloud/mesh/mesh_pb2.pyi +9 -2
  9. luminarycloud/_proto/api/v0/luminarycloud/physics_ai/physics_ai_pb2.py +33 -20
  10. luminarycloud/_proto/api/v0/luminarycloud/physics_ai/physics_ai_pb2.pyi +21 -1
  11. luminarycloud/_proto/api/v0/luminarycloud/project/project_pb2.py +16 -16
  12. luminarycloud/_proto/api/v0/luminarycloud/project/project_pb2.pyi +7 -3
  13. luminarycloud/_proto/api/v0/luminarycloud/simulation/simulation_pb2.py +55 -55
  14. luminarycloud/_proto/api/v0/luminarycloud/simulation/simulation_pb2.pyi +4 -0
  15. luminarycloud/_proto/api/v0/luminarycloud/vis/vis_pb2.py +28 -26
  16. luminarycloud/_proto/api/v0/luminarycloud/vis/vis_pb2.pyi +7 -1
  17. luminarycloud/_proto/assistant/assistant_pb2.py +47 -34
  18. luminarycloud/_proto/assistant/assistant_pb2.pyi +21 -1
  19. luminarycloud/_proto/base/base_pb2.py +17 -7
  20. luminarycloud/_proto/base/base_pb2.pyi +26 -0
  21. luminarycloud/_proto/cad/transformation_pb2.py +60 -16
  22. luminarycloud/_proto/cad/transformation_pb2.pyi +138 -32
  23. luminarycloud/_proto/hexmesh/hexmesh_pb2.py +20 -18
  24. luminarycloud/_proto/hexmesh/hexmesh_pb2.pyi +7 -2
  25. luminarycloud/_proto/quantity/quantity_options_pb2.py +6 -6
  26. luminarycloud/_proto/quantity/quantity_options_pb2.pyi +10 -1
  27. luminarycloud/_proto/quantity/quantity_pb2.py +166 -166
  28. luminarycloud/_proto/quantity/quantity_pb2.pyi +1 -1
  29. luminarycloud/enum/gpu_type.py +2 -0
  30. luminarycloud/feature_modification.py +13 -34
  31. luminarycloud/physics_ai/solution.py +3 -1
  32. luminarycloud/pipelines/__init__.py +3 -3
  33. luminarycloud/pipelines/api.py +81 -0
  34. luminarycloud/pipelines/core.py +103 -96
  35. luminarycloud/pipelines/{operators.py → stages.py} +28 -28
  36. luminarycloud/project.py +10 -1
  37. luminarycloud/types/matrix3.py +12 -0
  38. luminarycloud/volume_selection.py +18 -60
  39. {luminarycloud-0.21.0.dist-info → luminarycloud-0.21.2.dist-info}/METADATA +1 -1
  40. {luminarycloud-0.21.0.dist-info → luminarycloud-0.21.2.dist-info}/RECORD +41 -41
  41. {luminarycloud-0.21.0.dist-info → luminarycloud-0.21.2.dist-info}/WHEEL +0 -0
@@ -1,7 +1,7 @@
1
1
  """
2
2
  @generated by mypy-protobuf. Do not edit manually!
3
3
  isort:skip_file
4
- Copyright 2020-2025 Luminary Cloud, Inc. All Rights Reserved.
4
+ Copyright 2025 Luminary Cloud, Inc. All Rights Reserved.
5
5
  Generated by quantities.py. DO NOT EDIT
6
6
  """
7
7
  import builtins
@@ -10,3 +10,5 @@ class GPUType(IntEnum):
10
10
  UNSPECIFIED = simulationpb.SimulationOptions.GPU_TYPE_UNSPECIFIED
11
11
  V100 = simulationpb.SimulationOptions.GPU_TYPE_V100
12
12
  A100 = simulationpb.SimulationOptions.GPU_TYPE_A100
13
+ T4 = simulationpb.SimulationOptions.GPU_TYPE_T4
14
+ H100 = simulationpb.SimulationOptions.GPU_TYPE_H100
@@ -1,12 +1,12 @@
1
1
  # Copyright 2025 Luminary Cloud, Inc. All Rights Reserved.
2
2
  from enum import Enum, auto
3
- from typing import Dict, Iterable, List, Optional, cast
3
+ from typing import Dict, Iterable, List, Optional
4
4
  from copy import deepcopy
5
5
 
6
- from luminarycloud.types.adfloat import _to_ad_proto, _from_ad_proto
6
+ from luminarycloud.types.adfloat import _to_ad_proto
7
7
  from ._proto.geometry import geometry_pb2 as gpb
8
8
  from .types import Vector3Like
9
- from .types.vector3 import _to_vector3
9
+ from .types.vector3 import _to_vector3_ad_proto
10
10
  from .params.geometry import Shape, Sphere, Cube, Cylinder, Torus, Cone, HalfSphere, Volume
11
11
  from google.protobuf.internal.containers import RepeatedScalarFieldContainer
12
12
 
@@ -483,10 +483,7 @@ def modify_translate(
483
483
  _update_repeated_field(transform_op.body, vol_ids)
484
484
 
485
485
  if displacement is not None:
486
- vec = _to_vector3(displacement)
487
- transform_op.translation.vector.x = vec.x
488
- transform_op.translation.vector.y = vec.y
489
- transform_op.translation.vector.z = vec.z
486
+ transform_op.translation.vector.CopyFrom(_to_vector3_ad_proto(displacement))
490
487
 
491
488
  if keep is not None:
492
489
  transform_op.keep = keep
@@ -532,19 +529,13 @@ def modify_rotate(
532
529
 
533
530
  # Update existing rotation
534
531
  if angle is not None:
535
- transform_op.rotation.angle = angle
532
+ transform_op.rotation.angle.CopyFrom(_to_ad_proto(angle))
536
533
 
537
534
  if axis is not None:
538
- axis_vec = _to_vector3(axis)
539
- transform_op.rotation.arbitrary.direction.x = axis_vec.x
540
- transform_op.rotation.arbitrary.direction.y = axis_vec.y
541
- transform_op.rotation.arbitrary.direction.z = axis_vec.z
535
+ transform_op.rotation.arbitrary.direction.CopyFrom(_to_vector3_ad_proto(axis))
542
536
 
543
537
  if origin is not None:
544
- origin_vec = _to_vector3(origin)
545
- transform_op.rotation.arbitrary.origin.x = origin_vec.x
546
- transform_op.rotation.arbitrary.origin.y = origin_vec.y
547
- transform_op.rotation.arbitrary.origin.z = origin_vec.z
538
+ transform_op.rotation.arbitrary.origin.CopyFrom(_to_vector3_ad_proto(origin))
548
539
 
549
540
  if keep is not None:
550
541
  transform_op.keep = keep
@@ -587,13 +578,10 @@ def modify_scale(
587
578
  _update_repeated_field(transform_op.body, vol_ids)
588
579
 
589
580
  if scale_factor is not None:
590
- transform_op.scaling.isotropic = scale_factor
581
+ transform_op.scaling.isotropic.CopyFrom(_to_ad_proto(scale_factor))
591
582
 
592
583
  if origin is not None:
593
- origin_vec = _to_vector3(origin)
594
- transform_op.scaling.arbitrary.x = origin_vec.x
595
- transform_op.scaling.arbitrary.y = origin_vec.y
596
- transform_op.scaling.arbitrary.z = origin_vec.z
584
+ transform_op.scaling.arbitrary.CopyFrom(_to_vector3_ad_proto(origin))
597
585
 
598
586
  if keep is not None:
599
587
  transform_op.keep = keep
@@ -769,10 +757,7 @@ def modify_linear_pattern(
769
757
 
770
758
  if direction is not None:
771
759
  # Update existing linear pattern direction
772
- dir_vec = _to_vector3(direction)
773
- pattern_op.direction.linear_spacing.vector.x = dir_vec.x
774
- pattern_op.direction.linear_spacing.vector.y = dir_vec.y
775
- pattern_op.direction.linear_spacing.vector.z = dir_vec.z
760
+ pattern_op.direction.linear_spacing.vector.CopyFrom(_to_vector3_ad_proto(direction))
776
761
 
777
762
  if quantity is not None:
778
763
  pattern_op.direction.quantity = quantity
@@ -832,19 +817,13 @@ def modify_circular_pattern(
832
817
  circular = pattern_op.direction.circular_distribution
833
818
 
834
819
  if angle is not None:
835
- circular.rotation.angle = angle
820
+ circular.rotation.angle.CopyFrom(_to_ad_proto(angle))
836
821
 
837
822
  if axis is not None:
838
- axis_vec = _to_vector3(axis)
839
- circular.rotation.arbitrary.direction.x = axis_vec.x
840
- circular.rotation.arbitrary.direction.y = axis_vec.y
841
- circular.rotation.arbitrary.direction.z = axis_vec.z
823
+ circular.rotation.arbitrary.direction.CopyFrom(_to_vector3_ad_proto(axis))
842
824
 
843
825
  if origin is not None:
844
- origin_vec = _to_vector3(origin)
845
- circular.rotation.arbitrary.origin.x = origin_vec.x
846
- circular.rotation.arbitrary.origin.y = origin_vec.y
847
- circular.rotation.arbitrary.origin.z = origin_vec.z
826
+ circular.rotation.arbitrary.origin.CopyFrom(_to_vector3_ad_proto(origin))
848
827
 
849
828
  if full_rotation is not None:
850
829
  circular.full = full_rotation
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import tarfile
4
- from typing import List, Optional, BinaryIO, cast
4
+ from typing import List, Optional, BinaryIO, cast, Dict
5
5
 
6
6
  from .._client import get_default_client
7
7
  from .._helpers.download import download_solution_physics_ai as _download_solution_physics_ai
@@ -16,6 +16,7 @@ def _download_processed_solution_physics_ai( # noqa: F841
16
16
  volume_fields_to_keep: Optional[List[QuantityType]] = None,
17
17
  process_volume: bool = False,
18
18
  single_precision: bool = True,
19
+ internal_options: Optional[Dict[str, str]] = None,
19
20
  ) -> tarfile.TarFile:
20
21
  """
21
22
  Download solution data with physics AI processing applied.
@@ -50,6 +51,7 @@ def _download_processed_solution_physics_ai( # noqa: F841
50
51
  volume_fields_to_keep=volume_fields_to_keep,
51
52
  process_volume=process_volume,
52
53
  single_precision=single_precision,
54
+ internal_options=internal_options,
53
55
  )
54
56
 
55
57
  assert stream is not None, "Failed to download solution data"
@@ -2,6 +2,8 @@
2
2
  from .core import (
3
3
  Pipeline as Pipeline,
4
4
  PipelineParameter as PipelineParameter,
5
+ # Stage base class, mainly exported for testing
6
+ Stage as Stage,
5
7
  )
6
8
 
7
9
  from .parameters import (
@@ -11,9 +13,7 @@ from .parameters import (
11
13
  BoolPipelineParameter as BoolPipelineParameter,
12
14
  )
13
15
 
14
- from .operators import (
15
- # Operator base class, mainly exported for testing
16
- Operator as Operator,
16
+ from .stages import (
17
17
  # PipelineOutputs, i.e. things that "flow" in a Pipeline
18
18
  PipelineOutputGeometry as PipelineOutputGeometry,
19
19
  PipelineOutputMesh as PipelineOutputMesh,
@@ -3,11 +3,15 @@ from typing import Any, Literal
3
3
  from dataclasses import dataclass
4
4
 
5
5
  from datetime import datetime
6
+ from time import time, sleep
7
+ import logging
6
8
 
7
9
  from .arguments import PipelineArgValueType
8
10
  from ..pipelines import Pipeline, PipelineArgs
9
11
  from .._client import get_default_client
10
12
 
13
+ logger = logging.getLogger(__name__)
14
+
11
15
 
12
16
  @dataclass
13
17
  class LogLine:
@@ -152,6 +156,83 @@ class PipelineJobRecord:
152
156
  res = get_default_client().http.get(f"/rest/v0/pipeline_jobs/{self.id}/artifacts")
153
157
  return res["data"]
154
158
 
159
+ def wait(
160
+ self,
161
+ *,
162
+ interval_seconds: float = 5,
163
+ timeout_seconds: float = float("inf"),
164
+ print_logs: bool = False,
165
+ ) -> Literal["completed", "failed"]:
166
+ """
167
+ Wait for the pipeline job to complete or fail.
168
+
169
+ This method polls the pipeline job status at regular intervals until it reaches
170
+ a terminal state (completed or failed).
171
+
172
+ Parameters
173
+ ----------
174
+ interval_seconds : float
175
+ Number of seconds between status polls. Default is 5 seconds.
176
+ timeout_seconds : float
177
+ Number of seconds before the operation times out. Default is infinity.
178
+ print_logs : bool
179
+ If True, prints new log lines as they become available. Default is False.
180
+
181
+ Returns
182
+ -------
183
+ Literal["completed", "failed"]
184
+ The final status of the pipeline job.
185
+
186
+ Raises
187
+ ------
188
+ TimeoutError
189
+ If the pipeline job does not complete within the specified timeout.
190
+
191
+ Examples
192
+ --------
193
+ >>> pipeline_job = pipelines.create_pipeline_job(pipeline.id, args, "My Job")
194
+ >>> final_status = pipeline_job.wait(timeout_seconds=3600)
195
+ >>> print(f"Pipeline job finished with status: {final_status}")
196
+ """
197
+ deadline = time() + timeout_seconds
198
+ last_log_count = 0
199
+
200
+ while True:
201
+ # Refresh the pipeline job status
202
+ updated_job = get_pipeline_job(self.id)
203
+
204
+ # Print new logs if requested
205
+ if print_logs:
206
+ logs = updated_job.logs()
207
+ if len(logs) > last_log_count:
208
+ for log_line in logs[last_log_count:]:
209
+ print(f"[{log_line.timestamp}] {log_line.message}")
210
+ last_log_count = len(logs)
211
+
212
+ # Check if we've reached a terminal state
213
+ if updated_job.status == "completed":
214
+ logger.info(f"Pipeline job {self.id} completed successfully")
215
+ return "completed"
216
+ elif updated_job.status == "failed":
217
+ logger.warning(f"Pipeline job {self.id} failed")
218
+ return "failed"
219
+
220
+ # Check timeout
221
+ if time() >= deadline:
222
+ raise TimeoutError(
223
+ f"Timed out waiting for pipeline job {self.id} to complete. "
224
+ f"Current status: {updated_job.status}"
225
+ )
226
+
227
+ # Wait before next poll
228
+ sleep(max(0, min(interval_seconds, deadline - time())))
229
+
230
+ # Update self with the latest status
231
+ self.status = updated_job.status
232
+ self.updated_at = updated_job.updated_at
233
+ self.started_at = updated_job.started_at
234
+ self.completed_at = updated_job.completed_at
235
+
155
236
 
156
237
  @dataclass
157
238
  class PipelineJobRunRecord:
@@ -88,48 +88,47 @@ class PipelineParameter(ABC):
88
88
 
89
89
  class PipelineInput:
90
90
  """
91
- A named input for an Operator instance (i.e. a Task). Explicitly connected to a PipelineOutput.
91
+ A named input for a Stage. Explicitly connected to a PipelineOutput.
92
92
  """
93
93
 
94
- def __init__(self, upstream_output: "PipelineOutput", owner: "Operator", name: str):
94
+ def __init__(self, upstream_output: "PipelineOutput", owner: "Stage", name: str):
95
95
  self.upstream_output = upstream_output
96
96
  self.owner = owner
97
97
  self.name = name
98
98
 
99
- def _to_dict(self, id_for_task: dict) -> dict:
100
- if self.upstream_output.owner not in id_for_task:
99
+ def _to_dict(self, id_for_stage: dict) -> dict:
100
+ if self.upstream_output.owner not in id_for_stage:
101
101
  raise ValueError(
102
- f"Task {self.owner} depends on a task, {self.upstream_output.owner}, that isn't in the Pipeline. Did you forget to add it?"
102
+ f"Stage {self.owner} depends on a stage, {self.upstream_output.owner}, that isn't in the Pipeline. Did you forget to add it?"
103
103
  )
104
- upstream_task_id = id_for_task[self.upstream_output.owner]
104
+ upstream_stage_id = id_for_stage[self.upstream_output.owner]
105
105
  upstream_output_name = self.upstream_output.name
106
- return {self.name: f"{upstream_task_id}.{upstream_output_name}"}
106
+ return {self.name: f"{upstream_stage_id}.{upstream_output_name}"}
107
107
 
108
108
 
109
109
  class PipelineOutput(ABC):
110
110
  """
111
- A named output for an Operator instance (i.e. a Task). Can be used to spawn any number of
112
- connected PipelineInputs.
111
+ A named output for a Stage. Can be used to spawn any number of connected PipelineInputs.
113
112
  """
114
113
 
115
- def __init__(self, owner: "Operator", name: str):
114
+ def __init__(self, owner: "Stage", name: str):
116
115
  self.owner = owner
117
116
  self.name = name
118
117
  self.downstream_inputs: list[PipelineInput] = []
119
118
 
120
- def _spawn_input(self, owner: "Operator", name: str) -> PipelineInput:
119
+ def _spawn_input(self, owner: "Stage", name: str) -> PipelineInput:
121
120
  input = PipelineInput(self, owner, name)
122
121
  self.downstream_inputs.append(input)
123
122
  return input
124
123
 
125
124
 
126
- class OperatorInputs:
125
+ class StageInputs:
127
126
  """
128
- A collection of all PipelineInputs for an Operator instance (i.e. a Task).
127
+ A collection of all PipelineInputs for a Stage.
129
128
  """
130
129
 
131
130
  def __init__(
132
- self, owner: "Operator", **input_descriptors: tuple[Type[PipelineOutput], PipelineOutput]
131
+ self, owner: "Stage", **input_descriptors: tuple[Type[PipelineOutput], PipelineOutput]
133
132
  ):
134
133
  """
135
134
  input_descriptors is a dict of input name -> (required_upstream_output_type, upstream_output)
@@ -144,26 +143,26 @@ class OperatorInputs:
144
143
  )
145
144
  self.inputs.add(upstream_output._spawn_input(owner, name))
146
145
 
147
- def _to_dict(self, id_for_task: dict) -> dict[str, str]:
146
+ def _to_dict(self, id_for_stage: dict) -> dict[str, str]:
148
147
  d: dict[str, str] = {}
149
148
  for input in self.inputs:
150
- d |= input._to_dict(id_for_task)
149
+ d |= input._to_dict(id_for_stage)
151
150
  return d
152
151
 
153
152
 
154
- T = TypeVar("T", bound="OperatorOutputs")
153
+ T = TypeVar("T", bound="StageOutputs")
155
154
 
156
155
 
157
- class OperatorOutputs(ABC):
156
+ class StageOutputs(ABC):
158
157
  """
159
- A collection of all PipelineOutputs for an Operator instance (i.e. a Task). Must be subclassed,
160
- and the subclass must also be a dataclass whose fields are all PipelineOutput subclasses. Then
161
- that subclass should be instantiated with `_instantiate_for`. Sounds a little complicated,
162
- perhaps, but it's not bad. See the existing subclasses in `./operators.py` for examples.
158
+ A collection of all PipelineOutputs for a Stage. Must be subclassed, and the subclass must also
159
+ be a dataclass whose fields are all PipelineOutput subclasses. Then that subclass should be
160
+ instantiated with `_instantiate_for`. Sounds a little complicated, perhaps, but it's not bad.
161
+ See the existing subclasses in `./stages.py` for examples.
163
162
  """
164
163
 
165
164
  @classmethod
166
- def _instantiate_for(cls: type[T], owner: "Operator") -> T:
165
+ def _instantiate_for(cls: type[T], owner: "Stage") -> T:
167
166
  # create an instance with all fields instantiated with the given owner, and named by the
168
167
  # field name.
169
168
  # Also validate here that we are a dataclass, and all our fields are PipelineOutput types.
@@ -188,41 +187,41 @@ class OperatorOutputs(ABC):
188
187
  return inputs
189
188
 
190
189
 
191
- class OperatorRegistry:
190
+ class StageRegistry:
192
191
  def __init__(self):
193
- self.operators = {}
192
+ self.stages = {}
194
193
 
195
- def register(self, operator_class: Type["Operator"]) -> None:
196
- self.operators[operator_class.__name__] = operator_class
194
+ def register(self, stage_class: Type["Stage"]) -> None:
195
+ self.stages[stage_class.__name__] = stage_class
197
196
 
198
- def get(self, operator_name: str) -> Type["Operator"]:
199
- if operator_name not in self.operators:
200
- raise ValueError(f"Unknown operator: {operator_name}")
201
- return self.operators[operator_name]
197
+ def get(self, stage_type_name: str) -> Type["Stage"]:
198
+ if stage_type_name not in self.stages:
199
+ raise ValueError(f"Unknown stage type: {stage_type_name}")
200
+ return self.stages[stage_type_name]
202
201
 
203
202
 
204
- TOutputs = TypeVar("TOutputs", bound=OperatorOutputs)
203
+ TOutputs = TypeVar("TOutputs", bound=StageOutputs)
205
204
 
206
205
 
207
- class Operator(Generic[TOutputs], ABC):
206
+ class Stage(Generic[TOutputs], ABC):
208
207
  def __init__(
209
208
  self,
210
- task_name: str | None,
209
+ stage_name: str | None,
211
210
  params: dict,
212
- inputs: OperatorInputs,
211
+ inputs: StageInputs,
213
212
  outputs: TOutputs,
214
213
  ):
215
- self._operator_name = self.__class__.__name__
216
- self._task_name = task_name if task_name is not None else self._operator_name
214
+ self._stage_type_name = self.__class__.__name__
215
+ self._name = stage_name if stage_name is not None else self._stage_type_name
217
216
  self._params = params
218
217
  self._inputs = inputs
219
218
  self.outputs = outputs
220
- ensure_yamlizable(self._params_dict()[0], "Operator parameters")
219
+ ensure_yamlizable(self._params_dict()[0], "Stage parameters")
221
220
 
222
221
  def is_source(self) -> bool:
223
222
  return len(self._inputs.inputs) == 0
224
223
 
225
- def inputs_dict(self) -> dict[str, tuple["Operator", str]]:
224
+ def inputs_dict(self) -> dict[str, tuple["Stage", str]]:
226
225
  inputs = {}
227
226
  for pipeline_input in self._inputs.inputs:
228
227
  inputs[pipeline_input.name] = (
@@ -231,16 +230,16 @@ class Operator(Generic[TOutputs], ABC):
231
230
  )
232
231
  return inputs
233
232
 
234
- def downstream_tasks(self) -> list["Operator"]:
233
+ def downstream_stages(self) -> list["Stage"]:
235
234
  return [input.owner for input in self.outputs.downstream_inputs()]
236
235
 
237
- def _to_dict(self, id_for_task: dict) -> tuple[dict, set[PipelineParameter]]:
236
+ def _to_dict(self, id_for_stage: dict) -> tuple[dict, set[PipelineParameter]]:
238
237
  params, pipeline_params_set = self._params_dict()
239
238
  d = {
240
- "name": self._task_name,
241
- "operator": self._operator_name,
239
+ "name": self._name,
240
+ "operator": self._stage_type_name, # TODO: change key to "stage_type" when we're ready to bump the yaml schema version
242
241
  "params": params,
243
- "inputs": self._inputs._to_dict(id_for_task),
242
+ "inputs": self._inputs._to_dict(id_for_stage),
244
243
  }
245
244
  return d, pipeline_params_set
246
245
 
@@ -256,64 +255,64 @@ class Operator(Generic[TOutputs], ABC):
256
255
  return d, pipeline_params
257
256
 
258
257
  def __str__(self) -> str:
259
- return f'{self._operator_name}(name="{self._task_name}")'
258
+ return f'{self._stage_type_name}(name="{self._name}")'
260
259
 
261
- _registry = OperatorRegistry()
260
+ _registry = StageRegistry()
262
261
 
263
262
  def __init_subclass__(cls, **kwargs):
264
263
  super().__init_subclass__(**kwargs)
265
- Operator._registry.register(cls)
264
+ Stage._registry.register(cls)
266
265
 
267
266
  @classmethod
268
- def _get_subclass(cls, operator_name: str) -> Type["Operator"]:
269
- return cls._registry.get(operator_name)
267
+ def _get_subclass(cls, stage_type_name: str) -> Type["Stage"]:
268
+ return cls._registry.get(stage_type_name)
270
269
 
271
270
  @classmethod
272
271
  def _parse_params(cls, params: dict) -> dict:
273
- # Operators with params that are just primitives or PipelineParams have no parsing to do.
274
- # Operators with more complicated params should override this method.
272
+ # Stages with params that are just primitives or PipelineParams have no parsing to do.
273
+ # Stages with more complicated params should override this method.
275
274
  return params
276
275
 
277
276
 
278
277
  class Pipeline:
279
- def __init__(self, tasks: list[Operator]):
280
- self.tasks = tasks
281
- self._task_ids = self._assign_ids_to_tasks()
278
+ def __init__(self, stages: list[Stage]):
279
+ self.stages = stages
280
+ self._stage_ids = self._assign_ids_to_stages()
282
281
 
283
282
  def to_yaml(self) -> str:
284
283
  return yaml.safe_dump(self._to_dict())
285
284
 
286
285
  def pipeline_params(self) -> set[PipelineParameter]:
287
- return self._tasks_dict_and_params()[1]
286
+ return self._stages_dict_and_params()[1]
288
287
 
289
- def _get_task_id(self, task: Operator) -> str:
290
- return self._task_ids[task]
288
+ def _get_stage_id(self, stage: Stage) -> str:
289
+ return self._stage_ids[stage]
291
290
 
292
- def _tasks_dict_and_params(self) -> tuple[dict, set[PipelineParameter]]:
293
- id_for_task = self._task_ids
294
- tasks = {}
291
+ def _stages_dict_and_params(self) -> tuple[dict, set[PipelineParameter]]:
292
+ id_for_stage = self._stage_ids
293
+ stages = {}
295
294
  params = set()
296
- for task in id_for_task.keys():
297
- task_dict, referenced_params = task._to_dict(id_for_task)
298
- tasks[id_for_task[task]] = task_dict
295
+ for stage in id_for_stage.keys():
296
+ stage_dict, referenced_params = stage._to_dict(id_for_stage)
297
+ stages[id_for_stage[stage]] = stage_dict
299
298
  params.update(referenced_params)
300
- return tasks, params
299
+ return stages, params
301
300
 
302
301
  def _to_dict(self) -> dict:
303
- tasks, params = self._tasks_dict_and_params()
302
+ stages, params = self._stages_dict_and_params()
304
303
 
305
304
  d = {
306
305
  "lc_pipeline": {
307
306
  "schema_version": 1,
308
307
  "params": self._pipeline_params_dict(params),
309
- "tasks": tasks,
308
+ "tasks": stages, # TODO: change key to "stages" when we're ready to bump the yaml schema version
310
309
  }
311
310
  }
312
311
  ensure_yamlizable(d, "Pipeline")
313
312
  return d
314
313
 
315
- def _assign_ids_to_tasks(self) -> dict[Operator, str]:
316
- return {task: f"t{i + 1}-{task._operator_name}" for i, task in enumerate(self.tasks)}
314
+ def _assign_ids_to_stages(self) -> dict[Stage, str]:
315
+ return {stage: f"s{i + 1}-{stage._stage_type_name}" for i, stage in enumerate(self.stages)}
317
316
 
318
317
  def _pipeline_params_dict(self, params: set[PipelineParameter]) -> dict:
319
318
  d: dict[str, dict] = {}
@@ -334,7 +333,9 @@ class Pipeline:
334
333
  d = d["lc_pipeline"]
335
334
  if "schema_version" not in d:
336
335
  raise ValueError("Invalid pipeline YAML: missing 'schema_version' key")
337
- if "tasks" not in d:
336
+ if (
337
+ "tasks" not in d
338
+ ): # TODO: change key to "stages" when we're ready to bump the yaml schema version
338
339
  raise ValueError("Invalid pipeline YAML: missing 'tasks' key")
339
340
 
340
341
  if d["schema_version"] != 1:
@@ -347,18 +348,20 @@ class Pipeline:
347
348
  param_name
348
349
  )
349
350
 
350
- # ...and use them as replacements for any references in the tasks' parameters
351
- for task_dict in d["tasks"].values():
352
- task_dict["params"] = _recursive_replace_pipeline_params(
353
- task_dict["params"], parsed_params
351
+ # ...and use them as replacements for any references in the stages' parameters
352
+ for stage_dict in d[
353
+ "tasks"
354
+ ].values(): # TODO: change key to "stages" when we're ready to bump the yaml schema version
355
+ stage_dict["params"] = _recursive_replace_pipeline_params(
356
+ stage_dict["params"], parsed_params
354
357
  )
355
358
 
356
- # then, finish parsing the tasks
357
- parsed_tasks = {}
358
- for task_id in d["tasks"]:
359
- _parse_task(d, task_id, parsed_tasks)
359
+ # then, finish parsing the stages
360
+ parsed_stages = {}
361
+ for stage_id in d["tasks"]:
362
+ _parse_stage(d, stage_id, parsed_stages)
360
363
 
361
- return cls(list(parsed_tasks.values()))
364
+ return cls(list(parsed_stages.values()))
362
365
 
363
366
 
364
367
  def _recursive_replace_pipeline_params(d: Any, parsed_params: dict) -> Any:
@@ -368,7 +371,7 @@ def _recursive_replace_pipeline_params(d: Any, parsed_params: dict) -> Any:
368
371
  pp_name = d["$pipeline_param"]
369
372
  if pp_name not in parsed_params:
370
373
  raise ValueError(
371
- f'Pipeline parameter "{pp_name}" referenced in a pipeline task, but not found in pipeline\'s declared parameters'
374
+ f'Pipeline parameter "{pp_name}" referenced in a pipeline stage, but not found in pipeline\'s declared parameters'
372
375
  )
373
376
  return parsed_params[pp_name]
374
377
  else:
@@ -382,28 +385,32 @@ def _recursive_replace_pipeline_params(d: Any, parsed_params: dict) -> Any:
382
385
  return d
383
386
 
384
387
 
385
- def _parse_task(pipeline_dict: dict, task_id: str, all_tasks: dict[str, Operator]) -> Operator:
386
- all_tasks_dict = pipeline_dict["tasks"]
387
- if task_id in all_tasks:
388
- return all_tasks[task_id]
389
- task_dict = all_tasks_dict[task_id]
390
- operator_name = task_dict["operator"]
391
- operator_class = Operator._get_subclass(operator_name)
388
+ def _parse_stage(pipeline_dict: dict, stage_id: str, all_stages: dict[str, Stage]) -> Stage:
389
+ all_stages_dict = pipeline_dict[
390
+ "tasks"
391
+ ] # TODO: change key to "stages" when we're ready to bump the yaml schema version
392
+ if stage_id in all_stages:
393
+ return all_stages[stage_id]
394
+ stage_dict = all_stages_dict[stage_id]
395
+ stage_type_name = stage_dict[
396
+ "operator"
397
+ ] # TODO: change key to "stage_type" when we're ready to bump the yaml schema version
398
+ stage_class = Stage._get_subclass(stage_type_name)
392
399
 
393
400
  parsed_inputs = {}
394
- for input_name, input_value in task_dict["inputs"].items():
395
- source_task_id, source_output_name = input_value.split(".")
396
- source_task = _parse_task(pipeline_dict, source_task_id, all_tasks)
397
- source_output = getattr(source_task.outputs, source_output_name)
401
+ for input_name, input_value in stage_dict["inputs"].items():
402
+ source_stage_id, source_output_name = input_value.split(".")
403
+ source_stage = _parse_stage(pipeline_dict, source_stage_id, all_stages)
404
+ source_output = getattr(source_stage.outputs, source_output_name)
398
405
  parsed_inputs[input_name] = source_output
399
406
 
400
- parsed_params = operator_class._parse_params(task_dict["params"])
407
+ parsed_params = stage_class._parse_params(stage_dict["params"])
401
408
 
402
- op_params = {
403
- "task_name": task_dict["name"],
409
+ stage_params = {
410
+ "stage_name": stage_dict["name"],
404
411
  **parsed_params,
405
412
  **parsed_inputs,
406
413
  }
407
- operator = operator_class(**op_params)
408
- all_tasks[task_id] = operator
409
- return operator
414
+ stage = stage_class(**stage_params)
415
+ all_stages[stage_id] = stage
416
+ return stage