runnable 0.10.0__py3-none-any.whl → 0.11.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,14 +1,19 @@
1
1
  import copy
2
- import json
3
2
  import logging
4
3
  import os
5
4
  from abc import abstractmethod
6
- from typing import Dict, List, Optional
7
-
8
- from rich import print
9
-
10
- from runnable import context, defaults, exceptions, integration, parameters, utils
11
- from runnable.datastore import DataCatalog, JsonParameter, StepLog
5
+ from typing import Dict, List, Optional, cast
6
+
7
+ from runnable import (
8
+ console,
9
+ context,
10
+ defaults,
11
+ exceptions,
12
+ integration,
13
+ parameters,
14
+ utils,
15
+ )
16
+ from runnable.datastore import DataCatalog, JsonParameter, RunLog, StepLog
12
17
  from runnable.defaults import TypeMapVariable
13
18
  from runnable.executor import BaseExecutor
14
19
  from runnable.extensions.nodes import TaskNode
@@ -58,6 +63,7 @@ class GenericExecutor(BaseExecutor):
58
63
 
59
64
  # Update these with some from the environment variables
60
65
  params.update(parameters.get_user_set_parameters())
66
+ logger.debug(f"parameters as seen by executor: {params}")
61
67
  return params
62
68
 
63
69
  def _set_up_run_log(self, exists_ok=False):
@@ -69,7 +75,7 @@ class GenericExecutor(BaseExecutor):
69
75
  try:
70
76
  attempt_run_log = self._context.run_log_store.get_run_log_by_id(run_id=self._context.run_id, full=False)
71
77
 
72
- logger.warning(f"The run log by id: {self._context.run_id} already exists")
78
+ logger.warning(f"The run log by id: {self._context.run_id} already exists, is this designed?")
73
79
  raise exceptions.RunLogExistsError(
74
80
  f"The run log by id: {self._context.run_id} already exists and is {attempt_run_log.status}"
75
81
  )
@@ -94,6 +100,7 @@ class GenericExecutor(BaseExecutor):
94
100
 
95
101
  # Update run_config
96
102
  run_config = utils.get_run_config()
103
+ logger.debug(f"run_config as seen by executor: {run_config}")
97
104
  self._context.run_log_store.set_run_config(run_id=self._context.run_id, run_config=run_config)
98
105
 
99
106
  def prepare_for_graph_execution(self):
@@ -116,9 +123,6 @@ class GenericExecutor(BaseExecutor):
116
123
  integration.validate(self, self._context.secrets_handler)
117
124
  integration.configure_for_traversal(self, self._context.secrets_handler)
118
125
 
119
- integration.validate(self, self._context.experiment_tracker)
120
- integration.configure_for_traversal(self, self._context.experiment_tracker)
121
-
122
126
  self._set_up_run_log()
123
127
 
124
128
  def prepare_for_node_execution(self):
@@ -138,9 +142,6 @@ class GenericExecutor(BaseExecutor):
138
142
  integration.validate(self, self._context.secrets_handler)
139
143
  integration.configure_for_execution(self, self._context.secrets_handler)
140
144
 
141
- integration.validate(self, self._context.experiment_tracker)
142
- integration.configure_for_execution(self, self._context.experiment_tracker)
143
-
144
145
  def _sync_catalog(self, stage: str, synced_catalogs=None) -> Optional[List[DataCatalog]]:
