runnable 0.9.1__py3-none-any.whl → 0.11.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,14 +1,19 @@
1
1
  import copy
2
- import json
3
2
  import logging
4
3
  import os
5
4
  from abc import abstractmethod
6
- from typing import Dict, List, Optional
7
-
8
- from rich import print
9
-
10
- from runnable import context, defaults, exceptions, integration, parameters, utils
11
- from runnable.datastore import DataCatalog, JsonParameter, StepLog
5
+ from typing import Dict, List, Optional, cast
6
+
7
+ from runnable import (
8
+ console,
9
+ context,
10
+ defaults,
11
+ exceptions,
12
+ integration,
13
+ parameters,
14
+ utils,
15
+ )
16
+ from runnable.datastore import DataCatalog, JsonParameter, RunLog, StepLog
12
17
  from runnable.defaults import TypeMapVariable
13
18
  from runnable.executor import BaseExecutor
14
19
  from runnable.extensions.nodes import TaskNode
@@ -58,6 +63,7 @@ class GenericExecutor(BaseExecutor):
58
63
 
59
64
  # Update these with some from the environment variables
60
65
  params.update(parameters.get_user_set_parameters())
66
+ logger.debug(f"parameters as seen by executor: {params}")
61
67
  return params
62
68
 
63
69
  def _set_up_run_log(self, exists_ok=False):
@@ -69,7 +75,7 @@ class GenericExecutor(BaseExecutor):
69
75
  try:
70
76
  attempt_run_log = self._context.run_log_store.get_run_log_by_id(run_id=self._context.run_id, full=False)
71
77
 
72
- logger.warning(f"The run log by id: {self._context.run_id} already exists")
78
+ logger.warning(f"The run log by id: {self._context.run_id} already exists, is this designed?")
73
79
  raise exceptions.RunLogExistsError(
74
80
  f"The run log by id: {self._context.run_id} already exists and is {attempt_run_log.status}"
75
81
  )
@@ -94,6 +100,7 @@ class GenericExecutor(BaseExecutor):
94
100
 
95
101
  # Update run_config
96
102
  run_config = utils.get_run_config()
103
+ logger.debug(f"run_config as seen by executor: {run_config}")
97
104
  self._context.run_log_store.set_run_config(run_id=self._context.run_id, run_config=run_config)
98
105
 
99
106
  def prepare_for_graph_execution(self):
@@ -116,9 +123,6 @@ class GenericExecutor(BaseExecutor):
116
123
  integration.validate(self, self._context.secrets_handler)
117
124
  integration.configure_for_traversal(self, self._context.secrets_handler)
118
125
 
119
- integration.validate(self, self._context.experiment_tracker)
120
- integration.configure_for_traversal(self, self._context.experiment_tracker)
121
-
122
126
  self._set_up_run_log()
123
127
 
124
128
  def prepare_for_node_execution(self):
@@ -138,9 +142,6 @@ class GenericExecutor(BaseExecutor):
138
142
  integration.validate(self, self._context.secrets_handler)
139
143
  integration.configure_for_execution(self, self._context.secrets_handler)
140
144
 
141
- integration.validate(self, self._context.experiment_tracker)
142
- integration.configure_for_execution(self, self._context.experiment_tracker)
143
-
144
145
  def _sync_catalog(self, stage: str, synced_catalogs=None) -> Optional[List[DataCatalog]]:
