runnable 0.35.0__py3-none-any.whl → 0.36.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. extensions/job_executor/__init__.py +3 -4
  2. extensions/job_executor/emulate.py +106 -0
  3. extensions/job_executor/k8s.py +8 -8
  4. extensions/job_executor/local_container.py +13 -14
  5. extensions/nodes/__init__.py +0 -0
  6. extensions/nodes/conditional.py +7 -5
  7. extensions/nodes/fail.py +72 -0
  8. extensions/nodes/map.py +350 -0
  9. extensions/nodes/parallel.py +159 -0
  10. extensions/nodes/stub.py +89 -0
  11. extensions/nodes/success.py +72 -0
  12. extensions/nodes/task.py +92 -0
  13. extensions/pipeline_executor/__init__.py +24 -26
  14. extensions/pipeline_executor/argo.py +18 -15
  15. extensions/pipeline_executor/emulate.py +112 -0
  16. extensions/pipeline_executor/local.py +4 -4
  17. extensions/pipeline_executor/local_container.py +19 -79
  18. extensions/pipeline_executor/mocked.py +4 -4
  19. extensions/pipeline_executor/retry.py +6 -10
  20. extensions/tasks/torch.py +1 -1
  21. runnable/__init__.py +0 -8
  22. runnable/catalog.py +1 -21
  23. runnable/cli.py +0 -59
  24. runnable/context.py +519 -28
  25. runnable/datastore.py +51 -54
  26. runnable/defaults.py +12 -34
  27. runnable/entrypoints.py +82 -440
  28. runnable/exceptions.py +35 -34
  29. runnable/executor.py +13 -20
  30. runnable/names.py +1 -1
  31. runnable/nodes.py +16 -15
  32. runnable/parameters.py +2 -2
  33. runnable/sdk.py +66 -163
  34. runnable/tasks.py +62 -21
  35. runnable/utils.py +6 -268
  36. {runnable-0.35.0.dist-info → runnable-0.36.0.dist-info}/METADATA +1 -1
  37. runnable-0.36.0.dist-info/RECORD +74 -0
  38. {runnable-0.35.0.dist-info → runnable-0.36.0.dist-info}/entry_points.txt +8 -7
  39. extensions/nodes/nodes.py +0 -778
  40. runnable-0.35.0.dist-info/RECORD +0 -66
  41. {runnable-0.35.0.dist-info → runnable-0.36.0.dist-info}/WHEEL +0 -0
  42. {runnable-0.35.0.dist-info → runnable-0.36.0.dist-info}/licenses/LICENSE +0 -0
runnable/context.py CHANGED
@@ -1,45 +1,536 @@
1
- from typing import Any, Dict, List, Optional
1
+ import hashlib
2
+ import importlib
3
+ import json
4
+ import logging
5
+ import os
6
+ import sys
7
+ from datetime import datetime
8
+ from enum import Enum
9
+ from functools import cached_property, partial
10
+ from typing import Annotated, Any, Callable, Dict, Optional
2
11
 
3
- from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny
4
- from rich.progress import Progress
12
+ from pydantic import (
13
+ BaseModel,
14
+ BeforeValidator,
15
+ ConfigDict,
16
+ Field,
17
+ computed_field,
18
+ field_validator,
19
+ )
20
+ from rich.progress import (
21
+ BarColumn,
22
+ Progress,
23
+ SpinnerColumn,
24
+ TextColumn,
25
+ TimeElapsedColumn,
26
+ )
27
+ from rich.table import Column
28
+ from stevedore import driver
5
29
 
30
+ from runnable import console, defaults, exceptions, names, utils
6
31
  from runnable.catalog import BaseCatalog
7
32
  from runnable.datastore import BaseRunLogStore
8
- from runnable.executor import BaseExecutor
9
- from runnable.graph import Graph
33
+ from runnable.executor import BaseJobExecutor, BasePipelineExecutor
34
+ from runnable.graph import Graph, create_graph
35
+ from runnable.nodes import BaseNode
10
36
  from runnable.pickler import BasePickler
11
37
  from runnable.secrets import BaseSecrets