145
146
  """
146
147
  1). Identify the catalog settings by over-riding node settings with the global settings.
@@ -166,6 +167,7 @@ class GenericExecutor(BaseExecutor):
166
167
  "Catalog service only accepts get/put possible actions as part of node execution."
167
168
  f"Sync catalog of the executor: {self.service_name} asks for {stage} which is not accepted"
168
169
  )
170
+ logger.exception(msg)
169
171
  raise Exception(msg)
170
172
 
171
173
  try:
@@ -183,10 +185,14 @@ class GenericExecutor(BaseExecutor):
183
185
  data_catalogs = []
184
186
  for name_pattern in node_catalog_settings.get(stage) or []:
185
187
  if stage == "get":
188
+ get_catalog_progress = self._context.progress.add_task(f"Getting from catalog {name_pattern}", total=1)
186
189
  data_catalog = self._context.catalog_handler.get(
187
190
  name=name_pattern, run_id=self._context.run_id, compute_data_folder=compute_data_folder
188
191
  )
192
+ self._context.progress.update(get_catalog_progress, completed=True, visible=False, refresh=True)
193
+
189
194
  elif stage == "put":
195
+ put_catalog_progress = self._context.progress.add_task(f"Putting in catalog {name_pattern}", total=1)
190
196
  data_catalog = self._context.catalog_handler.put(
191
197
  name=name_pattern,
192
198
  run_id=self._context.run_id,
@@ -194,7 +200,9 @@ class GenericExecutor(BaseExecutor):
194
200
  synced_catalogs=synced_catalogs,
195
201
  )
196
202
 
197
- logger.info(f"Added data catalog: {data_catalog} to step log")
203
+ self._context.progress.update(put_catalog_progress, completed=True, visible=False)
204
+
205
+ logger.debug(f"Added data catalog: {data_catalog} to step log")
198
206
  data_catalogs.extend(data_catalog)
199
207
 
200
208
  return data_catalogs
@@ -256,6 +264,7 @@ class GenericExecutor(BaseExecutor):
256
264
  self._context_node = node
257
265
 
258
266
  data_catalogs_get: Optional[List[DataCatalog]] = self._sync_catalog(stage="get")
267
+ logger.debug(f"data_catalogs_get: {data_catalogs_get}")
259
268
 
260
269
  step_log = node.execute(
261
270
  map_variable=map_variable,
@@ -263,11 +272,16 @@ class GenericExecutor(BaseExecutor):
263
272
  mock=mock,
264
273
  **kwargs,
265
274
  )
275
+
266
276
  data_catalogs_put: Optional[List[DataCatalog]] = self._sync_catalog(stage="put")
277
+ logger.debug(f"data_catalogs_put: {data_catalogs_put}")
267
278
 
268
279
  step_log.add_data_catalogs(data_catalogs_get or [])
269
280
  step_log.add_data_catalogs(data_catalogs_put or [])
270
281
 
282
+ console.print(f"Summary of the step: {step_log.internal_name}")
283
+ console.print(step_log.get_summary(), style=defaults.info_style)
284
+
271
285
  self._context_node = None # type: ignore
272
286
 
273
287
  self._context.run_log_store.add_step_log(step_log, self._context.run_id)
@@ -318,6 +332,8 @@ class GenericExecutor(BaseExecutor):
318
332
 
319
333
  self._context.run_log_store.add_step_log(step_log, self._context.run_id)
320
334
 
335
+ logger.info(f"Executing node: {node.get_summary()}")
336
+
321
337
  # Add the step log to the database as per the situation.
322
338
  # If its a terminal node, complete it now
323
339
  if node.node_type in ["success", "fail"]:
@@ -329,7 +345,8 @@ class GenericExecutor(BaseExecutor):
329
345
  node.execute_as_graph(map_variable=map_variable, **kwargs)
330
346
  return
331
347
 
332
- # Executor specific way to trigger a job
348
+ task_name = node._resolve_map_placeholders(node.internal_name, map_variable)
349
+ console.print(f":runner: Executing the node {task_name} ... ", style="bold color(208)")
333
350
  self.trigger_job(node=node, map_variable=map_variable, **kwargs)
334
351
 
335
352
  def trigger_job(self, node: BaseNode, map_variable: TypeMapVariable = None, **kwargs):
@@ -405,31 +422,72 @@ class GenericExecutor(BaseExecutor):
405
422
  previous_node = None
406
423
  logger.info(f"Running the execution with {current_node}")
407
424
 
425
+ branch_execution_task = None
426
+ branch_task_name: str = ""
427
+ if dag.internal_branch_name:
428
+ branch_task_name = BaseNode._resolve_map_placeholders(
429
+ dag.internal_branch_name or "Graph",
430
+ map_variable,
431
+ )
432
+ branch_execution_task = self._context.progress.add_task(
433
+ f"[dark_orange]Executing {branch_task_name}",
434
+ total=1,
435
+ )
436
+
408
437
  while True:
409
438
  working_on = dag.get_node_by_name(current_node)
439
+ task_name = working_on._resolve_map_placeholders(working_on.internal_name, map_variable)
410
440
 
411
441
  if previous_node == current_node:
412
442
  raise Exception("Potentially running in a infinite loop")
413
443
 
414
444
  previous_node = current_node
415
445
 
416
- logger.info(f"Creating execution log for {working_on}")
417
- self.execute_from_graph(working_on, map_variable=map_variable, **kwargs)
446
+ logger.debug(f"Creating execution log for {working_on}")
418
447
 
419
- status, next_node_name = self._get_status_and_next_node_name(
420
- current_node=working_on, dag=dag, map_variable=map_variable
421
- )
448
+ depth = " " * ((task_name.count(".")) or 1 - 1)
422
449
 
423
- if status == defaults.TRIGGERED:
424
- # Some nodes go into triggered state and self traverse
425
- logger.info(f"Triggered the job to execute the node {current_node}")
426
- break
450
+ task_execution = self._context.progress.add_task(f"{depth}Executing {task_name}", total=1)
451
+
452
+ try:
453
+ self.execute_from_graph(working_on, map_variable=map_variable, **kwargs)
454
+ status, next_node_name = self._get_status_and_next_node_name(
455
+ current_node=working_on, dag=dag, map_variable=map_variable
456
+ )
457
+
458
+ if status == defaults.SUCCESS:
459
+ self._context.progress.update(
460
+ task_execution,
461
+ description=f"{depth}[green] {task_name} Completed",
462
+ completed=True,
463
+ overflow="fold",
464
+ )
465
+ else:
466
+ self._context.progress.update(
467
+ task_execution, description=f"{depth}[red] {task_name} Failed", completed=True
468
+ ) # type ignore
469
+ except Exception as e: # noqa: E722
470
+ self._context.progress.update(
471
+ task_execution,
472
+ description=f"{depth}[red] {task_name} Errored",
473
+ completed=True,
474
+ )
475
+ console.print(e, style=defaults.error_style)
476
+ logger.exception(e)
477
+ raise
478
+
479
+ console.rule(style="[dark orange]")
427
480
 
428
481
  if working_on.node_type in ["success", "fail"]:
429
482
  break
430
483
 
431
484
  current_node = next_node_name
432
485
 
486
+ if branch_execution_task:
487
+ self._context.progress.update(
488
+ branch_execution_task, description=f"[green3] {branch_task_name} completed", completed=True
489
+ )
490
+
433
491
  run_log = self._context.run_log_store.get_branch_log(
434
492
  working_on._get_branch_log_name(map_variable), self._context.run_id
435
493
  )
@@ -440,10 +498,10 @@ class GenericExecutor(BaseExecutor):
440
498
 
441
499
  logger.info(f"Finished execution of the {branch} with status {run_log.status}")
442
500
 
443
- # get the final run log
444
- if branch == "graph":
445
- run_log = self._context.run_log_store.get_run_log_by_id(run_id=self._context.run_id, full=True)
446
- print(json.dumps(run_log.model_dump(), indent=4))
501
+ if dag == self._context.dag:
502
+ run_log = cast(RunLog, run_log)
503
+ console.print("Completed Execution, Summary:", style="bold color(208)")
504
+ console.print(run_log.get_summary(), style=defaults.info_style)
447
505
 
448
506
  def send_return_code(self, stage="traversal"):
449
507
  """
