runnable 0.34.0a1__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of runnable might be problematic. Click here for more details.

Files changed (49) hide show
  1. extensions/catalog/any_path.py +13 -2
  2. extensions/job_executor/__init__.py +7 -5
  3. extensions/job_executor/emulate.py +106 -0
  4. extensions/job_executor/k8s.py +8 -8
  5. extensions/job_executor/local_container.py +13 -14
  6. extensions/nodes/__init__.py +0 -0
  7. extensions/nodes/conditional.py +243 -0
  8. extensions/nodes/fail.py +72 -0
  9. extensions/nodes/map.py +350 -0
  10. extensions/nodes/parallel.py +159 -0
  11. extensions/nodes/stub.py +89 -0
  12. extensions/nodes/success.py +72 -0
  13. extensions/nodes/task.py +92 -0
  14. extensions/pipeline_executor/__init__.py +27 -27
  15. extensions/pipeline_executor/argo.py +52 -46
  16. extensions/pipeline_executor/emulate.py +112 -0
  17. extensions/pipeline_executor/local.py +4 -4
  18. extensions/pipeline_executor/local_container.py +19 -79
  19. extensions/pipeline_executor/mocked.py +5 -9
  20. extensions/pipeline_executor/retry.py +6 -10
  21. runnable/__init__.py +2 -11
  22. runnable/catalog.py +6 -23
  23. runnable/cli.py +145 -48
  24. runnable/context.py +520 -28
  25. runnable/datastore.py +51 -54
  26. runnable/defaults.py +12 -34
  27. runnable/entrypoints.py +82 -440
  28. runnable/exceptions.py +35 -34
  29. runnable/executor.py +13 -20
  30. runnable/gantt.py +1141 -0
  31. runnable/graph.py +1 -1
  32. runnable/names.py +1 -1
  33. runnable/nodes.py +20 -16
  34. runnable/parameters.py +108 -51
  35. runnable/sdk.py +125 -204
  36. runnable/tasks.py +62 -85
  37. runnable/utils.py +6 -268
  38. runnable-1.0.0.dist-info/METADATA +122 -0
  39. runnable-1.0.0.dist-info/RECORD +73 -0
  40. {runnable-0.34.0a1.dist-info → runnable-1.0.0.dist-info}/entry_points.txt +9 -8
  41. extensions/nodes/nodes.py +0 -778
  42. extensions/nodes/torch.py +0 -273
  43. extensions/nodes/torch_config.py +0 -76
  44. extensions/tasks/torch.py +0 -286
  45. extensions/tasks/torch_config.py +0 -76
  46. runnable-0.34.0a1.dist-info/METADATA +0 -267
  47. runnable-0.34.0a1.dist-info/RECORD +0 -67
  48. {runnable-0.34.0a1.dist-info → runnable-1.0.0.dist-info}/WHEEL +0 -0
  49. {runnable-0.34.0a1.dist-info → runnable-1.0.0.dist-info}/licenses/LICENSE +0 -0
runnable/sdk.py CHANGED
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import inspect
3
4
  import logging
4
- import os
5
5
  import re
6
6
  from abc import ABC, abstractmethod
7
7
  from pathlib import Path
@@ -16,26 +16,17 @@ from pydantic import (
16
16
  field_validator,
17
17
  model_validator,
18
18
  )
19
- from rich.progress import (
20
- BarColumn,
21
- Progress,
22
- SpinnerColumn,
23
- TextColumn,
24
- TimeElapsedColumn,
25
- )
26
- from rich.table import Column
27
19
  from typing_extensions import Self
28
20
 
29
- from extensions.nodes.nodes import (
30
- FailNode,
31
- MapNode,
32
- ParallelNode,
33
- StubNode,
34
- SuccessNode,
35
- TaskNode,
36
- )
37
- from runnable import console, defaults, entrypoints, exceptions, graph, utils
38
- from runnable.executor import BaseJobExecutor, BasePipelineExecutor
21
+ from extensions.nodes.conditional import ConditionalNode
22
+ from extensions.nodes.fail import FailNode
23
+ from extensions.nodes.map import MapNode
24
+ from extensions.nodes.parallel import ParallelNode
25
+ from extensions.nodes.stub import StubNode
26
+ from extensions.nodes.success import SuccessNode
27
+ from extensions.nodes.task import TaskNode
28
+ from runnable import defaults, graph
29
+ from runnable.executor import BaseJobExecutor
39
30
  from runnable.nodes import TraversalNode