145
146
  """
146
147
  1). Identify the catalog settings by over-riding node settings with the global settings.
@@ -166,6 +167,7 @@ class GenericExecutor(BaseExecutor):
166
167
  "Catalog service only accepts get/put possible actions as part of node execution."
167
168
  f"Sync catalog of the executor: {self.service_name} asks for {stage} which is not accepted"
168
169
  )
170
+ logger.exception(msg)
169
171
  raise Exception(msg)
170
172
 
171
173
  try:
@@ -183,10 +185,14 @@ class GenericExecutor(BaseExecutor):
183
185
  data_catalogs = []
184
186
  for name_pattern in node_catalog_settings.get(stage) or []:
185
187
  if stage == "get":
188
+ get_catalog_progress = self._context.progress.add_task(f"Getting from catalog {name_pattern}", total=1)
186
189
  data_catalog = self._context.catalog_handler.get(
187
190
  name=name_pattern, run_id=self._context.run_id, compute_data_folder=compute_data_folder
188
191
  )
192
+ self._context.progress.update(get_catalog_progress, completed=True, visible=False, refresh=True)
193
+
189
194
  elif stage == "put":
195
+ put_catalog_progress = self._context.progress.add_task(f"Putting in catalog {name_pattern}", total=1)
190
196
  data_catalog = self._context.catalog_handler.put(
191
197
  name=name_pattern,
192
198
  run_id=self._context.run_id,
@@ -194,7 +200,9 @@ class GenericExecutor(BaseExecutor):
194
200
  synced_catalogs=synced_catalogs,
195
201
  )
196
202
 
197
- logger.info(f"Added data catalog: {data_catalog} to step log")
203
+ self._context.progress.update(put_catalog_progress, completed=True, visible=False)
204
+
205
+ logger.debug(f"Added data catalog: {data_catalog} to step log")
198
206
  data_catalogs.extend(data_catalog)
199
207
 
200
208
  return data_catalogs
@@ -256,6 +264,7 @@ class GenericExecutor(BaseExecutor):
256
264
  self._context_node = node
257
265
 
258
266
  data_catalogs_get: Optional[List[DataCatalog]] = self._sync_catalog(stage="get")
267
+ logger.debug(f"data_catalogs_get: {data_catalogs_get}")
259
268
 
260
269
  step_log = node.execute(
261
270
  map_variable=map_variable,
@@ -263,11 +272,16 @@ class GenericExecutor(BaseExecutor):
263
272
  mock=mock,
264
273
  **kwargs,
265
274
  )
275
+
266
276
  data_catalogs_put: Optional[List[DataCatalog]] = self._sync_catalog(stage="put")
277
+ logger.debug(f"data_catalogs_put: {data_catalogs_put}")
267
278
 
268
279
  step_log.add_data_catalogs(data_catalogs_get or [])
269
280
  step_log.add_data_catalogs(data_catalogs_put or [])
270
281
 
282
+ console.print(f"Summary of the step: {step_log.internal_name}")
283
+ console.print(step_log.get_summary(), style=defaults.info_style)
284
+
271
285
  self._context_node = None # type: ignore
272
286
 
273
287
  self._context.run_log_store.add_step_log(step_log, self._context.run_id)
@@ -318,6 +332,8 @@ class GenericExecutor(BaseExecutor):
318
332
 
319
333
  self._context.run_log_store.add_step_log(step_log, self._context.run_id)
320
334
 
335
+ logger.info(f"Executing node: {node.get_summary()}")
336
+
321
337
  # Add the step log to the database as per the situation.
322
338
  # If its a terminal node, complete it now
323
339
  if node.node_type in ["success", "fail"]:
@@ -329,7 +345,8 @@ class GenericExecutor(BaseExecutor):
329
345
  node.execute_as_graph(map_variable=map_variable, **kwargs)
330
346
  return
331
347
 
332
- # Executor specific way to trigger a job
348
+ task_name = node._resolve_map_placeholders(node.internal_name, map_variable)
349
+ console.print(f":runner: Executing the node {task_name} ... ", style="bold color(208)")
333
350
  self.trigger_job(node=node, map_variable=map_variable, **kwargs)
334
351
 
335
352
  def trigger_job(self, node: BaseNode, map_variable: TypeMapVariable = None, **kwargs):
@@ -405,31 +422,70 @@ class GenericExecutor(BaseExecutor):
405
422
  previous_node = None
406
423
  logger.info(f"Running the execution with {current_node}")
407
424
 
425
+ branch_execution_task = None
426
+ branch_task_name: str = ""
427
+ if dag.internal_branch_name:
428
+ branch_task_name = BaseNode._resolve_map_placeholders(
429
+ dag.internal_branch_name or "Graph",
430
+ map_variable,
431
+ )
432
+ branch_execution_task = self._context.progress.add_task(
433
+ f"[dark_orange]Executing {branch_task_name}",
434
+ total=1,
435
+ )
436
+
408
437
  while True:
409
438
  working_on = dag.get_node_by_name(current_node)
439
+ task_name = working_on._resolve_map_placeholders(working_on.internal_name, map_variable)
410
440
 
411
441
  if previous_node == current_node:
412
442
  raise Exception("Potentially running in a infinite loop")
413
443
 
414
444
  previous_node = current_node
415
445
 
416
- logger.info(f"Creating execution log for {working_on}")
417
- self.execute_from_graph(working_on, map_variable=map_variable, **kwargs)
446
+ logger.debug(f"Creating execution log for {working_on}")
418
447
 
419
- status, next_node_name = self._get_status_and_next_node_name(
420
- current_node=working_on, dag=dag, map_variable=map_variable
421
- )
448
+ depth = " " * ((task_name.count(".")) or 1 - 1)
422
449
 
423
- if status == defaults.TRIGGERED:
424
- # Some nodes go into triggered state and self traverse
425
- logger.info(f"Triggered the job to execute the node {current_node}")
426
- break
450
+ task_execution = self._context.progress.add_task(f"{depth}Executing {task_name}", total=1)
451
+
452
+ try:
453
+ self.execute_from_graph(working_on, map_variable=map_variable, **kwargs)
454
+ status, next_node_name = self._get_status_and_next_node_name(
455
+ current_node=working_on, dag=dag, map_variable=map_variable
456
+ )
457
+
458
+ if status == defaults.SUCCESS:
459
+ self._context.progress.update(
460
+ task_execution,
461
+ description=f"{depth}[green] {task_name} Completed",
462
+ completed=True,
463
+ overflow="fold",
464
+ )
465
+ else:
466
+ self._context.progress.update(
467
+ task_execution, description=f"{depth}[red] {task_name} Failed", completed=True
468
+ ) # type ignore
469
+ except Exception as e: # noqa: E722
470
+ self._context.progress.update(
471
+ task_execution,
472
+ description=f"{depth}[red] {task_name} Errored",
473
+ completed=True,
474
+ )
475
+ console.print(e, style=defaults.error_style)
476
+ logger.exception(e)
477
+ raise
427
478
 
428
479
  if working_on.node_type in ["success", "fail"]:
429
480
  break
430
481
 
431
482
  current_node = next_node_name
432
483
 
484
+ if branch_execution_task:
485
+ self._context.progress.update(
486
+ branch_execution_task, description=f"[green3] {branch_task_name} completed", completed=True
487
+ )
488
+
433
489
  run_log = self._context.run_log_store.get_branch_log(
434
490
  working_on._get_branch_log_name(map_variable), self._context.run_id
435
491
  )
@@ -440,10 +496,10 @@ class GenericExecutor(BaseExecutor):
440
496
 
441
497
  logger.info(f"Finished execution of the {branch} with status {run_log.status}")
442
498
 
443
- # get the final run log
444
- if branch == "graph":
445
- run_log = self._context.run_log_store.get_run_log_by_id(run_id=self._context.run_id, full=True)
446
- print(json.dumps(run_log.model_dump(), indent=4))
499
+ if dag == self._context.dag:
500
+ run_log = cast(RunLog, run_log)
501
+ console.print("Completed Execution, Summary:", style="bold color(208)")
502
+ console.print(run_log.get_summary(), style=defaults.info_style)
447
503
 
448
504
  def send_return_code(self, stage="traversal"):
449
505
  """