@@ -7,7 +7,14 @@ from abc import ABC, abstractmethod
7
7
  from collections import OrderedDict
8
8
  from typing import Any, Dict, List, Optional, Union, cast
9
9
 
10
- from pydantic import BaseModel, ConfigDict, Field, computed_field, field_serializer, field_validator
10
+ from pydantic import (
11
+ BaseModel,
12
+ ConfigDict,
13
+ Field,
14
+ computed_field,
15
+ field_serializer,
16
+ field_validator,
17
+ )
11
18
  from pydantic.functional_serializers import PlainSerializer
12
19
  from ruamel.yaml import YAML
13
20
  from typing_extensions import Annotated
@@ -773,9 +780,6 @@ class ArgoExecutor(GenericExecutor):
773
780
  integration.validate(self, self._context.secrets_handler)
774
781
  integration.configure_for_traversal(self, self._context.secrets_handler)
775
782
 
776
- integration.validate(self, self._context.experiment_tracker)
777
- integration.configure_for_traversal(self, self._context.experiment_tracker)
778
-
779
783
  def prepare_for_node_execution(self):
780
784
  """
781
785
  Perform any modifications to the services prior to execution of the node.
@@ -35,6 +35,7 @@ class LocalExecutor(GenericExecutor):
35
35
  node (BaseNode): [description]
36
36
  map_variable (str, optional): [description]. Defaults to ''.
37
37
  """