40
31
  from runnable.tasks import BaseTaskType as RunnableTask
41
32
  from runnable.tasks import TaskReturns
@@ -49,7 +40,7 @@ StepType = Union[
49
40
  "ShellTask",
50
41
  "Parallel",
51
42
  "Map",
52
- "TorchTask",
43
+ "Conditional",
53
44
  ]
54
45
 
55
46
 
@@ -69,6 +60,7 @@ class Catalog(BaseModel):
69
60
  Attributes:
70
61
  get (List[str]): List of glob patterns to get from central catalog to the compute data folder.
71
62
  put (List[str]): List of glob patterns to put into central catalog from the compute data folder.
63
+ store_copy (bool): Whether to store a copy of the data in the central catalog.
72
64
 
73
65
  Examples:
74
66
  >>> from runnable import Catalog
@@ -83,6 +75,7 @@ class Catalog(BaseModel):
83
75
  # compute_data_folder: str = Field(default="", alias="compute_data_folder")
84
76
  get: List[str] = Field(default_factory=list, alias="get")
85
77
  put: List[str] = Field(default_factory=list, alias="put")
78
+ store_copy: bool = Field(default=True, alias="store_copy")
86
79
 
87
80
 
88
81
  class BaseTraversal(ABC, BaseModel):
@@ -193,6 +186,9 @@ class BaseTask(BaseTraversal):
193
186
  "This method should be implemented in the child class"
194
187
  )
195
188
 
189
+ def as_pipeline(self) -> "Pipeline":
190
+ return Pipeline(steps=[self], name=self.internal_name) # type: ignore
191
+
196
192
 
197
193
  class PythonTask(BaseTask):
198
194
  """
@@ -282,26 +278,6 @@ class PythonTask(BaseTask):
282
278
  return node.executable
283
279
 
284
280
 
285
- class TorchTask(BaseTask):
286
- entrypoint: str = Field(
287
- alias="entrypoint", default="torch.distributed.run", frozen=True
288
- )
289
- args_to_torchrun: Dict[str, Any] = Field(
290
- default_factory=dict, alias="args_to_torchrun"
291
- )
292
-
293
- script_to_call: str
294
-
295
- @computed_field
296
- def command_type(self) -> str:
297
- return "torch"
298
-
299
- def create_job(self) -> RunnableTask:
300
- self.terminate_with_success = True
301
- node = self.create_node()
302
- return node.executable
303
-
304
-
305
281
  class NotebookTask(BaseTask):
306
282
  """
307
283
  An execution node of the pipeline of notebook.
@@ -481,6 +457,9 @@ class Stub(BaseTraversal):
481
457
 
482
458
  return StubNode.parse_from_config(self.model_dump(exclude_none=True))
483
459
 
460
+ def as_pipeline(self) -> "Pipeline":
461
+ return Pipeline(steps=[self])
462
+
484
463
 
485
464
  class Parallel(BaseTraversal):
486
465
  """
@@ -520,6 +499,53 @@ class Parallel(BaseTraversal):
520
499
  return node
521
500
 
522
501
 
502
+ class Conditional(BaseTraversal):
503
+ branches: Dict[str, "Pipeline"]
504
+ parameter: str # the name of the parameter should be isalnum
505
+
506
+ @field_validator("parameter")
507
+ @classmethod
508
+ def validate_parameter(cls, parameter: str) -> str:
509
+ if not parameter.isalnum():
510
+ raise AssertionError(
511
+ "The parameter name should be alphanumeric and not empty"
512
+ )
513
+ return parameter
514
+
515
+ @field_validator("branches")
516
+ @classmethod
517
+ def validate_branches(
518
+ cls, branches: Dict[str, "Pipeline"]
519
+ ) -> Dict[str, "Pipeline"]:
520
+ for branch_name in branches.keys():
521
+ if not branch_name.isalnum():
522
+ raise ValueError(f"Branch '{branch_name}' must be alphanumeric.")
523
+ return branches
524
+
525
+ @computed_field # type: ignore
526
+ @property
527
+ def graph_branches(self) -> Dict[str, graph.Graph]:
528
+ return {
529
+ name: pipeline._dag.model_copy() for name, pipeline in self.branches.items()
530
+ }
531
+
532
+ def create_node(self) -> ConditionalNode:
533
+ if not self.next_node:
534
+ if not (self.terminate_with_failure or self.terminate_with_success):
535
+ raise AssertionError(
536
+ "A node not being terminated must have a user defined next node"
537
+ )
538
+
539
+ node = ConditionalNode(
540
+ name=self.name,
541
+ branches=self.graph_branches,
542
+ internal_name="",
543
+ next_node=self.next_node,
544
+ parameter=self.parameter,
545
+ )
546
+ return node
547
+
548
+
523
549
  class Map(BaseTraversal):
524
550
  """