@@ -7,7 +7,14 @@ from abc import ABC, abstractmethod
7
7
  from collections import OrderedDict
8
8
  from typing import Any, Dict, List, Optional, Union, cast
9
9
 
10
- from pydantic import BaseModel, ConfigDict, Field, computed_field, field_serializer, field_validator
10
+ from pydantic import (
11
+ BaseModel,
12
+ ConfigDict,
13
+ Field,
14
+ computed_field,
15
+ field_serializer,
16
+ field_validator,
17
+ )
11
18
  from pydantic.functional_serializers import PlainSerializer
12
19
  from ruamel.yaml import YAML
13
20
  from typing_extensions import Annotated
@@ -773,9 +780,6 @@ class ArgoExecutor(GenericExecutor):
773
780
  integration.validate(self, self._context.secrets_handler)
774
781
  integration.configure_for_traversal(self, self._context.secrets_handler)
775
782
 
776
- integration.validate(self, self._context.experiment_tracker)
777
- integration.configure_for_traversal(self, self._context.experiment_tracker)
778
-
779
783
  def prepare_for_node_execution(self):
780
784
  """
781
785
  Perform any modifications to the services prior to execution of the node.
@@ -35,6 +35,7 @@ class LocalExecutor(GenericExecutor):
35
35
  node (BaseNode): [description]
36
36
  map_variable (str, optional): [description]. Defaults to ''.
37
37
  """