38
+
38
39
  self.prepare_for_node_execution()
39
40
  self.execute_node(node=node, map_variable=map_variable, **kwargs)
40
41
 
@@ -5,13 +5,19 @@ import sys
5
5
  from collections import OrderedDict
6
6
  from copy import deepcopy
7
7
  from datetime import datetime
8
- from typing import Any, Dict, Optional, cast
9
-
10
- from pydantic import ConfigDict, Field, ValidationInfo, field_serializer, field_validator
8
+ from typing import Any, Dict, List, Optional, Tuple, Union, cast
9
+
10
+ from pydantic import (
11
+ ConfigDict,
12
+ Field,
13
+ ValidationInfo,
14
+ field_serializer,
15
+ field_validator,
16
+ )
11
17
  from typing_extensions import Annotated
12
18
 
13
19
  from runnable import datastore, defaults, utils
14
- from runnable.datastore import JsonParameter, ObjectParameter, StepLog
20
+ from runnable.datastore import JsonParameter, MetricParameter, ObjectParameter, StepLog
15
21
  from runnable.defaults import TypeMapVariable
16
22
  from runnable.graph import Graph, create_graph
17
23
  from runnable.nodes import CompositeNode, ExecutableNode, TerminalNode
@@ -45,6 +51,16 @@ class TaskNode(ExecutableNode):
45
51
  executable = create_task(task_config)
46
52
  return cls(executable=executable, **node_config, **task_config)
47
53
 
