runnable 0.34.0a1__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of runnable might be problematic. Click here for more details.
- extensions/catalog/any_path.py +13 -2
- extensions/job_executor/__init__.py +7 -5
- extensions/job_executor/emulate.py +106 -0
- extensions/job_executor/k8s.py +8 -8
- extensions/job_executor/local_container.py +13 -14
- extensions/nodes/__init__.py +0 -0
- extensions/nodes/conditional.py +243 -0
- extensions/nodes/fail.py +72 -0
- extensions/nodes/map.py +350 -0
- extensions/nodes/parallel.py +159 -0
- extensions/nodes/stub.py +89 -0
- extensions/nodes/success.py +72 -0
- extensions/nodes/task.py +92 -0
- extensions/pipeline_executor/__init__.py +27 -27
- extensions/pipeline_executor/argo.py +52 -46
- extensions/pipeline_executor/emulate.py +112 -0
- extensions/pipeline_executor/local.py +4 -4
- extensions/pipeline_executor/local_container.py +19 -79
- extensions/pipeline_executor/mocked.py +5 -9
- extensions/pipeline_executor/retry.py +6 -10
- runnable/__init__.py +2 -11
- runnable/catalog.py +6 -23
- runnable/cli.py +145 -48
- runnable/context.py +520 -28
- runnable/datastore.py +51 -54
- runnable/defaults.py +12 -34
- runnable/entrypoints.py +82 -440
- runnable/exceptions.py +35 -34
- runnable/executor.py +13 -20
- runnable/gantt.py +1141 -0
- runnable/graph.py +1 -1
- runnable/names.py +1 -1
- runnable/nodes.py +20 -16
- runnable/parameters.py +108 -51
- runnable/sdk.py +125 -204
- runnable/tasks.py +62 -85
- runnable/utils.py +6 -268
- runnable-1.0.0.dist-info/METADATA +122 -0
- runnable-1.0.0.dist-info/RECORD +73 -0
- {runnable-0.34.0a1.dist-info → runnable-1.0.0.dist-info}/entry_points.txt +9 -8
- extensions/nodes/nodes.py +0 -778
- extensions/nodes/torch.py +0 -273
- extensions/nodes/torch_config.py +0 -76
- extensions/tasks/torch.py +0 -286
- extensions/tasks/torch_config.py +0 -76
- runnable-0.34.0a1.dist-info/METADATA +0 -267
- runnable-0.34.0a1.dist-info/RECORD +0 -67
- {runnable-0.34.0a1.dist-info → runnable-1.0.0.dist-info}/WHEEL +0 -0
- {runnable-0.34.0a1.dist-info → runnable-1.0.0.dist-info}/licenses/LICENSE +0 -0
runnable/sdk.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import inspect
|
|
3
4
|
import logging
|
|
4
|
-
import os
|
|
5
5
|
import re
|
|
6
6
|
from abc import ABC, abstractmethod
|
|
7
7
|
from pathlib import Path
|
|
@@ -16,26 +16,17 @@ from pydantic import (
|
|
|
16
16
|
field_validator,
|
|
17
17
|
model_validator,
|
|
18
18
|
)
|
|
19
|
-
from rich.progress import (
|
|
20
|
-
BarColumn,
|
|
21
|
-
Progress,
|
|
22
|
-
SpinnerColumn,
|
|
23
|
-
TextColumn,
|
|
24
|
-
TimeElapsedColumn,
|
|
25
|
-
)
|
|
26
|
-
from rich.table import Column
|
|
27
19
|
from typing_extensions import Self
|
|
28
20
|
|
|
29
|
-
from extensions.nodes.
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
from runnable import
|
|
38
|
-
from runnable.executor import BaseJobExecutor, BasePipelineExecutor
|
|
21
|
+
from extensions.nodes.conditional import ConditionalNode
|
|
22
|
+
from extensions.nodes.fail import FailNode
|
|
23
|
+
from extensions.nodes.map import MapNode
|
|
24
|
+
from extensions.nodes.parallel import ParallelNode
|
|
25
|
+
from extensions.nodes.stub import StubNode
|
|
26
|
+
from extensions.nodes.success import SuccessNode
|
|
27
|
+
from extensions.nodes.task import TaskNode
|
|
28
|
+
from runnable import defaults, graph
|
|
29
|
+
from runnable.executor import BaseJobExecutor
|
|
39
30
|
from runnable.nodes import TraversalNode
|
|
40
31
|
from runnable.tasks import BaseTaskType as RunnableTask
|
|
41
32
|
from runnable.tasks import TaskReturns
|
|
@@ -49,7 +40,7 @@ StepType = Union[
|
|
|
49
40
|
"ShellTask",
|
|
50
41
|
"Parallel",
|
|
51
42
|
"Map",
|
|
52
|
-
"
|
|
43
|
+
"Conditional",
|
|
53
44
|
]
|
|
54
45
|
|
|
55
46
|
|
|
@@ -69,6 +60,7 @@ class Catalog(BaseModel):
|
|
|
69
60
|
Attributes:
|
|
70
61
|
get (List[str]): List of glob patterns to get from central catalog to the compute data folder.
|
|
71
62
|
put (List[str]): List of glob patterns to put into central catalog from the compute data folder.
|
|
63
|
+
store_copy (bool): Whether to store a copy of the data in the central catalog.
|
|
72
64
|
|
|
73
65
|
Examples:
|
|
74
66
|
>>> from runnable import Catalog
|
|
@@ -83,6 +75,7 @@ class Catalog(BaseModel):
|
|
|
83
75
|
# compute_data_folder: str = Field(default="", alias="compute_data_folder")
|
|
84
76
|
get: List[str] = Field(default_factory=list, alias="get")
|
|
85
77
|
put: List[str] = Field(default_factory=list, alias="put")
|
|
78
|
+
store_copy: bool = Field(default=True, alias="store_copy")
|
|
86
79
|
|
|
87
80
|
|
|
88
81
|
class BaseTraversal(ABC, BaseModel):
|
|
@@ -193,6 +186,9 @@ class BaseTask(BaseTraversal):
|
|
|
193
186
|
"This method should be implemented in the child class"
|
|
194
187
|
)
|
|
195
188
|
|
|
189
|
+
def as_pipeline(self) -> "Pipeline":
|
|
190
|
+
return Pipeline(steps=[self], name=self.internal_name) # type: ignore
|
|
191
|
+
|
|
196
192
|
|
|
197
193
|
class PythonTask(BaseTask):
|
|
198
194
|
"""
|
|
@@ -282,26 +278,6 @@ class PythonTask(BaseTask):
|
|
|
282
278
|
return node.executable
|
|
283
279
|
|
|
284
280
|
|
|
285
|
-
class TorchTask(BaseTask):
|
|
286
|
-
entrypoint: str = Field(
|
|
287
|
-
alias="entrypoint", default="torch.distributed.run", frozen=True
|
|
288
|
-
)
|
|
289
|
-
args_to_torchrun: Dict[str, Any] = Field(
|
|
290
|
-
default_factory=dict, alias="args_to_torchrun"
|
|
291
|
-
)
|
|
292
|
-
|
|
293
|
-
script_to_call: str
|
|
294
|
-
|
|
295
|
-
@computed_field
|
|
296
|
-
def command_type(self) -> str:
|
|
297
|
-
return "torch"
|
|
298
|
-
|
|
299
|
-
def create_job(self) -> RunnableTask:
|
|
300
|
-
self.terminate_with_success = True
|
|
301
|
-
node = self.create_node()
|
|
302
|
-
return node.executable
|
|
303
|
-
|
|
304
|
-
|
|
305
281
|
class NotebookTask(BaseTask):
|
|
306
282
|
"""
|
|
307
283
|
An execution node of the pipeline of notebook.
|
|
@@ -481,6 +457,9 @@ class Stub(BaseTraversal):
|
|
|
481
457
|
|
|
482
458
|
return StubNode.parse_from_config(self.model_dump(exclude_none=True))
|
|
483
459
|
|
|
460
|
+
def as_pipeline(self) -> "Pipeline":
|
|
461
|
+
return Pipeline(steps=[self])
|
|
462
|
+
|
|
484
463
|
|
|
485
464
|
class Parallel(BaseTraversal):
|
|
486
465
|
"""
|
|
@@ -520,6 +499,53 @@ class Parallel(BaseTraversal):
|
|
|
520
499
|
return node
|
|
521
500
|
|
|
522
501
|
|
|
502
|
+
class Conditional(BaseTraversal):
|
|
503
|
+
branches: Dict[str, "Pipeline"]
|
|
504
|
+
parameter: str # the name of the parameter should be isalnum
|
|
505
|
+
|
|
506
|
+
@field_validator("parameter")
|
|
507
|
+
@classmethod
|
|
508
|
+
def validate_parameter(cls, parameter: str) -> str:
|
|
509
|
+
if not parameter.isalnum():
|
|
510
|
+
raise AssertionError(
|
|
511
|
+
"The parameter name should be alphanumeric and not empty"
|
|
512
|
+
)
|
|
513
|
+
return parameter
|
|
514
|
+
|
|
515
|
+
@field_validator("branches")
|
|
516
|
+
@classmethod
|
|
517
|
+
def validate_branches(
|
|
518
|
+
cls, branches: Dict[str, "Pipeline"]
|
|
519
|
+
) -> Dict[str, "Pipeline"]:
|
|
520
|
+
for branch_name in branches.keys():
|
|
521
|
+
if not branch_name.isalnum():
|
|
522
|
+
raise ValueError(f"Branch '{branch_name}' must be alphanumeric.")
|
|
523
|
+
return branches
|
|
524
|
+
|
|
525
|
+
@computed_field # type: ignore
|
|
526
|
+
@property
|
|
527
|
+
def graph_branches(self) -> Dict[str, graph.Graph]:
|
|
528
|
+
return {
|
|
529
|
+
name: pipeline._dag.model_copy() for name, pipeline in self.branches.items()
|
|
530
|
+
}
|
|
531
|
+
|
|
532
|
+
def create_node(self) -> ConditionalNode:
|
|
533
|
+
if not self.next_node:
|
|
534
|
+
if not (self.terminate_with_failure or self.terminate_with_success):
|
|
535
|
+
raise AssertionError(
|
|
536
|
+
"A node not being terminated must have a user defined next node"
|
|
537
|
+
)
|
|
538
|
+
|
|
539
|
+
node = ConditionalNode(
|
|
540
|
+
name=self.name,
|
|
541
|
+
branches=self.graph_branches,
|
|
542
|
+
internal_name="",
|
|
543
|
+
next_node=self.next_node,
|
|
544
|
+
parameter=self.parameter,
|
|
545
|
+
)
|
|
546
|
+
return node
|
|
547
|
+
|
|
548
|
+
|
|
523
549
|
class Map(BaseTraversal):
|
|
524
550
|
"""
|
|
525
551
|
A node that iterates over a list of items and executes a pipeline for each item.
|
|
@@ -543,7 +569,6 @@ class Map(BaseTraversal):
|
|
|
543
569
|
iterate_on: str
|
|
544
570
|
iterate_as: str
|
|
545
571
|
reducer: Optional[str] = Field(default=None, alias="reducer")
|
|
546
|
-
overrides: Dict[str, Any] = Field(default_factory=dict)
|
|
547
572
|
|
|
548
573
|
@computed_field # type: ignore
|
|
549
574
|
@property
|
|
@@ -564,7 +589,6 @@ class Map(BaseTraversal):
|
|
|
564
589
|
next_node=self.next_node,
|
|
565
590
|
iterate_on=self.iterate_on,
|
|
566
591
|
iterate_as=self.iterate_as,
|
|
567
|
-
overrides=self.overrides,
|
|
568
592
|
reducer=self.reducer,
|
|
569
593
|
)
|
|
570
594
|
|
|
@@ -735,6 +759,15 @@ class Pipeline(BaseModel):
|
|
|
735
759
|
return False
|
|
736
760
|
return True
|
|
737
761
|
|
|
762
|
+
def get_caller(self) -> str:
|
|
763
|
+
caller_stack = inspect.stack()[2]
|
|
764
|
+
relative_to_root = str(Path(caller_stack.filename).relative_to(Path.cwd()))
|
|
765
|
+
|
|
766
|
+
module_name = re.sub(r"\b.py\b", "", relative_to_root.replace("/", "."))
|
|
767
|
+
module_to_call = f"{module_name}.{caller_stack.function}"
|
|
768
|
+
|
|
769
|
+
return module_to_call
|
|
770
|
+
|
|
738
771
|
def execute(
|
|
739
772
|
self,
|
|
740
773
|
configuration_file: str = "",
|
|
@@ -752,106 +785,31 @@ class Pipeline(BaseModel):
|
|
|
752
785
|
# Immediately return as this call is only for getting the pipeline definition
|
|
753
786
|
return {}
|
|
754
787
|
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
run_id = utils.generate_run_id(run_id=run_id)
|
|
788
|
+
from runnable import context
|
|
758
789
|
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
tag = os.environ.get("RUNNABLE_tag", tag)
|
|
790
|
+
logger.setLevel(log_level)
|
|
762
791
|
|
|
763
|
-
|
|
764
|
-
"RUNNABLE_CONFIGURATION_FILE", configuration_file
|
|
765
|
-
)
|
|
766
|
-
run_context = entrypoints.prepare_configurations(
|
|
792
|
+
service_configurations = context.ServiceConfigurations(
|
|
767
793
|
configuration_file=configuration_file,
|
|
768
|
-
|
|
769
|
-
tag=tag,
|
|
770
|
-
parameters_file=parameters_file,
|
|
794
|
+
execution_context=context.ExecutionContext.PIPELINE,
|
|
771
795
|
)
|
|
772
796
|
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
|
|
783
|
-
console.print("Working with context:")
|
|
784
|
-
console.print(run_context)
|
|
785
|
-
console.rule(style="[dark orange]")
|
|
786
|
-
|
|
787
|
-
if not run_context.executor._is_local:
|
|
788
|
-
# We are not working with executor that does not work in local environment
|
|
789
|
-
import inspect
|
|
790
|
-
|
|
791
|
-
caller_stack = inspect.stack()[1]
|
|
792
|
-
relative_to_root = str(Path(caller_stack.filename).relative_to(Path.cwd()))
|
|
793
|
-
|
|
794
|
-
module_name = re.sub(r"\b.py\b", "", relative_to_root.replace("/", "."))
|
|
795
|
-
module_to_call = f"{module_name}.{caller_stack.function}"
|
|
796
|
-
|
|
797
|
-
run_context.pipeline_file = f"{module_to_call}.py"
|
|
798
|
-
run_context.from_sdk = True
|
|
799
|
-
|
|
800
|
-
# Prepare for graph execution
|
|
801
|
-
run_context.executor._set_up_run_log(exists_ok=False)
|
|
802
|
-
|
|
803
|
-
with Progress(
|
|
804
|
-
SpinnerColumn(spinner_name="runner"),
|
|
805
|
-
TextColumn(
|
|
806
|
-
"[progress.description]{task.description}", table_column=Column(ratio=2)
|
|
807
|
-
),
|
|
808
|
-
BarColumn(table_column=Column(ratio=1), style="dark_orange"),
|
|
809
|
-
TimeElapsedColumn(table_column=Column(ratio=1)),
|
|
810
|
-
console=console,
|
|
811
|
-
expand=True,
|
|
812
|
-
) as progress:
|
|
813
|
-
pipeline_execution_task = progress.add_task(
|
|
814
|
-
"[dark_orange] Starting execution .. ", total=1
|
|
815
|
-
)
|
|
816
|
-
try:
|
|
817
|
-
run_context.progress = progress
|
|
797
|
+
configurations = {
|
|
798
|
+
"pipeline_definition_file": self.get_caller(),
|
|
799
|
+
"parameters_file": parameters_file,
|
|
800
|
+
"tag": tag,
|
|
801
|
+
"run_id": run_id,
|
|
802
|
+
"execution_mode": context.ExecutionMode.PYTHON,
|
|
803
|
+
"configuration_file": configuration_file,
|
|
804
|
+
**service_configurations.services,
|
|
805
|
+
}
|
|
818
806
|
|
|
819
|
-
|
|
807
|
+
run_context = context.PipelineContext.model_validate(configurations)
|
|
808
|
+
context.run_context = run_context
|
|
820
809
|
|
|
821
|
-
|
|
822
|
-
# non local executors just traverse the graph and do nothing
|
|
823
|
-
return {}
|
|
810
|
+
assert isinstance(run_context, context.PipelineContext)
|
|
824
811
|
|
|
825
|
-
|
|
826
|
-
run_id=run_context.run_id, full=False
|
|
827
|
-
)
|
|
828
|
-
|
|
829
|
-
if run_log.status == defaults.SUCCESS:
|
|
830
|
-
progress.update(
|
|
831
|
-
pipeline_execution_task,
|
|
832
|
-
description="[green] Success",
|
|
833
|
-
completed=True,
|
|
834
|
-
)
|
|
835
|
-
else:
|
|
836
|
-
progress.update(
|
|
837
|
-
pipeline_execution_task,
|
|
838
|
-
description="[red] Failed",
|
|
839
|
-
completed=True,
|
|
840
|
-
)
|
|
841
|
-
raise exceptions.ExecutionFailedError(run_context.run_id)
|
|
842
|
-
except Exception as e: # noqa: E722
|
|
843
|
-
console.print(e, style=defaults.error_style)
|
|
844
|
-
progress.update(
|
|
845
|
-
pipeline_execution_task,
|
|
846
|
-
description="[red] Errored execution",
|
|
847
|
-
completed=True,
|
|
848
|
-
)
|
|
849
|
-
raise
|
|
850
|
-
|
|
851
|
-
if run_context.executor._is_local:
|
|
852
|
-
return run_context.run_log_store.get_run_log_by_id(
|
|
853
|
-
run_id=run_context.run_id
|
|
854
|
-
)
|
|
812
|
+
run_context.execute()
|
|
855
813
|
|
|
856
814
|
|
|
857
815
|
class BaseJob(BaseModel):
|
|
@@ -875,11 +833,25 @@ class BaseJob(BaseModel):
|
|
|
875
833
|
def get_task(self) -> RunnableTask:
|
|
876
834
|
raise NotImplementedError
|
|
877
835
|
|
|
836
|
+
def get_caller(self) -> str:
|
|
837
|
+
caller_stack = inspect.stack()[2]
|
|
838
|
+
relative_to_root = str(Path(caller_stack.filename).relative_to(Path.cwd()))
|
|
839
|
+
|
|
840
|
+
module_name = re.sub(r"\b.py\b", "", relative_to_root.replace("/", "."))
|
|
841
|
+
module_to_call = f"{module_name}.{caller_stack.function}"
|
|
842
|
+
|
|
843
|
+
return module_to_call
|
|
844
|
+
|
|
878
845
|
def return_catalog_settings(self) -> Optional[List[str]]:
|
|
879
846
|
if self.catalog is None:
|
|
880
847
|
return []
|
|
881
848
|
return self.catalog.put
|
|
882
849
|
|
|
850
|
+
def return_bool_catalog_store_copy(self) -> bool:
|
|
851
|
+
if self.catalog is None:
|
|
852
|
+
return True
|
|
853
|
+
return self.catalog.store_copy
|
|
854
|
+
|
|
883
855
|
def _is_called_for_definition(self) -> bool:
|
|
884
856
|
"""
|
|
885
857
|
If the run context is set, we are coming in only to get the pipeline definition.
|
|
@@ -901,65 +873,33 @@ class BaseJob(BaseModel):
|
|
|
901
873
|
if self._is_called_for_definition():
|
|
902
874
|
# Immediately return as this call is only for getting the job definition
|
|
903
875
|
return {}
|
|
904
|
-
|
|
905
|
-
|
|
906
|
-
run_id = utils.generate_run_id(run_id=job_id)
|
|
907
|
-
|
|
908
|
-
parameters_file = os.environ.get("RUNNABLE_PARAMETERS_FILE", parameters_file)
|
|
909
|
-
|
|
910
|
-
tag = os.environ.get("RUNNABLE_tag", tag)
|
|
876
|
+
from runnable import context
|
|
911
877
|
|
|
912
|
-
|
|
913
|
-
"RUNNABLE_CONFIGURATION_FILE", configuration_file
|
|
914
|
-
)
|
|
878
|
+
logger.setLevel(log_level)
|
|
915
879
|
|
|
916
|
-
|
|
880
|
+
service_configurations = context.ServiceConfigurations(
|
|
917
881
|
configuration_file=configuration_file,
|
|
918
|
-
|
|
919
|
-
tag=tag,
|
|
920
|
-
parameters_file=parameters_file,
|
|
921
|
-
is_job=True,
|
|
922
|
-
)
|
|
923
|
-
|
|
924
|
-
assert isinstance(run_context.executor, BaseJobExecutor)
|
|
925
|
-
run_context.from_sdk = True
|
|
926
|
-
|
|
927
|
-
utils.set_runnable_environment_variables(
|
|
928
|
-
run_id=run_id, configuration_file=configuration_file, tag=tag
|
|
882
|
+
execution_context=context.ExecutionContext.JOB,
|
|
929
883
|
)
|
|
930
884
|
|
|
931
|
-
|
|
932
|
-
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
|
|
936
|
-
|
|
937
|
-
|
|
938
|
-
|
|
939
|
-
|
|
940
|
-
|
|
941
|
-
|
|
942
|
-
module_name = re.sub(r"\b.py\b", "", relative_to_root.replace("/", "."))
|
|
943
|
-
module_to_call = f"{module_name}.{caller_stack.function}"
|
|
944
|
-
|
|
945
|
-
run_context.job_definition_file = f"{module_to_call}.py"
|
|
946
|
-
|
|
947
|
-
job = self.get_task()
|
|
948
|
-
catalog_settings = self.return_catalog_settings()
|
|
885
|
+
configurations = {
|
|
886
|
+
"job_definition_file": self.get_caller(),
|
|
887
|
+
"parameters_file": parameters_file,
|
|
888
|
+
"tag": tag,
|
|
889
|
+
"run_id": job_id,
|
|
890
|
+
"execution_mode": context.ExecutionMode.PYTHON,
|
|
891
|
+
"configuration_file": configuration_file,
|
|
892
|
+
"job": self.get_task(),
|
|
893
|
+
"catalog_settings": self.return_catalog_settings(),
|
|
894
|
+
**service_configurations.services,
|
|
895
|
+
}
|
|
949
896
|
|
|
950
|
-
|
|
951
|
-
|
|
952
|
-
finally:
|
|
953
|
-
run_context.executor.add_task_log_to_catalog("job")
|
|
897
|
+
run_context = context.JobContext.model_validate(configurations)
|
|
898
|
+
run_context.catalog_store_copy = self.return_bool_catalog_store_copy()
|
|
954
899
|
|
|
955
|
-
|
|
956
|
-
"Executing the job from the user. We are still in the caller's compute environment"
|
|
957
|
-
)
|
|
900
|
+
assert isinstance(run_context.job_executor, BaseJobExecutor)
|
|
958
901
|
|
|
959
|
-
|
|
960
|
-
return run_context.run_log_store.get_run_log_by_id(
|
|
961
|
-
run_id=run_context.run_id
|
|
962
|
-
)
|
|
902
|
+
run_context.execute()
|
|
963
903
|
|
|
964
904
|
|
|
965
905
|
class PythonJob(BaseJob):
|
|
@@ -983,25 +923,6 @@ class PythonJob(BaseJob):
|
|
|
983
923
|
return task.create_node().executable
|
|
984
924
|
|
|
985
925
|
|
|
986
|
-
class TorchJob(BaseJob):
|
|
987
|
-
entrypoint: str = Field(default="torch.distributed.run", frozen=True)
|
|
988
|
-
args_to_torchrun: dict[str, str | bool | int | float] = Field(
|
|
989
|
-
default_factory=dict
|
|
990
|
-
) # For example
|
|
991
|
-
# {"nproc_per_node": 2, "nnodes": 1,}
|
|
992
|
-
|
|
993
|
-
script_to_call: str # For example train/script.py
|
|
994
|
-
|
|
995
|
-
def get_task(self) -> RunnableTask:
|
|
996
|
-
# Piggy bank on existing tasks as a hack
|
|
997
|
-
task = TorchTask(
|
|
998
|
-
name="dummy",
|
|
999
|
-
terminate_with_success=True,
|
|
1000
|
-
**self.model_dump(exclude_defaults=True, exclude_none=True),
|
|
1001
|
-
)
|
|
1002
|
-
return task.create_node().executable
|
|
1003
|
-
|
|
1004
|
-
|
|
1005
926
|
class NotebookJob(BaseJob):
|
|
1006
927
|
notebook: str = Field(serialization_alias="command")
|
|
1007
928
|
optional_ploomber_args: Optional[Dict[str, Any]] = Field(
|