38
+
38
39
  self.prepare_for_node_execution()
39
40
  self.execute_node(node=node, map_variable=map_variable, **kwargs)
40
41
 
@@ -5,13 +5,19 @@ import sys
5
5
  from collections import OrderedDict
6
6
  from copy import deepcopy
7
7
  from datetime import datetime
8
- from typing import Any, Dict, Optional, cast
9
-
10
- from pydantic import ConfigDict, Field, ValidationInfo, field_serializer, field_validator
8
+ from typing import Any, Dict, List, Optional, Tuple, Union, cast
9
+
10
+ from pydantic import (
11
+ ConfigDict,
12
+ Field,
13
+ ValidationInfo,
14
+ field_serializer,
15
+ field_validator,
16
+ )
11
17
  from typing_extensions import Annotated
12
18
 
13
19
  from runnable import datastore, defaults, utils
14
- from runnable.datastore import JsonParameter, ObjectParameter, StepLog
20
+ from runnable.datastore import JsonParameter, MetricParameter, ObjectParameter, StepLog
15
21
  from runnable.defaults import TypeMapVariable
16
22
  from runnable.graph import Graph, create_graph
17
23
  from runnable.nodes import CompositeNode, ExecutableNode, TerminalNode
@@ -45,6 +51,16 @@ class TaskNode(ExecutableNode):
45
51
  executable = create_task(task_config)
46
52
  return cls(executable=executable, **node_config, **task_config)
47
53
 