54
+ def get_summary(self) -> Dict[str, Any]:
55
+ summary = {
56
+ "name": self.name,
57
+ "type": self.node_type,
58
+ "executable": self.executable.get_summary(),
59
+ "catalog": self._get_catalog_settings(),
60
+ }
61
+
62
+ return summary
63
+
48
64
  def execute(
49
65
  self,
50
66
  mock=False,
@@ -63,9 +79,8 @@ class TaskNode(ExecutableNode):
63
79
  Returns:
64
80
  StepAttempt: The attempt object
65
81
  """
66
- print("Executing task:", self._context.executor._context_node)
67
-
68
82
  step_log = self._context.run_log_store.get_step_log(self._get_step_log_name(map_variable), self._context.run_id)
83
+
69
84
  if not mock:
70
85
  # Do not run if we are mocking the execution, could be useful for caching and dry runs
71
86
  attempt_log = self.executable.execute_command(map_variable=map_variable)
@@ -78,6 +93,9 @@ class TaskNode(ExecutableNode):
78
93
  attempt_number=attempt_number,
79
94
  )
80
95
 
96
+ logger.debug(f"attempt_log: {attempt_log}")
97
+ logger.info(f"Step {self.name} completed with status: {attempt_log.status}")
98
+
81
99
  step_log.status = attempt_log.status
82
100
 
83
101
  step_log.attempts.append(attempt_log)
@@ -96,6 +114,14 @@ class FailNode(TerminalNode):
96
114
  def parse_from_config(cls, config: Dict[str, Any]) -> "FailNode":
97
115
  return cast("FailNode", super().parse_from_config(config))
98
116
 
117
+ def get_summary(self) -> Dict[str, Any]:
118
+ summary = {
119
+ "name": self.name,
120
+ "type": self.node_type,
121
+ }
122
+
123
+ return summary
124
+
99
125
  def execute(
100
126
  self,
101
127
  mock=False,
@@ -118,7 +144,7 @@ class FailNode(TerminalNode):
118
144
  step_log = self._context.run_log_store.get_step_log(self._get_step_log_name(map_variable), self._context.run_id)
119
145
 
120
146
  attempt_log = datastore.StepAttempt(
121
- status=defaults.FAIL,
147
+ status=defaults.SUCCESS,
122
148
  start_time=str(datetime.now()),
123
149
  end_time=str(datetime.now()),
124
150
  attempt_number=attempt_number,
@@ -148,6 +174,14 @@ class SuccessNode(TerminalNode):
148
174
  def parse_from_config(cls, config: Dict[str, Any]) -> "SuccessNode":
149
175
  return cast("SuccessNode", super().parse_from_config(config))
150
176
 
177
+ def get_summary(self) -> Dict[str, Any]:
178
+ summary = {
179
+ "name": self.name,
180
+ "type": self.node_type,
181
+ }
182
+
183
+ return summary
184
+
151
185
  def execute(
152
186
  self,
153
187
  mock=False,
@@ -207,6 +241,15 @@ class ParallelNode(CompositeNode):
207
241
  branches: Dict[str, Graph]
208
242
  is_composite: bool = Field(default=True, exclude=True)
209
243
 
244
+ def get_summary(self) -> Dict[str, Any]:
245
+ summary = {
246
+ "name": self.name,
247
+ "type": self.node_type,
248
+ "branches": [branch.get_summary() for branch in self.branches.values()],
249
+ }
250
+
251
+ return summary
252
+
210
253
  @field_serializer("branches")
211
254
  def ser_branches(self, branches: Dict[str, Graph]) -> Dict[str, Graph]:
212
255
  ret: Dict[str, Graph] = {}
@@ -296,6 +339,7 @@ class ParallelNode(CompositeNode):
296
339
  executor (BaseExecutor): The executor class as defined by the config
297
340
  map_variable (dict, optional): If the node is part of a map. Defaults to None.
298
341
  """
342
+ effective_internal_name = self._resolve_map_placeholders(self.internal_name, map_variable=map_variable)
299
343
  step_success_bool = True
300
344
  for internal_branch_name, _ in self.branches.items():
301
345
  effective_branch_name = self._resolve_map_placeholders(internal_branch_name, map_variable=map_variable)
@@ -304,7 +348,7 @@ class ParallelNode(CompositeNode):
304
348
  step_success_bool = False
305
349
 
306
350
  # Collate all the results and update the status of the step
307
- effective_internal_name = self._resolve_map_placeholders(self.internal_name, map_variable=map_variable)
351
+
308
352
  step_log = self._context.run_log_store.get_step_log(effective_internal_name, self._context.run_id)
309
353
 
310
354
  if step_success_bool: #  If none failed
@@ -335,10 +379,24 @@ class MapNode(CompositeNode):
335
379
  node_type: str = Field(default="map", serialization_alias="type")
336
380
  iterate_on: str
337
381
  iterate_as: str
382
+ iterate_index: bool = Field(default=False) # TODO: Need to design this
338
383
  reducer: Optional[str] = Field(default=None)
339
384
  branch: Graph
340
385
  is_composite: bool = True
341
386
 
387
+ def get_summary(self) -> Dict[str, Any]:
388
+ summary = {
389
+ "name": self.name,
390
+ "type": self.node_type,
391
+ "branch": self.branch.get_summary(),
392
+ "iterate_on": self.iterate_on,
393
+ "iterate_as": self.iterate_as,
394
+ "iterate_index": self.iterate_index,
395
+ "reducer": self.reducer,
396
+ }
397
+
398
+ return summary
399
+
342
400
  def get_reducer_function(self):
343
401
  if not self.reducer:
344
402
  return lambda *x: list(x) # returns a list of the args
@@ -375,12 +433,12 @@ class MapNode(CompositeNode):
375
433
 
376
434
  @property
377
435
  def branch_returns(self):
378
- branch_returns = []
436
+ branch_returns: List[Tuple[str, Union[ObjectParameter, MetricParameter, JsonParameter]]] = []
379
437
  for _, node in self.branch.nodes.items():
380
438
  if isinstance(node, TaskNode):
381
439
  for task_return in node.executable.returns:
382
440
  if task_return.kind == "json":
383
- branch_returns.append((task_return.name, JsonParameter(kind="json", value=None, reduced=False)))
441
+ branch_returns.append((task_return.name, JsonParameter(kind="json", value="", reduced=False)))
384
442
  elif task_return.kind == "object":
385
443
  branch_returns.append(
386
444
  (
@@ -390,7 +448,11 @@ class MapNode(CompositeNode):
390
448
  value="Will be reduced",
391
449
  reduced=False,
392
450
  ),
393
- ) # type: ignore
451
+ )
452
+ )
453
+ elif task_return.kind == "metric":
454
+ branch_returns.append(
455
+ (task_return.name, MetricParameter(kind="metric", value="", reduced=False))
394
456
  )
395
457
  else:
396
458
  raise Exception("kind should be either json or object")