525
551
  A node that iterates over a list of items and executes a pipeline for each item.
@@ -543,7 +569,6 @@ class Map(BaseTraversal):
543
569
  iterate_on: str
544
570
  iterate_as: str
545
571
  reducer: Optional[str] = Field(default=None, alias="reducer")
546
- overrides: Dict[str, Any] = Field(default_factory=dict)
547
572
 
548
573
  @computed_field # type: ignore
549
574
  @property
@@ -564,7 +589,6 @@ class Map(BaseTraversal):
564
589
  next_node=self.next_node,
565
590
  iterate_on=self.iterate_on,
566
591
  iterate_as=self.iterate_as,
567
- overrides=self.overrides,
568
592
  reducer=self.reducer,
569
593
  )
570
594
 
@@ -735,6 +759,15 @@ class Pipeline(BaseModel):
735
759
  return False
736
760
  return True
737
761
 
762
+ def get_caller(self) -> str:
763
+ caller_stack = inspect.stack()[2]
764
+ relative_to_root = str(Path(caller_stack.filename).relative_to(Path.cwd()))
765
+
766
+ module_name = re.sub(r"\b.py\b", "", relative_to_root.replace("/", "."))
767
+ module_to_call = f"{module_name}.{caller_stack.function}"
768
+
769
+ return module_to_call
770
+
738
771
  def execute(
739
772
  self,
740
773
  configuration_file: str = "",
@@ -752,106 +785,31 @@ class Pipeline(BaseModel):
752
785
  # Immediately return as this call is only for getting the pipeline definition
753
786
  return {}
754
787
 
755
- logger.setLevel(log_level)
756
-
757
- run_id = utils.generate_run_id(run_id=run_id)
788
+ from runnable import context
758
789
 
759
- parameters_file = os.environ.get("RUNNABLE_PARAMETERS_FILE", parameters_file)
760
-
761
- tag = os.environ.get("RUNNABLE_tag", tag)
790
+ logger.setLevel(log_level)
762
791
 
763
- configuration_file = os.environ.get(
764
- "RUNNABLE_CONFIGURATION_FILE", configuration_file
765
- )
766
- run_context = entrypoints.prepare_configurations(
792
+ service_configurations = context.ServiceConfigurations(
767
793
  configuration_file=configuration_file,
768
- run_id=run_id,
769
- tag=tag,
770
- parameters_file=parameters_file,
794
+ execution_context=context.ExecutionContext.PIPELINE,
771
795
  )
772
796
 
773
- assert isinstance(run_context.executor, BasePipelineExecutor)
774
-
775
- utils.set_runnable_environment_variables(
776
- run_id=run_id, configuration_file=configuration_file, tag=tag
777
- )
778
-
779
- dag_definition = self._dag.model_dump(by_alias=True, exclude_none=True)
780
- run_context.from_sdk = True
781
- run_context.dag = graph.create_graph(dag_definition)
782
-
783
- console.print("Working with context:")
784
- console.print(run_context)
785
- console.rule(style="[dark orange]")
786
-
787
- if not run_context.executor._is_local:
788
- # We are not working with executor that does not work in local environment
789
- import inspect
790
-
791
- caller_stack = inspect.stack()[1]
792
- relative_to_root = str(Path(caller_stack.filename).relative_to(Path.cwd()))
793
-
794
- module_name = re.sub(r"\b.py\b", "", relative_to_root.replace("/", "."))
795
- module_to_call = f"{module_name}.{caller_stack.function}"
796
-
797
- run_context.pipeline_file = f"{module_to_call}.py"
798
- run_context.from_sdk = True
799
-
800
- # Prepare for graph execution
801
- run_context.executor._set_up_run_log(exists_ok=False)
802
-
803
- with Progress(
804
- SpinnerColumn(spinner_name="runner"),
805
- TextColumn(
806
- "[progress.description]{task.description}", table_column=Column(ratio=2)
807
- ),
808
- BarColumn(table_column=Column(ratio=1), style="dark_orange"),
809
- TimeElapsedColumn(table_column=Column(ratio=1)),
810
- console=console,
811
- expand=True,
812
- ) as progress:
813
- pipeline_execution_task = progress.add_task(
814
- "[dark_orange] Starting execution .. ", total=1
815
- )
816
- try:
817
- run_context.progress = progress
797
+ configurations = {
798
+ "pipeline_definition_file": self.get_caller(),
799
+ "parameters_file": parameters_file,
800
+ "tag": tag,
801
+ "run_id": run_id,
802
+ "execution_mode": context.ExecutionMode.PYTHON,
803
+ "configuration_file": configuration_file,
804
+ **service_configurations.services,
805
+ }
818
806
 
819
- run_context.executor.execute_graph(dag=run_context.dag)
807
+ run_context = context.PipelineContext.model_validate(configurations)
808
+ context.run_context = run_context
820
809
 
821
- if not run_context.executor._is_local:
822
- # non local executors just traverse the graph and do nothing
823
- return {}
810
+ assert isinstance(run_context, context.PipelineContext)
824
811
 
825
- run_log = run_context.run_log_store.get_run_log_by_id(
826
- run_id=run_context.run_id, full=False
827
- )
828
-
829
- if run_log.status == defaults.SUCCESS:
830
- progress.update(
831
- pipeline_execution_task,
832
- description="[green] Success",
833
- completed=True,
834
- )
835
- else:
836
- progress.update(
837
- pipeline_execution_task,
838
- description="[red] Failed",
839
- completed=True,
840
- )
841
- raise exceptions.ExecutionFailedError(run_context.run_id)
842
- except Exception as e: # noqa: E722
843
- console.print(e, style=defaults.error_style)
844
- progress.update(
845
- pipeline_execution_task,
846
- description="[red] Errored execution",
847
- completed=True,
848
- )
849
- raise
850
-
851
- if run_context.executor._is_local:
852
- return run_context.run_log_store.get_run_log_by_id(
853
- run_id=run_context.run_id
854
- )
812
+ run_context.execute()
855
813
 
856
814
 
857
815
  class BaseJob(BaseModel):
@@ -875,11 +833,25 @@ class BaseJob(BaseModel):
875
833
  def get_task(self) -> RunnableTask:
876
834
  raise NotImplementedError
877
835
 
836
+ def get_caller(self) -> str:
837
+ caller_stack = inspect.stack()[2]
838
+ relative_to_root = str(Path(caller_stack.filename).relative_to(Path.cwd()))
839
+
840
+ module_name = re.sub(r"\b.py\b", "", relative_to_root.replace("/", "."))
841
+ module_to_call = f"{module_name}.{caller_stack.function}"
842
+
843
+ return module_to_call
844
+
878
845
  def return_catalog_settings(self) -> Optional[List[str]]:
879
846
  if self.catalog is None:
880
847
  return []
881
848
  return self.catalog.put
882
849
 
850
+ def return_bool_catalog_store_copy(self) -> bool:
851
+ if self.catalog is None:
852
+ return True
853
+ return self.catalog.store_copy
854
+
883
855
  def _is_called_for_definition(self) -> bool:
884
856
  """
885
857
  If the run context is set, we are coming in only to get the pipeline definition.
@@ -901,65 +873,33 @@ class BaseJob(BaseModel):
901
873
  if self._is_called_for_definition():
902
874
  # Immediately return as this call is only for getting the job definition
903
875
  return {}
904
- logger.setLevel(log_level)
905
-
906
- run_id = utils.generate_run_id(run_id=job_id)
907
-
908
- parameters_file = os.environ.get("RUNNABLE_PARAMETERS_FILE", parameters_file)
909
-
910
- tag = os.environ.get("RUNNABLE_tag", tag)
876
+ from runnable import context
911
877
 
912
- configuration_file = os.environ.get(
913
- "RUNNABLE_CONFIGURATION_FILE", configuration_file
914
- )
878
+ logger.setLevel(log_level)
915
879
 
916
- run_context = entrypoints.prepare_configurations(
880
+ service_configurations = context.ServiceConfigurations(
917
881
  configuration_file=configuration_file,
918
- run_id=run_id,
919
- tag=tag,
920
- parameters_file=parameters_file,
921
- is_job=True,
922
- )
923
-
924
- assert isinstance(run_context.executor, BaseJobExecutor)
925
- run_context.from_sdk = True
926
-
927
- utils.set_runnable_environment_variables(
928
- run_id=run_id, configuration_file=configuration_file, tag=tag
882
+ execution_context=context.ExecutionContext.JOB,
929
883
  )
930
884
 
931
- console.print("Working with context:")
932
- console.print(run_context)
933
- console.rule(style="[dark orange]")
934
-
935
- if not run_context.executor._is_local:
936
- # We are not working with executor that does not work in local environment
937
- import inspect
938
-
939
- caller_stack = inspect.stack()[1]
940
- relative_to_root = str(Path(caller_stack.filename).relative_to(Path.cwd()))
941
-
942
- module_name = re.sub(r"\b.py\b", "", relative_to_root.replace("/", "."))
943
- module_to_call = f"{module_name}.{caller_stack.function}"
944
-
945
- run_context.job_definition_file = f"{module_to_call}.py"
946
-
947
- job = self.get_task()
948
- catalog_settings = self.return_catalog_settings()
885
+ configurations = {
886
+ "job_definition_file": self.get_caller(),
887
+ "parameters_file": parameters_file,
888
+ "tag": tag,
889
+ "run_id": job_id,
890
+ "execution_mode": context.ExecutionMode.PYTHON,
891
+ "configuration_file": configuration_file,
892
+ "job": self.get_task(),
893
+ "catalog_settings": self.return_catalog_settings(),
894
+ **service_configurations.services,
895
+ }
949
896
 
950
- try:
951
- run_context.executor.submit_job(job, catalog_settings=catalog_settings)
952
- finally:
953
- run_context.executor.add_task_log_to_catalog("job")
897
+ run_context = context.JobContext.model_validate(configurations)
898
+ run_context.catalog_store_copy = self.return_bool_catalog_store_copy()
954
899
 
955
- logger.info(
956
- "Executing the job from the user. We are still in the caller's compute environment"
957
- )
900
+ assert isinstance(run_context.job_executor, BaseJobExecutor)
958
901
 
959
- if run_context.executor._is_local:
960
- return run_context.run_log_store.get_run_log_by_id(
961
- run_id=run_context.run_id
962
- )
902
+ run_context.execute()
963
903
 
964
904
 
965
905
  class PythonJob(BaseJob):
@@ -983,25 +923,6 @@ class PythonJob(BaseJob):
983
923
  return task.create_node().executable
984
924
 
985
925
 
986
- class TorchJob(BaseJob):
987
- entrypoint: str = Field(default="torch.distributed.run", frozen=True)
988
- args_to_torchrun: dict[str, str | bool | int | float] = Field(
989
- default_factory=dict
990
- ) # For example
991
- # {"nproc_per_node": 2, "nnodes": 1,}
992
-
993
- script_to_call: str # For example train/script.py
994
-
995
- def get_task(self) -> RunnableTask:
996
- # Piggy bank on existing tasks as a hack
997
- task = TorchTask(
998
- name="dummy",
999
- terminate_with_success=True,
1000
- **self.model_dump(exclude_defaults=True, exclude_none=True),
1001
- )
1002
- return task.create_node().executable
1003
-
1004
-
1005
926
  class NotebookJob(BaseJob):
1006
927
  notebook: str = Field(serialization_alias="command")
1007
928
  optional_ploomber_args: Optional[Dict[str, Any]] = Field(