54
+ def get_summary(self) -> Dict[str, Any]:
55
+ summary = {
56
+ "name": self.name,
57
+ "type": self.node_type,
58
+ "executable": self.executable.get_summary(),
59
+ "catalog": self._get_catalog_settings(),
60
+ }
61
+
62
+ return summary
63
+
48
64
  def execute(
49
65
  self,
50
66
  mock=False,
@@ -63,9 +79,8 @@ class TaskNode(ExecutableNode):
63
79
  Returns:
64
80
  StepAttempt: The attempt object
65
81
  """
66
- print("Executing task:", self._context.executor._context_node)
67
-
68
82
  step_log = self._context.run_log_store.get_step_log(self._get_step_log_name(map_variable), self._context.run_id)
83
+
69
84
  if not mock:
70
85
  # Do not run if we are mocking the execution, could be useful for caching and dry runs
71
86
  attempt_log = self.executable.execute_command(map_variable=map_variable)
@@ -78,6 +93,9 @@ class TaskNode(ExecutableNode):
78
93
  attempt_number=attempt_number,
79
94
  )
80
95
 
96
+ logger.debug(f"attempt_log: {attempt_log}")
97
+ logger.info(f"Step {self.name} completed with status: {attempt_log.status}")
98
+
81
99
  step_log.status = attempt_log.status
82
100
 
83
101
  step_log.attempts.append(attempt_log)
@@ -96,6 +114,14 @@ class FailNode(TerminalNode):
96
114
  def parse_from_config(cls, config: Dict[str, Any]) -> "FailNode":
97
115
  return cast("FailNode", super().parse_from_config(config))
98
116
 
117
+ def get_summary(self) -> Dict[str, Any]:
118
+ summary = {
119
+ "name": self.name,
120
+ "type": self.node_type,
121
+ }
122
+
123
+ return summary
124
+
99
125
  def execute(
100
126
  self,
101
127
  mock=False,
@@ -118,7 +144,7 @@ class FailNode(TerminalNode):
118
144
  step_log = self._context.run_log_store.get_step_log(self._get_step_log_name(map_variable), self._context.run_id)
119
145
 
120
146
  attempt_log = datastore.StepAttempt(
121
- status=defaults.FAIL,
147
+ status=defaults.SUCCESS,
122
148
  start_time=str(datetime.now()),
123
149
  end_time=str(datetime.now()),
124
150
  attempt_number=attempt_number,
@@ -148,6 +174,14 @@ class SuccessNode(TerminalNode):
148
174
  def parse_from_config(cls, config: Dict[str, Any]) -> "SuccessNode":
149
175
  return cast("SuccessNode", super().parse_from_config(config))
150
176
 
177
+ def get_summary(self) -> Dict[str, Any]:
178
+ summary = {
179
+ "name": self.name,
180
+ "type": self.node_type,
181
+ }
182
+
183
+ return summary
184
+
151
185
  def execute(
152
186
  self,
153
187
  mock=False,
@@ -207,6 +241,15 @@ class ParallelNode(CompositeNode):
207
241
  branches: Dict[str, Graph]
208
242
  is_composite: bool = Field(default=True, exclude=True)
209
243
 
244
+ def get_summary(self) -> Dict[str, Any]:
245
+ summary = {
246
+ "name": self.name,
247
+ "type": self.node_type,
248
+ "branches": [branch.get_summary() for branch in self.branches.values()],
249
+ }
250
+
251
+ return summary
252
+
210
253
  @field_serializer("branches")
211
254
  def ser_branches(self, branches: Dict[str, Graph]) -> Dict[str, Graph]:
212
255
  ret: Dict[str, Graph] = {}
@@ -296,6 +339,7 @@ class ParallelNode(CompositeNode):
296
339
  executor (BaseExecutor): The executor class as defined by the config
297
340
  map_variable (dict, optional): If the node is part of a map. Defaults to None.
298
341
  """
342
+ effective_internal_name = self._resolve_map_placeholders(self.internal_name, map_variable=map_variable)
299
343
  step_success_bool = True
300
344
  for internal_branch_name, _ in self.branches.items():
301
345
  effective_branch_name = self._resolve_map_placeholders(internal_branch_name, map_variable=map_variable)
@@ -304,7 +348,7 @@ class ParallelNode(CompositeNode):
304
348
  step_success_bool = False
305
349
 
306
350
  # Collate all the results and update the status of the step
307
- effective_internal_name = self._resolve_map_placeholders(self.internal_name, map_variable=map_variable)
351
+
308
352
  step_log = self._context.run_log_store.get_step_log(effective_internal_name, self._context.run_id)
309
353
 
310
354
  if step_success_bool: #  If none failed
@@ -335,10 +379,24 @@ class MapNode(CompositeNode):
335
379
  node_type: str = Field(default="map", serialization_alias="type")
336
380
  iterate_on: str
337
381
  iterate_as: str
382
+ iterate_index: bool = Field(default=False) # TODO: Need to design this
338
383
  reducer: Optional[str] = Field(default=None)
339
384
  branch: Graph
340
385
  is_composite: bool = True
341
386
 
387
+ def get_summary(self) -> Dict[str, Any]:
388
+ summary = {
389
+ "name": self.name,
390
+ "type": self.node_type,
391
+ "branch": self.branch.get_summary(),
392
+ "iterate_on": self.iterate_on,
393
+ "iterate_as": self.iterate_as,
394
+ "iterate_index": self.iterate_index,
395
+ "reducer": self.reducer,
396
+ }
397
+
398
+ return summary
399
+
342
400
  def get_reducer_function(self):
343
401
  if not self.reducer:
344
402
  return lambda *x: list(x) # returns a list of the args
@@ -375,12 +433,12 @@ class MapNode(CompositeNode):
375
433
 
376
434
  @property
377
435
  def branch_returns(self):
378
- branch_returns = []
436
+ branch_returns: List[Tuple[str, Union[ObjectParameter, MetricParameter, JsonParameter]]] = []
379
437
  for _, node in self.branch.nodes.items():
380
438
  if isinstance(node, TaskNode):
381
439
  for task_return in node.executable.returns:
382
440
  if task_return.kind == "json":
383
- branch_returns.append((task_return.name, JsonParameter(kind="json", value=None, reduced=False)))
441
+ branch_returns.append((task_return.name, JsonParameter(kind="json", value="", reduced=False)))
384
442
  elif task_return.kind == "object":
385
443
  branch_returns.append(
386
444
  (
@@ -390,7 +448,11 @@ class MapNode(CompositeNode):
390
448
  value="Will be reduced",
391
449
  reduced=False,
392
450
  ),
393
- ) # type: ignore
451
+ )
452
+ )
453
+ elif task_return.kind == "metric":
454
+ branch_returns.append(
455
+ (task_return.name, MetricParameter(kind="metric", value="", reduced=False))
394
456
  )
395
457
  else:
396
458
  raise Exception("kind should be either json or object")
@@ -513,6 +575,7 @@ class MapNode(CompositeNode):
513
575
  iterate_on = params[self.iterate_on].get_value()