@@ -513,6 +575,7 @@ class MapNode(CompositeNode):
513
575
  iterate_on = params[self.iterate_on].get_value()
514
576
  # # Find status of the branches
515
577
  step_success_bool = True
578
+ effective_internal_name = self._resolve_map_placeholders(self.internal_name, map_variable=map_variable)
516
579
 
517
580
  for iter_variable in iterate_on:
518
581
  effective_branch_name = self._resolve_map_placeholders(
@@ -523,7 +586,6 @@ class MapNode(CompositeNode):
523
586
  step_success_bool = False
524
587
 
525
588
  # Collate all the results and update the status of the step
526
- effective_internal_name = self._resolve_map_placeholders(self.internal_name, map_variable=map_variable)
527
589
  step_log = self._context.run_log_store.get_step_log(effective_internal_name, self._context.run_id)
528
590
 
529
591
  if step_success_bool: #  If none failed and nothing is waiting
@@ -580,6 +642,13 @@ class DagNode(CompositeNode):
580
642
  is_composite: bool = True
581
643
  internal_branch_name: Annotated[str, Field(validate_default=True)] = ""
582
644
 
645
+ def get_summary(self) -> Dict[str, Any]:
646
+ summary = {
647
+ "name": self.name,
648
+ "type": self.node_type,
649
+ }
650
+ return summary
651
+
583
652
  @field_validator("internal_branch_name")
584
653
  @classmethod
585
654
  def validate_internal_branch_name(cls, internal_branch_name: str, info: ValidationInfo):
@@ -711,7 +780,15 @@ class StubNode(ExecutableNode):
711
780
  """
712
781
 
713
782
  node_type: str = Field(default="stub", serialization_alias="type")
714
- model_config = ConfigDict(extra="allow")
783
+ model_config = ConfigDict(extra="ignore")
784
+
785
+ def get_summary(self) -> Dict[str, Any]:
786
+ summary = {
787
+ "name": self.name,
788
+ "type": self.node_type,
789
+ }
790
+
791
+ return summary
715
792
 
716
793
  @classmethod
717
794
  def parse_from_config(cls, config: Dict[str, Any]) -> "StubNode":
@@ -2,7 +2,7 @@ import json
2
2
  import logging
3
3
  from pathlib import Path
4
4
  from string import Template
5
- from typing import Optional, Sequence, Union
5
+ from typing import Any, Dict, Optional, Sequence, Union
6
6
 
7
7
  from runnable import defaults, utils
8
8
  from runnable.extensions.run_log_store.generic_chunked import ChunkedRunLogStore
@@ -21,6 +21,11 @@ class ChunkedFileSystemRunLogStore(ChunkedRunLogStore):
21
21
  service_name: str = "chunked-fs"
22
22
  log_folder: str = defaults.LOG_LOCATION_FOLDER
23
23
 
24
+ def get_summary(self) -> Dict[str, Any]:
25
+ summary = {"Type": self.service_name, "Location": self.log_folder}
26
+
27
+ return summary
28
+
24
29
  def get_matches(self, run_id: str, name: str, multiple_allowed: bool = False) -> Optional[Union[Sequence[T], T]]:
25
30
  """
26
31
  Get contents of files matching the pattern name*
@@ -1,6 +1,7 @@
1
1
  import json
2
2
  import logging
3
3
  from pathlib import Path
4
+ from typing import Any, Dict
4
5
 
5
6
  from runnable import defaults, exceptions, utils
6
7
  from runnable.datastore import BaseRunLogStore, RunLog
@@ -37,6 +38,11 @@ class FileSystemRunLogstore(BaseRunLogStore):
37
38
  def log_folder_name(self):
38
39
  return self.log_folder
39
40
 
41
+ def get_summary(self) -> Dict[str, Any]:
42
+ summary = {"Type": self.service_name, "Location": self.log_folder}
43
+
44
+ return summary
45
+
40
46
  def write_to_folder(self, run_log: RunLog):
41
47
  """
42
48
  Write the run log to the folder
runnable/graph.py CHANGED
@@ -26,6 +26,17 @@ class Graph(BaseModel):
26
26
  internal_branch_name: str = Field(default="", exclude=True)
27
27
  nodes: SerializeAsAny[Dict[str, "BaseNode"]] = Field(default_factory=dict, serialization_alias="steps")
28
28
 
29
+ def get_summary(self) -> Dict[str, Any]:
30
+ """
31
+ Return a summary of the graph
32
+ """
33
+ return {
34
+ "name": self.name,
35
+ "description": self.description,
36
+ "start_at": self.start_at,
37
+ "nodes": [node.get_summary() for node in list(self.nodes.values())],
38
+ }
39
+
29
40
  def get_node_by_name(self, name: str) -> "BaseNode":
30
41
  """
31
42
  Return the Node object by the name
runnable/integration.py CHANGED
@@ -102,7 +102,7 @@ def get_integration_handler(executor: "BaseExecutor", service: object) -> BaseIn
102
102
  raise Exception(msg)
103
103
 
104
104
  if not integrations:
105
- logger.warning(
105
+ logger.info(
106
106
  f"Could not find an integration pattern for {executor.service_name} and {service_name} for {service_type}."
107
107
  " This implies that there is no need to change the configurations."
108
108
  )
@@ -163,7 +163,7 @@ class BufferedRunLogStore(BaseIntegration):
163
163
  "Run log generated by buffered run log store are not persisted. "
164
164
  "Re-running this run, in case of a failure, is not possible"
165
165
  )
166
- logger.warning(msg)
166
+ logger.info(msg)
167
167
 
168
168
 
169
169
  class DoNothingCatalog(BaseIntegration):
@@ -176,7 +176,7 @@ class DoNothingCatalog(BaseIntegration):
176
176
 
177
177
  def validate(self, **kwargs):
178
178
  msg = "A do-nothing catalog does not hold any data and therefore cannot pass data between nodes."
179
- logger.warning(msg)
179
+ logger.info(msg)
180
180
 
181
181
 
182
182
  class DoNothingSecrets(BaseIntegration):
@@ -189,17 +189,4 @@ class DoNothingSecrets(BaseIntegration):
189
189
 
190
190
  def validate(self, **kwargs):
191
191
  msg = "A do-nothing secrets does not hold any secrets and therefore cannot return you any secrets."
192
- logger.warning(msg)
193
-
194
-
195
- class DoNothingExperimentTracker(BaseIntegration):
196
- """
197
- Integration between any executor and do nothing experiment tracker
198
- """
199
-
200
- service_type = "experiment_tracker" # One of secret, catalog, datastore
201
- service_provider = "do-nothing" # The actual implementation of the service
202
-
203
- def validate(self, **kwargs):
204
- msg = "A do-nothing experiment tracker does nothing and therefore cannot track anything."
205
- logger.warning(msg)
192
+ logger.info(msg)
runnable/nodes.py CHANGED
@@ -362,6 +362,15 @@ class BaseNode(ABC, BaseModel):
362
362
  """
363
363
  ...
364
364
 
365
+ @abstractmethod
366
+ def get_summary(self) -> Dict[str, Any]:
367
+ """
368
+ Return the summary of the node
369
+
370
+ Returns:
371
+ Dict[str, Any]: _description_
372
+ """
373
+
365
374
 
366
375
  # --8<-- [end:docs]
367
376
  class TraversalNode(BaseNode):
runnable/parameters.py CHANGED
@@ -4,6 +4,7 @@ import logging
4
4
  import os
5
5
  from typing import Any, Dict, Type
6
6
 
7
+ import pydantic
7
8
  from pydantic import BaseModel, ConfigDict
8
9
  from typing_extensions import Callable
9
10
 
@@ -99,7 +100,7 @@ def filter_arguments_for_func(
99
100
  # default value is given in the function signature, nothing further to do.
100
101
  continue
101
102
 
102
- if issubclass(value.annotation, BaseModel):
103
+ if type(value.annotation) in [BaseModel, pydantic._internal._model_construction.ModelMetaclass]:
103
104
  # We try to cast it as a pydantic model if asked
104
105
  named_param = params[name].get_value()
105
106
 
@@ -110,6 +111,7 @@ def filter_arguments_for_func(
110
111
  bound_model = bind_args_for_pydantic_model(named_param, value.annotation)
111
112
  bound_args[name] = bound_model
112
113
  unassigned_params = unassigned_params.difference(bound_model.model_fields.keys())
114
+
113
115
  elif value.annotation in [str, int, float, bool]:
114
116
  # Cast it if its a primitive type. Ensure the type matches the annotation.
115
117
  bound_args[name] = value.annotation(params[name].get_value())