12
38
  from runnable.tasks import BaseTaskType
13
39
 
40
+ logger = logging.getLogger(defaults.LOGGER_NAME)
14
41
 
15
- class Context(BaseModel):
16
- executor: SerializeAsAny[BaseExecutor]
17
- run_log_store: SerializeAsAny[BaseRunLogStore]
18
- secrets_handler: SerializeAsAny[BaseSecrets]
19
- catalog_handler: SerializeAsAny[BaseCatalog]
20
- pickler: SerializeAsAny[BasePickler]
21
- progress: SerializeAsAny[Optional[Progress]] = Field(default=None, exclude=True)
22
42
 
23
- model_config = ConfigDict(arbitrary_types_allowed=True)
43
+ def get_pipeline_spec_from_yaml(pipeline_file: str) -> Graph:
44
+ """
45
+ Reads the pipeline file from a YAML file and sets the pipeline spec in the run context
46
+ """
47
+ pipeline_config = utils.load_yaml(pipeline_file)
48
+ logger.info("The input pipeline:")
49
+ logger.info(json.dumps(pipeline_config, indent=4))
24
50
 
25
- pipeline_file: Optional[str] = ""
26
- job_definition_file: Optional[str] = ""
27
- parameters_file: Optional[str] = ""
28
- configuration_file: Optional[str] = ""
29
- from_sdk: bool = False
51
+ dag_config = pipeline_config["dag"]
30
52
 
31
- run_id: str = ""
32
- object_serialisation: bool = True
33
- return_objects: Dict[str, Any] = {}
53
+ dag = create_graph(dag_config)
54
+ return dag
34
55
 
35
- tag: str = ""
36
- variables: Dict[str, str] = {}
37
56
 
38
- dag: Optional[Graph] = None
39
- dag_hash: str = ""
57
+ def get_pipeline_spec_from_python(python_module: str) -> Graph:
58
+ # Call the SDK to get the dag
59
+ # Import the module and call the function to get the dag
60
+ module_file = python_module.rstrip(".py")
61
+ module, func = utils.get_module_and_attr_names(module_file)
62
+ sys.path.insert(0, os.getcwd()) # Need to add the current directory to path
63
+ imported_module = importlib.import_module(module)
40
64
 
41
- job: Optional[BaseTaskType] = None
42
- job_catalog_settings: Optional[List[str]] = []
65
+ dag = getattr(imported_module, func)().return_dag()
43
66
 
67
+ return dag
44
68
 