514
576
  # # Find status of the branches
515
577
  step_success_bool = True
578
+ effective_internal_name = self._resolve_map_placeholders(self.internal_name, map_variable=map_variable)
516
579
 
517
580
  for iter_variable in iterate_on:
518
581
  effective_branch_name = self._resolve_map_placeholders(
@@ -523,7 +586,6 @@ class MapNode(CompositeNode):
523
586
  step_success_bool = False
524
587
 
525
588
  # Collate all the results and update the status of the step
526
- effective_internal_name = self._resolve_map_placeholders(self.internal_name, map_variable=map_variable)
527
589
  step_log = self._context.run_log_store.get_step_log(effective_internal_name, self._context.run_id)
528
590
 
529
591
  if step_success_bool: #  If none failed and nothing is waiting
@@ -580,6 +642,13 @@ class DagNode(CompositeNode):
580
642
  is_composite: bool = True
581
643
  internal_branch_name: Annotated[str, Field(validate_default=True)] = ""
582
644
 
645
+ def get_summary(self) -> Dict[str, Any]:
646
+ summary = {
647
+ "name": self.name,
648
+ "type": self.node_type,
649
+ }
650
+ return summary
651
+
583
652
  @field_validator("internal_branch_name")
584
653
  @classmethod
585
654
  def validate_internal_branch_name(cls, internal_branch_name: str, info: ValidationInfo):
@@ -711,7 +780,15 @@ class StubNode(ExecutableNode):
711
780
  """
712
781
 
713
782
  node_type: str = Field(default="stub", serialization_alias="type")
714
- model_config = ConfigDict(extra="allow")
783
+ model_config = ConfigDict(extra="ignore")
784
+
785
+ def get_summary(self) -> Dict[str, Any]:
786
+ summary = {
787
+ "name": self.name,
788
+ "type": self.node_type,
789
+ }
790
+
791
+ return summary
715
792
 
716
793
  @classmethod
717
794
  def parse_from_config(cls, config: Dict[str, Any]) -> "StubNode":
@@ -2,7 +2,7 @@ import json
2
2
  import logging
3
3
  from pathlib import Path
4
4
  from string import Template
5
- from typing import Optional, Sequence, Union
5
+ from typing import Any, Dict, Optional, Sequence, Union
6
6
 
7
7
  from runnable import defaults, utils
8
8
  from runnable.extensions.run_log_store.generic_chunked import ChunkedRunLogStore
@@ -21,6 +21,11 @@ class ChunkedFileSystemRunLogStore(ChunkedRunLogStore):
21
21
  service_name: str = "chunked-fs"
22
22
  log_folder: str = defaults.LOG_LOCATION_FOLDER
23
23
 
24
+ def get_summary(self) -> Dict[str, Any]:
25
+ summary = {"Type": self.service_name, "Location": self.log_folder}
26
+
27
+ return summary
28
+
24
29
  def get_matches(self, run_id: str, name: str, multiple_allowed: bool = False) -> Optional[Union[Sequence[T], T]]:
25
30
  """
26
31
  Get contents of files matching the pattern name*
@@ -1,6 +1,7 @@
1
1
  import json
2
2
  import logging
3
3
  from pathlib import Path
4
+ from typing import Any, Dict
4
5
 
5
6
  from runnable import defaults, exceptions, utils
6
7
  from runnable.datastore import BaseRunLogStore, RunLog
@@ -37,6 +38,11 @@ class FileSystemRunLogstore(BaseRunLogStore):
37
38
  def log_folder_name(self):
38
39
  return self.log_folder
39
40
 
41
+ def get_summary(self) -> Dict[str, Any]:
42
+ summary = {"Type": self.service_name, "Location": self.log_folder}
43
+
44
+ return summary
45
+
40
46
  def write_to_folder(self, run_log: RunLog):
41
47
  """
42
48
  Write the run log to the folder
runnable/graph.py CHANGED
@@ -26,6 +26,17 @@ class Graph(BaseModel):
26
26
  internal_branch_name: str = Field(default="", exclude=True)
27
27
  nodes: SerializeAsAny[Dict[str, "BaseNode"]] = Field(default_factory=dict, serialization_alias="steps")
28
28
 
29
+ def get_summary(self) -> Dict[str, Any]:
30
+ """
31
+ Return a summary of the graph
32
+ """
33
+ return {
34
+ "name": self.name,
35
+ "description": self.description,
36
+ "start_at": self.start_at,
37
+ "nodes": [node.get_summary() for node in list(self.nodes.values())],
38
+ }
39
+
29
40
  def get_node_by_name(self, name: str) -> "BaseNode":
30
41
  """
31
42
  Return the Node object by the name
runnable/integration.py CHANGED
@@ -102,7 +102,7 @@ def get_integration_handler(executor: "BaseExecutor", service: object) -> BaseIn
102
102
  raise Exception(msg)
103
103
 
104
104
  if not integrations:
105
- logger.warning(
105
+ logger.info(
106
106
  f"Could not find an integration pattern for {executor.service_name} and {service_name} for {service_type}."
107
107
  " This implies that there is no need to change the configurations."
108
108
  )
@@ -163,7 +163,7 @@ class BufferedRunLogStore(BaseIntegration):
163
163
  "Run log generated by buffered run log store are not persisted. "
164
164
  "Re-running this run, in case of a failure, is not possible"
165
165
  )
166
- logger.warning(msg)
166
+ logger.info(msg)
167
167
 
168
168
 
169
169
  class DoNothingCatalog(BaseIntegration):
@@ -176,7 +176,7 @@ class DoNothingCatalog(BaseIntegration):
176
176
 
177
177
  def validate(self, **kwargs):
178
178
  msg = "A do-nothing catalog does not hold any data and therefore cannot pass data between nodes."
179
- logger.warning(msg)
179
+ logger.info(msg)
180
180
 
181
181
 
182
182
  class DoNothingSecrets(BaseIntegration):
@@ -189,17 +189,4 @@ class DoNothingSecrets(BaseIntegration):
189
189
 
190
190
  def validate(self, **kwargs):
191
191
  msg = "A do-nothing secrets does not hold any secrets and therefore cannot return you any secrets."
192
- logger.warning(msg)
193
-
194
-
195
- class DoNothingExperimentTracker(BaseIntegration):
196
- """
197
- Integration between any executor and do nothing experiment tracker
198
- """
199
-
200
- service_type = "experiment_tracker" # One of secret, catalog, datastore
201
- service_provider = "do-nothing" # The actual implementation of the service
202
-
203
- def validate(self, **kwargs):
204
- msg = "A do-nothing experiment tracker does nothing and therefore cannot track anything."
205
- logger.warning(msg)
192
+ logger.info(msg)
runnable/nodes.py CHANGED
@@ -362,6 +362,15 @@ class BaseNode(ABC, BaseModel):
362
362
  """
363
363
  ...
364
364
 
365
+ @abstractmethod
366
+ def get_summary(self) -> Dict[str, Any]:
367
+ """
368
+ Return the summary of the node
369
+
370
+ Returns:
371
+ Dict[str, Any]: _description_
372
+ """
373
+
365
374
 
366
375
  # --8<-- [end:docs]
367
376
  class TraversalNode(BaseNode):
runnable/parameters.py CHANGED
@@ -4,6 +4,7 @@ import logging
4
4
  import os
5
5
  from typing import Any, Dict, Type
6
6
 
7
+ import pydantic
7
8
  from pydantic import BaseModel, ConfigDict
8
9
  from typing_extensions import Callable
9
10
 
@@ -99,7 +100,7 @@ def filter_arguments_for_func(
99
100
  # default value is given in the function signature, nothing further to do.
100
101
  continue
101
102
 
102
- if issubclass(value.annotation, BaseModel):
103
+ if type(value.annotation) in [BaseModel, pydantic._internal._model_construction.ModelMetaclass]:
103
104
  # We try to cast it as a pydantic model if asked
104
105
  named_param = params[name].get_value()
105
106
 
@@ -110,6 +111,7 @@ def filter_arguments_for_func(
110
111
  bound_model = bind_args_for_pydantic_model(named_param, value.annotation)
111
112
  bound_args[name] = bound_model
112
113
  unassigned_params = unassigned_params.difference(bound_model.model_fields.keys())
114
+
113
115
  elif value.annotation in [str, int, float, bool]:
114
116
  # Cast it if its a primitive type. Ensure the type matches the annotation.
115
117
  bound_args[name] = value.annotation(params[name].get_value())