45
- run_context = None # type: Context # type: ignore
69
+
70
+ def get_job_spec_from_python(job_file: str) -> BaseTaskType:
71
+ """
72
+ Reads the job file from a Python file and sets the job spec in the run context
73
+ """
74
+ # Import the module and call the function to get the job
75
+ module_file = job_file.rstrip(".py")
76
+ module, func = utils.get_module_and_attr_names(module_file)
77
+ sys.path.insert(0, os.getcwd()) # Need to add the current directory to path
78
+ imported_module = importlib.import_module(module)
79
+
80
+ job = getattr(imported_module, func)().get_task()
81
+
82
+ return job
83
+
84
+
85
+ def get_service_by_name(namespace: str, service_config: dict[str, Any], _) -> Any: # noqa: ANN401, ANN001
86
+ """Get the service by name."""
87
+ service_config = service_config.copy()
88
+
89
+ kind = service_config.pop("type", None)
90
+
91
+ if "config" in service_config:
92
+ service_config = service_config.get("config", {})
93
+
94
+ logger.debug(
95
+ f"Trying to get a service of {namespace} with config: {service_config}"
96
+ )
97
+ try:
98
+ mgr = driver.DriverManager(
99
+ namespace=namespace, # eg: reader
100
+ name=kind, # eg: csv, pdf
101
+ invoke_on_load=True,
102
+ invoke_kwds={**service_config},
103
+ )
104
+ return mgr.driver
105
+ except Exception as _e:
106
+ raise Exception(
107
+ f"Could not find the service of kind: {kind} in namespace:{namespace} with config: {service_config}"
108
+ ) from _e
109
+
110
+
111
+ def get_service(service: str) -> Callable:
112
+ """Get the service by name.
113
+
114
+ Args:
115
+ service (str): service name.
116
+
117
+ Returns:
118
+ Callable: callable function of service.
119
+ """
120
+ return partial(get_service_by_name, service)
121
+
122
+
123
+ InstantiatedCatalog = Annotated[BaseCatalog, BeforeValidator(get_service("catalog"))]
124
+ InstantiatedSecrets = Annotated[BaseSecrets, BeforeValidator(get_service("secrets"))]
125
+ InstantiatedPickler = Annotated[BasePickler, BeforeValidator(get_service("pickler"))]
126
+ InstantiatedRunLogStore = Annotated[
127
+ BaseRunLogStore, BeforeValidator(get_service("run_log_store"))
128
+ ]
129
+ InstantiatedPipelineExecutor = Annotated[
130
+ BasePipelineExecutor, BeforeValidator(get_service("pipeline_executor"))
131
+ ]
132
+ InstantiatedJobExecutor = Annotated[
133
+ BaseJobExecutor, BeforeValidator(get_service("job_executor"))
134
+ ]
135
+
136
+
137
+ class ExecutionMode(str, Enum):
138
+ YAML = "yaml"
139
+ PYTHON = "python"
140
+
141
+
142
+ class ExecutionContext(str, Enum):
143
+ PIPELINE = "pipeline"
144
+ JOB = "job"
145
+
146
+
147
+ class ServiceConfigurations(BaseModel):
148
+ configuration_file: Optional[str] = Field(
149
+ default=None, exclude=True, description="Path to the configuration file."
150
+ )
151
+ execution_context: ExecutionContext = ExecutionContext.PIPELINE
152
+ variables: dict[str, str] = Field(
153
+ default_factory=utils.gather_variables,
154
+ exclude=True,
155
+ description="Variables to be used.",
156
+ )
157
+
158
+ @field_validator("configuration_file", mode="before")
159
+ @classmethod
160
+ def override_configuration_file(cls, configuration_file: str | None) -> str | None:
161
+ """Determine the configuration file to use, following the order of precedence."""
162
+ # 1. Environment variable
163
+ env_config = os.environ.get(defaults.RUNNABLE_CONFIGURATION_FILE)
164
+ if env_config:
165
+ return env_config
166
+
167
+ # 2. User-provided at runtime
168
+ if configuration_file:
169
+ return configuration_file
170
+
171
+ # 3. Default user config file
172
+ if utils.does_file_exist(defaults.USER_CONFIG_FILE):
173
+ return defaults.USER_CONFIG_FILE
174
+
175
+ # 4. No config file
176
+ return None
177
+
178
+ @computed_field # type: ignore
179
+ @property
180
+ def services(self) -> dict[str, Any]:
181
+ """Get the effective services"""
182
+ _services = defaults.DEFAULT_SERVICES.copy()
183
+
184
+ if not self.configuration_file:
185
+ return _services
186
+
187
+ # Load the configuration file
188
+ templated_config = utils.load_yaml(self.configuration_file)
189
+ config = utils.apply_variables(templated_config, self.variables)
190
+
191
+ for key, value in config.items():
192
+ _services[key.replace("-", "_")] = value
193
+
194
+ if self.execution_context == ExecutionContext.JOB:
195
+ _services.pop("pipeline_executor", None)
196
+ elif self.execution_context == ExecutionContext.PIPELINE:
197
+ _services.pop("job_executor", None)
198
+ else:
199
+ raise ValueError(
200
+ f"Invalid execution context: {self.execution_context}. Must be 'pipeline' or 'job'."
201
+ )
202
+
203
+ return _services
204
+
205
+
206
+ class RunnableContext(BaseModel):
207
+ model_config = ConfigDict(use_enum_values=True, loc_by_alias=True)
208
+
209
+ execution_mode: ExecutionMode = ExecutionMode.PYTHON
210
+
211
+ parameters_file: Optional[str] = Field(
212
+ default=None, exclude=True, description="Path to the parameters file."
213
+ )
214
+ configuration_file: Optional[str] = Field(
215
+ default=None, exclude=True, description="Path to the configuration file."
216
+ )
217
+ variables: dict[str, str] = Field(
218
+ default_factory=utils.gather_variables,
219
+ exclude=True,
220
+ description="Variables to be used.",
221
+ )
222
+ run_id: str = "" # Should be annotated to generate one if not provided
223
+ tag: Optional[str] = Field(default=None, description="Tag to be used for the run.")
224
+
225
+ # TODO: Verify the design
226
+ object_serialisation: bool = (
227
+ True # Should be validated against executor being local
228
+ )
229
+ return_objects: Dict[
230
+ str, Any
231
+ ] = {} # Should be validated against executor being local, should this be here?
232
+
233
+ @field_validator("parameters_file", mode="before")
234
+ @classmethod
235
+ def override_parameters_file(cls, parameters_file: str) -> str:
236
+ """Override the parameters file if provided."""
237
+ if os.environ.get(defaults.RUNNABLE_PARAMETERS_FILE, None):
238
+ return os.environ.get(defaults.RUNNABLE_PARAMETERS_FILE, parameters_file)
239
+ return parameters_file
240
+
241
+ @field_validator("configuration_file", mode="before")
242
+ @classmethod
243
+ def override_configuration_file(cls, configuration_file: str) -> str:
244
+ """Override the configuration file if provided."""
245
+ return os.environ.get(defaults.RUNNABLE_CONFIGURATION_FILE, configuration_file)
246
+
247
+ @field_validator("run_id", mode="before")
248
+ @classmethod
249
+ def generate_run_id(cls, run_id: str) -> str:
250
+ """Generate a run id if not provided."""
251
+ if not run_id:
252
+ run_id = os.environ.get(defaults.ENV_RUN_ID, "")
253
+
254
+ # If both are not given, generate one
255
+ if not run_id:
256
+ now = datetime.now()
257
+ run_id = f"{names.get_random_name()}-{now.hour:02}{now.minute:02}"
258
+
259
+ return run_id
260
+
261
+ def model_post_init(self, __context: Any) -> None:
262
+ os.environ[defaults.ENV_RUN_ID] = self.run_id
263
+
264
+ if self.configuration_file:
265
+ os.environ[defaults.RUNNABLE_CONFIGURATION_FILE] = self.configuration_file
266
+ if self.tag:
267
+ os.environ[defaults.RUNNABLE_RUN_TAG] = self.tag
268
+
269
+ global run_context
270
+ if not run_context:
271
+ run_context = self # type: ignore
272
+
273
+ global progress
274
+ progress = Progress(
275
+ SpinnerColumn(spinner_name="runner"),
276
+ TextColumn(
277
+ "[progress.description]{task.description}", table_column=Column(ratio=2)
278
+ ),
279
+ BarColumn(table_column=Column(ratio=1), style="dark_orange"),
280
+ TimeElapsedColumn(table_column=Column(ratio=1)),
281
+ console=console,
282
+ expand=True,
283
+ auto_refresh=False,
284
+ )
285
+
286
+ def execute(self):
287
+ "Execute the pipeline or the job"
288
+ raise NotImplementedError
289
+
290
+
291
+ class PipelineContext(RunnableContext):
292
+ pipeline_executor: InstantiatedPipelineExecutor
293
+ catalog: InstantiatedCatalog
294
+ secrets: InstantiatedSecrets
295
+ pickler: InstantiatedPickler
296
+ run_log_store: InstantiatedRunLogStore
297
+
298
+ pipeline_definition_file: str
299
+
300
+ @computed_field # type: ignore
301
+ @cached_property
302
+ def dag(self) -> Graph | None:
303
+ """Get the dag."""
304
+ if self.execution_mode == ExecutionMode.YAML:
305
+ return get_pipeline_spec_from_yaml(self.pipeline_definition_file)
306
+ elif self.execution_mode == ExecutionMode.PYTHON:
307
+ return get_pipeline_spec_from_python(self.pipeline_definition_file)
308
+ else:
309
+ raise ValueError(
310
+ f"Invalid execution mode: {self.execution_mode}. Must be 'yaml' or 'python'."
311
+ )
312
+
313
+ @computed_field # type: ignore
314
+ @cached_property
315
+ def dag_hash(self) -> str:
316
+ dag = self.dag
317
+ if not dag:
318
+ return ""
319
+ dag_str = json.dumps(dag.model_dump(), sort_keys=True, ensure_ascii=True)
320
+ return hashlib.sha1(dag_str.encode("utf-8")).hexdigest()
321
+
322
+ def get_node_callable_command(
323
+ self,
324
+ node: BaseNode,
325
+ map_variable: defaults.MapVariableType = None,
326
+ over_write_run_id: str = "",
327
+ log_level: str = "",
328
+ ) -> str:
329
+ run_id = self.run_id
330
+
331
+ if over_write_run_id:
332
+ run_id = over_write_run_id
333
+
334
+ log_level = log_level or logging.getLevelName(logger.getEffectiveLevel())
335
+
336
+ action = (
337
+ f"runnable execute-single-node {run_id} "
338
+ f"{self.pipeline_definition_file} "
339
+ f"{node._command_friendly_name()} "
340
+ f"--log-level {log_level} "
341
+ )
342
+
343
+ # yaml is the default mode
344
+ if self.execution_mode == ExecutionMode.PYTHON:
345
+ action = action + "--mode python "
346
+
347
+ if map_variable:
348
+ action = action + f"--map-variable '{json.dumps(map_variable)}' "
349
+
350
+ if self.configuration_file:
351
+ action = action + f"--config {self.configuration_file} "
352
+
353
+ if self.parameters_file:
354
+ action = action + f"--parameters-file {self.parameters_file} "
355
+
356
+ if self.tag:
357
+ action = action + f"--tag {self.tag}"
358
+
359
+ console.log(
360
+ f"Generated command for node {node._command_friendly_name()}: {action}"
361
+ )
362
+
363
+ return action
364
+
365
+ def get_fan_command(
366
+ self,
367
+ node: BaseNode,
368
+ mode: str,
369
+ run_id: str,
370
+ map_variable: defaults.MapVariableType = None,
371
+ log_level: str = "",
372
+ ) -> str:
373
+ """
374
+ Return the fan "in or out" command for this pipeline context.
375
+
376
+ Args:
377
+ node (BaseNode): The composite node that we are fanning in or out
378
+ mode (str): "in" or "out"
379
+ map_variable (dict, optional): If the node is a map, we have the map variable. Defaults to None.
380
+ log_level (str, optional): Log level. Defaults to "".
381
+
382
+ Returns:
383
+ str: The fan in or out command
384
+ """
385
+ log_level = log_level or logging.getLevelName(logger.getEffectiveLevel())
386
+ action = (
387
+ f"runnable fan {run_id} "
388
+ f"{node._command_friendly_name()} "
389
+ f"{self.pipeline_definition_file} "
390
+ f"{mode} "
391
+ f"--log-level {log_level}"
392
+ )
393
+ if self.configuration_file:
394
+ action += f" --config-file {self.configuration_file}"
395
+ if self.parameters_file:
396
+ action += f" --parameters-file {self.parameters_file}"
397
+ if map_variable:
398
+ action += f" --map-variable '{json.dumps(map_variable)}'"
399
+ if self.execution_mode == ExecutionMode.PYTHON:
400
+ action += " --mode python"
401
+ if self.tag:
402
+ action += f" --tag {self.tag}"
403
+
404
+ console.log(
405
+ f"Generated command for fan {mode} for node {node._command_friendly_name()}: {action}"
406
+ )
407
+ return action
408
+
409
+ def execute(self):
410
+ assert self.dag is not None
411
+
412
+ console.print("Working with context:")
413
+ console.print(run_context)
414
+ console.rule(style="[dark orange]")
415
+
416
+ # Prepare for graph execution
417
+ if self.pipeline_executor._should_setup_run_log_at_traversal:
418
+ self.pipeline_executor._set_up_run_log(exists_ok=False)
419
+
420
+ pipeline_execution_task = progress.add_task(
421
+ "[dark_orange] Starting execution .. ", total=1
422
+ )
423
+
424
+ try:
425
+ progress.start()
426
+ self.pipeline_executor.execute_graph(dag=self.dag)
427
+
428
+ if not self.pipeline_executor._should_setup_run_log_at_traversal:
429
+ # non local executors just traverse the graph and do nothing
430
+ return {}
431
+
432
+ run_log = run_context.run_log_store.get_run_log_by_id(
433
+ run_id=run_context.run_id, full=False
434
+ )
435
+
436
+ if run_log.status == defaults.SUCCESS:
437
+ progress.update(
438
+ pipeline_execution_task,
439
+ description="[green] Success",
440
+ completed=True,
441
+ )
442
+ else:
443
+ progress.update(
444
+ pipeline_execution_task,
445
+ description="[red] Failed",
446
+ completed=True,
447
+ )
448
+ raise exceptions.ExecutionFailedError(run_context.run_id)
449
+ except Exception as e: # noqa: E722
450
+ console.print(e, style=defaults.error_style)
451
+ progress.update(
452
+ pipeline_execution_task,
453
+ description="[red] Errored execution",
454
+ completed=True,
455
+ )
456
+ raise
457
+ finally:
458
+ progress.stop()
459
+
460
+ if self.pipeline_executor._should_setup_run_log_at_traversal:
461
+ return run_context.run_log_store.get_run_log_by_id(
462
+ run_id=run_context.run_id
463
+ )
464
+
465
+
466
+ class JobContext(RunnableContext):
467
+ job_executor: InstantiatedJobExecutor
468
+ catalog: InstantiatedCatalog
469
+ secrets: InstantiatedSecrets
470
+ pickler: InstantiatedPickler
471
+ run_log_store: InstantiatedRunLogStore
472
+
473
+ job_definition_file: str
474
+ catalog_settings: Optional[list[str]] = Field(
475
+ default=None,
476
+ description="Catalog settings to be used for the job.",
477
+ )
478
+
479
+ @computed_field # type: ignore
480
+ @cached_property
481
+ def job(self) -> BaseTaskType:
482
+ """Get the job."""
483
+ return get_job_spec_from_python(self.job_definition_file)
484
+
485
+ def get_job_callable_command(
486
+ self,
487
+ over_write_run_id: str = "",
488
+ ):
489
+ run_id = self.run_id
490
+
491
+ if over_write_run_id:
492
+ run_id = over_write_run_id
493
+
494
+ log_level = logging.getLevelName(logger.getEffectiveLevel())
495
+
496
+ action = (
497
+ f"runnable execute-job {self.job_definition_file} {run_id} "
498
+ f" --log-level {log_level}"
499
+ )
500
+
501
+ if self.configuration_file:
502
+ action = action + f" --config {self.configuration_file}"
503
+
504
+ if self.parameters_file:
505
+ action = action + f" --parameters {self.parameters_file}"
506
+
507
+ if self.tag:
508
+ action = action + f" --tag {self.tag}"
509
+
510
+ return action
511
+
512
+ def execute(self):
513
+ console.print("Working with context:")
514
+ console.print(run_context)
515
+ console.rule(style="[dark orange]")
516
+
517
+ try:
518
+ self.job_executor.submit_job(
519
+ self.job, catalog_settings=self.catalog_settings
520
+ )
521
+ finally:
522
+ self.job_executor.add_task_log_to_catalog("job")
523
+ progress.stop()
524
+
525
+ logger.info(
526
+ "Executing the job from the user. We are still in the caller's compute environment"
527
+ )
528
+
529
+ if self.job_executor._should_setup_run_log_at_traversal:
530
+ return run_context.run_log_store.get_run_log_by_id(
531
+ run_id=run_context.run_id
532
+ )
533
+
534
+
535
+ run_context: PipelineContext | JobContext = None # type: ignore
536
+ progress: Progress = None # type: ignore