tracdap-runtime 0.6.4__py3-none-any.whl → 0.6.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. tracdap/rt/_exec/context.py +556 -36
  2. tracdap/rt/_exec/dev_mode.py +320 -198
  3. tracdap/rt/_exec/engine.py +331 -62
  4. tracdap/rt/_exec/functions.py +151 -22
  5. tracdap/rt/_exec/graph.py +47 -13
  6. tracdap/rt/_exec/graph_builder.py +383 -175
  7. tracdap/rt/_exec/runtime.py +7 -5
  8. tracdap/rt/_impl/config_parser.py +11 -4
  9. tracdap/rt/_impl/data.py +329 -152
  10. tracdap/rt/_impl/ext/__init__.py +13 -0
  11. tracdap/rt/_impl/ext/sql.py +116 -0
  12. tracdap/rt/_impl/ext/storage.py +57 -0
  13. tracdap/rt/_impl/grpc/tracdap/metadata/job_pb2.py +82 -30
  14. tracdap/rt/_impl/grpc/tracdap/metadata/job_pb2.pyi +155 -2
  15. tracdap/rt/_impl/grpc/tracdap/metadata/model_pb2.py +12 -10
  16. tracdap/rt/_impl/grpc/tracdap/metadata/model_pb2.pyi +14 -2
  17. tracdap/rt/_impl/grpc/tracdap/metadata/resource_pb2.py +29 -0
  18. tracdap/rt/_impl/grpc/tracdap/metadata/resource_pb2.pyi +16 -0
  19. tracdap/rt/_impl/models.py +8 -0
  20. tracdap/rt/_impl/static_api.py +29 -0
  21. tracdap/rt/_impl/storage.py +39 -27
  22. tracdap/rt/_impl/util.py +10 -0
  23. tracdap/rt/_impl/validation.py +140 -18
  24. tracdap/rt/_plugins/repo_git.py +1 -1
  25. tracdap/rt/_plugins/storage_sql.py +417 -0
  26. tracdap/rt/_plugins/storage_sql_dialects.py +117 -0
  27. tracdap/rt/_version.py +1 -1
  28. tracdap/rt/api/experimental.py +267 -0
  29. tracdap/rt/api/hook.py +14 -0
  30. tracdap/rt/api/model_api.py +48 -6
  31. tracdap/rt/config/__init__.py +2 -2
  32. tracdap/rt/config/common.py +6 -0
  33. tracdap/rt/metadata/__init__.py +29 -20
  34. tracdap/rt/metadata/job.py +99 -0
  35. tracdap/rt/metadata/model.py +18 -0
  36. tracdap/rt/metadata/resource.py +24 -0
  37. {tracdap_runtime-0.6.4.dist-info → tracdap_runtime-0.6.6.dist-info}/METADATA +5 -1
  38. {tracdap_runtime-0.6.4.dist-info → tracdap_runtime-0.6.6.dist-info}/RECORD +41 -32
  39. {tracdap_runtime-0.6.4.dist-info → tracdap_runtime-0.6.6.dist-info}/WHEEL +1 -1
  40. {tracdap_runtime-0.6.4.dist-info → tracdap_runtime-0.6.6.dist-info}/LICENSE +0 -0
  41. {tracdap_runtime-0.6.4.dist-info → tracdap_runtime-0.6.6.dist-info}/top_level.txt +0 -0
@@ -12,8 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from __future__ import annotations
16
-
17
15
  import copy as cp
18
16
  import dataclasses as dc
19
17
  import enum
@@ -41,8 +39,9 @@ class _EngineNode:
41
39
  """
42
40
 
43
41
  node: _graph.Node
44
- dependencies: tp.Dict[NodeId, _graph.DependencyType]
45
42
  function: tp.Optional[_func.NodeFunction] = None
43
+
44
+ dependencies: tp.Dict[NodeId, _graph.DependencyType] = dc.field(default_factory=dict)
46
45
  complete: bool = False
47
46
  result: tp.Optional[tp.Any] = None
48
47
  error: tp.Optional[str] = None
@@ -59,21 +58,35 @@ class _EngineContext:
59
58
  Represents the state of an execution graph being processed by the TRAC engine
60
59
  """
61
60
 
61
+ engine_id: _actors.ActorId
62
+ job_key: str
63
+ root_id: NodeId
64
+
62
65
  nodes: tp.Dict[NodeId, _EngineNode]
63
66
  pending_nodes: tp.Set[NodeId] = dc.field(default_factory=set)
64
67
  active_nodes: tp.Set[NodeId] = dc.field(default_factory=set)
65
68
  succeeded_nodes: tp.Set[NodeId] = dc.field(default_factory=set)
66
69
  failed_nodes: tp.Set[NodeId] = dc.field(default_factory=set)
67
70
 
71
+ def with_updates(
72
+ self, nodes,
73
+ pending_nodes, active_nodes,
74
+ succeeded_nodes, failed_nodes) -> "_EngineContext":
75
+
76
+ return _EngineContext(
77
+ self.engine_id, self.job_key, self.root_id, nodes,
78
+ pending_nodes, active_nodes, succeeded_nodes, failed_nodes)
79
+
68
80
 
69
81
  @dc.dataclass
70
82
  class _JobState:
71
83
 
72
84
  job_id: _meta.TagHeader
73
- job_config: _cfg.JobConfig
74
-
75
85
  actor_id: _actors.ActorId = None
76
86
 
87
+ monitors: tp.List[_actors.ActorId] = dc.field(default_factory=list)
88
+
89
+ job_config: _cfg.JobConfig = None
77
90
  job_result: _cfg.JobResult = None
78
91
  job_error: Exception = None
79
92
 
@@ -156,14 +169,35 @@ class TracEngine(_actors.Actor):
156
169
 
157
170
  self._log.info(f"Job submitted: [{job_key}]")
158
171
 
159
- job_processor = JobProcessor(job_key, job_config, result_spec,self._models, self._storage)
172
+ job_processor = JobProcessor(self._models, self._storage, job_key, job_config, result_spec, graph_spec=None)
160
173
  job_actor_id = self.actors().spawn(job_processor)
161
174
 
162
- job_state = _JobState(job_config.jobId, job_config)
175
+ job_monitor_success = lambda ctx, key, result: self._notify_callback(key, result, None)
176
+ job_monitor_failure = lambda ctx, key, error: self._notify_callback(key, None, error)
177
+ job_monitor = JobMonitor(job_key, job_monitor_success, job_monitor_failure)
178
+ job_monitor_id = self.actors().spawn(job_monitor)
179
+
180
+ job_state = _JobState(job_config.jobId)
163
181
  job_state.actor_id = job_actor_id
182
+ job_state.monitors.append(job_monitor_id)
183
+ job_state.job_config = job_config
164
184
 
165
185
  self._jobs[job_key] = job_state
166
186
 
187
+ @_actors.Message
188
+ def submit_child_job(self, child_id: _meta.TagHeader, child_graph: _graph.Graph, monitor_id: _actors.ActorId):
189
+
190
+ child_key = _util.object_key(child_id)
191
+
192
+ child_processor = JobProcessor(self._models, self._storage, child_key, None, None, graph_spec=child_graph) # noqa
193
+ child_actor_id = self.actors().spawn(child_processor)
194
+
195
+ child_state = _JobState(child_id)
196
+ child_state.actor_id = child_actor_id
197
+ child_state.monitors.append(monitor_id)
198
+
199
+ self._jobs[child_key] = child_state
200
+
167
201
  @_actors.Message
168
202
  def get_job_list(self):
169
203
 
@@ -186,11 +220,13 @@ class TracEngine(_actors.Actor):
186
220
 
187
221
  self._log.info(f"Recording job as successful: {job_key}")
188
222
 
189
- self._jobs[job_key].job_result = job_result
190
- self._finalize_job(job_key)
223
+ job_state = self._jobs[job_key]
224
+ job_state.job_result = job_result
191
225
 
192
- if self._notify_callback is not None:
193
- self._notify_callback(job_key, job_result, None)
226
+ for monitor_id in job_state.monitors:
227
+ self.actors().send(monitor_id, "job_succeeded", job_result)
228
+
229
+ self._finalize_job(job_key)
194
230
 
195
231
  @_actors.Message
196
232
  def job_failed(self, job_key: str, error: Exception):
@@ -202,11 +238,13 @@ class TracEngine(_actors.Actor):
202
238
 
203
239
  self._log.error(f"Recording job as failed: {job_key}")
204
240
 
205
- self._jobs[job_key].job_error = error
206
- self._finalize_job(job_key)
241
+ job_state = self._jobs[job_key]
242
+ job_state.job_error = error
207
243
 
208
- if self._notify_callback is not None:
209
- self._notify_callback(job_key, None, error)
244
+ for monitor_id in job_state.monitors:
245
+ self.actors().send(monitor_id, "job_failed", error)
246
+
247
+ self._finalize_job(job_key)
210
248
 
211
249
  def _finalize_job(self, job_key: str):
212
250
 
@@ -216,10 +254,17 @@ class TracEngine(_actors.Actor):
216
254
  # For now each instance of the runtime only processes one job so no need to worry
217
255
 
218
256
  job_state = self._jobs.get(job_key)
219
- job_actor_id = job_state.actor_id if job_state is not None else None
220
257
 
221
- if job_actor_id is not None:
222
- self.actors().stop(job_actor_id)
258
+ # Stop any monitors that were created directly by the engine
259
+ # (Other actors are responsible for stopping their own monitors)
260
+ while job_state.monitors:
261
+ monitor_id = job_state.monitors.pop()
262
+ monitor_parent = monitor_id[:monitor_id.rfind('/')]
263
+ if self.actors().id == monitor_parent:
264
+ self.actors().stop(monitor_id)
265
+
266
+ if job_state.actor_id is not None:
267
+ self.actors().stop(job_state.actor_id )
223
268
  job_state.actor_id = None
224
269
 
225
270
  def _get_job_info(self, job_key: str, details: bool = False) -> tp.Optional[_cfg.JobResult]:
@@ -253,6 +298,35 @@ class TracEngine(_actors.Actor):
253
298
  return job_result
254
299
 
255
300
 
301
+ class JobMonitor(_actors.Actor):
302
+
303
+ def __init__(
304
+ self, job_key: str,
305
+ success_func: tp.Callable[[_actors.ActorContext, str, _cfg.JobResult], None],
306
+ failure_func: tp.Callable[[_actors.ActorContext, str, Exception], None]):
307
+
308
+ super().__init__()
309
+ self._job_key = job_key
310
+ self._success_func = success_func
311
+ self._failure_func = failure_func
312
+ self._signal_sent = False
313
+
314
+ @_actors.Message
315
+ def job_succeeded(self, job_result: _cfg.JobResult):
316
+ self._success_func(self.actors(), self._job_key, job_result)
317
+ self._signal_sent = True
318
+
319
+ @_actors.Message
320
+ def job_failed(self, error: Exception):
321
+ self._failure_func(self.actors(), self._job_key, error)
322
+ self._signal_sent = True
323
+
324
+ def on_stop(self):
325
+ if not self._signal_sent:
326
+ error = _ex.ETracInternal(f"No result was received for job [{self._job_key}]")
327
+ self._failure_func(self.actors(), self._job_key, error)
328
+
329
+
256
330
  class JobProcessor(_actors.Actor):
257
331
 
258
332
  """
@@ -261,25 +335,32 @@ class JobProcessor(_actors.Actor):
261
335
  """
262
336
 
263
337
  def __init__(
264
- self, job_key, job_config: _cfg.JobConfig,
265
- result_spec: _graph.JobResultSpec,
266
- models: _models.ModelLoader,
267
- storage: _storage.StorageManager):
338
+ self, models: _models.ModelLoader, storage: _storage.StorageManager,
339
+ job_key: str, job_config: _cfg.JobConfig, result_spec: _graph.JobResultSpec,
340
+ graph_spec: tp.Optional[_graph.Graph]):
268
341
 
269
342
  super().__init__()
270
343
  self.job_key = job_key
271
344
  self.job_config = job_config
272
345
  self.result_spec = result_spec
346
+ self.graph_spec = graph_spec
273
347
  self._models = models
274
348
  self._storage = storage
349
+ self._resolver = _func.FunctionResolver(models, storage)
275
350
  self._log = _util.logger_for_object(self)
276
351
 
277
352
  def on_start(self):
353
+
278
354
  self._log.info(f"Starting job [{self.job_key}]")
279
355
  self._models.create_scope(self.job_key)
280
- self.actors().spawn(GraphBuilder(self.job_config, self.result_spec, self._models, self._storage))
356
+
357
+ if self.graph_spec is not None:
358
+ self.actors().send(self.actors().id, "build_graph_succeeded", self.graph_spec)
359
+ else:
360
+ self.actors().spawn(GraphBuilder(self.job_config, self.result_spec))
281
361
 
282
362
  def on_stop(self):
363
+
283
364
  self._log.info(f"Cleaning up job [{self.job_key}]")
284
365
  self._models.destroy_scope(self.job_key)
285
366
 
@@ -304,9 +385,26 @@ class JobProcessor(_actors.Actor):
304
385
  return super().on_signal(signal)
305
386
 
306
387
  @_actors.Message
307
- def job_graph(self, graph: _EngineContext, root_id: NodeId):
308
- self.actors().spawn(GraphProcessor(graph, root_id))
309
- self.actors().stop(self.actors().sender)
388
+ def build_graph_succeeded(self, graph_spec: _graph.Graph):
389
+
390
+ # Build a new engine context graph from the graph spec
391
+ engine_id = self.actors().parent
392
+ nodes = dict((node_id, _EngineNode(node)) for node_id, node in graph_spec.nodes.items())
393
+ graph = _EngineContext(engine_id, self.job_key, graph_spec.root_id, nodes)
394
+
395
+ # Add all the nodes as pending nodes to start
396
+ graph.pending_nodes.update(graph.nodes.keys())
397
+
398
+ self.actors().spawn(FunctionResolver(self._resolver, graph))
399
+ if self.actors().sender != self.actors().id and self.actors().sender != self.actors().parent:
400
+ self.actors().stop(self.actors().sender)
401
+
402
+ @_actors.Message
403
+ def resolve_functions_succeeded(self, graph: _EngineContext):
404
+
405
+ self.actors().spawn(GraphProcessor(graph, self._resolver))
406
+ if self.actors().sender != self.actors().id and self.actors().sender != self.actors().parent:
407
+ self.actors().stop(self.actors().sender)
310
408
 
311
409
  @_actors.Message
312
410
  def job_succeeded(self, job_result: _cfg.JobResult):
@@ -324,45 +422,54 @@ class JobProcessor(_actors.Actor):
324
422
  class GraphBuilder(_actors.Actor):
325
423
 
326
424
  """
327
- GraphBuilder is a worker (actors.Worker) responsible for building the execution graph for a job
328
- The logic for graph building is provided in graph_builder.py
425
+ GraphBuilder is a worker (actor) to wrap the GraphBuilder logic from graph_builder.py
329
426
  """
330
427
 
331
- def __init__(
332
- self, job_config: _cfg.JobConfig,
333
- result_spec: _graph.JobResultSpec,
334
- models: _models.ModelLoader,
335
- storage: _storage.StorageManager):
336
-
428
+ def __init__(self, job_config: _cfg.JobConfig, result_spec: _graph.JobResultSpec):
337
429
  super().__init__()
338
430
  self.job_config = job_config
339
431
  self.result_spec = result_spec
340
- self.graph: tp.Optional[_EngineContext] = None
341
-
342
- self._resolver = _func.FunctionResolver(models, storage)
343
432
  self._log = _util.logger_for_object(self)
344
433
 
345
434
  def on_start(self):
435
+ self.build_graph(self, self.job_config)
436
+
437
+ @_actors.Message
438
+ def build_graph(self, job_config: _cfg.JobConfig):
346
439
 
347
440
  self._log.info("Building execution graph")
348
441
 
349
442
  # TODO: Get sys config, or find a way to pass storage settings
350
- graph_data = _graph.GraphBuilder.build_job(self.job_config, self.result_spec)
351
- graph_nodes = {node_id: _EngineNode(node, {}) for node_id, node in graph_data.nodes.items()}
352
- graph = _EngineContext(graph_nodes, pending_nodes=set(graph_nodes.keys()))
443
+ graph_builder = _graph.GraphBuilder(job_config, self.result_spec)
444
+ graph_spec = graph_builder.build_job(job_config.job)
353
445
 
354
- self._log.info("Resolving graph nodes to executable code")
446
+ self.actors().reply("build_graph_succeeded", graph_spec)
355
447
 
356
- for node_id, node in graph.nodes.items():
357
- node.function = self._resolver.resolve_node(node.node)
358
448
 
449
+ class FunctionResolver(_actors.Actor):
450
+
451
+ """
452
+ GraphResolver is a worker (actors) to wrap the FunctionResolver logic in functions.py
453
+ """
454
+
455
+ def __init__(self, resolver: _func.FunctionResolver, graph: _EngineContext):
456
+ super().__init__()
359
457
  self.graph = graph
360
- self.actors().send_parent("job_graph", self.graph, graph_data.root_id)
458
+ self._resolver = resolver
459
+ self._log = _util.logger_for_object(self)
460
+
461
+ def on_start(self):
462
+ self.resolve_functions(self, self.graph)
361
463
 
362
464
  @_actors.Message
363
- def get_execution_graph(self):
465
+ def resolve_functions(self, graph: _EngineContext):
466
+
467
+ self._log.info("Resolving graph nodes to executable code")
468
+
469
+ for node_id, node in graph.nodes.items():
470
+ node.function = self._resolver.resolve_node(node.node)
364
471
 
365
- self.actors().send(self.actors().sender, "job_graph", self.graph)
472
+ self.actors().reply("resolve_functions_succeeded", graph)
366
473
 
367
474
 
368
475
  class GraphProcessor(_actors.Actor):
@@ -378,11 +485,12 @@ class GraphProcessor(_actors.Actor):
378
485
  Once all running nodes are stopped, an error is reported to the parent
379
486
  """
380
487
 
381
- def __init__(self, graph: _EngineContext, root_id: NodeId):
488
+ def __init__(self, graph: _EngineContext, resolver: _func.FunctionResolver):
382
489
  super().__init__()
383
490
  self.graph = graph
384
- self.root_id = root_id
491
+ self.root_id_ = graph.root_id
385
492
  self.processors: tp.Dict[NodeId, _actors.ActorId] = dict()
493
+ self._resolver = resolver
386
494
  self._log = _util.logger_for_object(self)
387
495
 
388
496
  def on_start(self):
@@ -428,12 +536,14 @@ class GraphProcessor(_actors.Actor):
428
536
  # Model and data nodes map to different thread pools in the actors engine
429
537
  # There is scope for a much more sophisticated approach, with prioritized scheduling
430
538
 
431
- if isinstance(node.node, _graph.RunModelNode) or isinstance(node.node, _graph.ImportModelNode):
432
- processor = ModelNodeProcessor(processed_graph, node_id, node)
539
+ if isinstance(node.node, _graph.ChildJobNode):
540
+ processor = ChildJobNodeProcessor(processed_graph, node)
541
+ elif isinstance(node.node, _graph.RunModelNode) or isinstance(node.node, _graph.ImportModelNode):
542
+ processor = ModelNodeProcessor(processed_graph, node)
433
543
  elif isinstance(node.node, _graph.LoadDataNode) or isinstance(node.node, _graph.SaveDataNode):
434
- processor = DataNodeProcessor(processed_graph, node_id, node)
544
+ processor = DataNodeProcessor(processed_graph, node)
435
545
  else:
436
- processor = NodeProcessor(processed_graph, node_id, node)
546
+ processor = NodeProcessor(processed_graph, node)
437
547
 
438
548
  # New nodes can be launched with the updated graph
439
549
  # Anything that was pruned is not needed by the new node
@@ -463,6 +573,62 @@ class GraphProcessor(_actors.Actor):
463
573
  # Job may have completed due to error propagation
464
574
  self.check_job_status(do_submit=False)
465
575
 
576
+ @_actors.Message
577
+ def update_graph(
578
+ self, requestor_id: NodeId,
579
+ new_nodes: tp.Dict[NodeId, _graph.Node],
580
+ new_deps: tp.Dict[NodeId, tp.List[_graph.Dependency]]):
581
+
582
+ new_graph = cp.copy(self.graph)
583
+ new_graph.nodes = cp.copy(new_graph.nodes)
584
+
585
+ # Attempt to insert a duplicate node is always an error
586
+ node_collision = list(filter(lambda nid: nid in self.graph.nodes, new_nodes))
587
+
588
+ # Only allow adding deps to pending nodes for now (adding deps to active nodes will require more work)
589
+ dep_collision = list(filter(lambda nid: nid not in self.graph.pending_nodes, new_deps))
590
+
591
+ dep_invalid = list(filter(
592
+ lambda dds: any(filter(lambda dd: dd.node_id not in new_nodes, dds)),
593
+ new_deps.values()))
594
+
595
+ if any(node_collision) or any(dep_collision) or any(dep_invalid):
596
+
597
+ self._log.error(f"Node collision during graph update (requested by {requestor_id})")
598
+ self._log.error(f"Duplicate node IDs: {node_collision or 'None'}")
599
+ self._log.error(f"Dependency updates for dead nodes: {dep_collision or 'None'}")
600
+ self._log.error(f"Dependencies added for existing nodes: {dep_invalid or 'None'}")
601
+
602
+ # Set an error on the node, and wait for it to complete normally
603
+ # The error will be picked up when the result is recorded
604
+ # If dependencies are added for an active node, more signalling will be needed
605
+ requestor = cp.copy(new_graph.nodes[requestor_id])
606
+ requestor.error = _ex.ETracInternal("Node collision during graph update")
607
+ new_graph.nodes[requestor_id] = requestor
608
+
609
+ return
610
+
611
+ new_graph.pending_nodes = cp.copy(new_graph.pending_nodes)
612
+
613
+ for node_id, node in new_nodes.items():
614
+ GraphLogger.log_node_add(node)
615
+ node_func = self._resolver.resolve_node(node)
616
+ new_node = _EngineNode(node, node_func)
617
+ new_graph.nodes[node_id] = new_node
618
+ new_graph.pending_nodes.add(node_id)
619
+
620
+ for node_id, deps in new_deps.items():
621
+ engine_node = cp.copy(new_graph.nodes[node_id])
622
+ engine_node.dependencies = cp.copy(engine_node.dependencies)
623
+ for dep in deps:
624
+ GraphLogger.log_dependency_add(node_id, dep.node_id)
625
+ engine_node.dependencies[dep.node_id] = dep.dependency_type
626
+ new_graph.nodes[node_id] = engine_node
627
+
628
+ self.graph = new_graph
629
+
630
+ self.actors().send(self.actors().id, "submit_viable_nodes")
631
+
466
632
  @classmethod
467
633
  def _is_required_node(cls, node: _EngineNode, graph: _EngineContext):
468
634
 
@@ -570,9 +736,10 @@ class GraphProcessor(_actors.Actor):
570
736
  for node_id in list(filter(lambda n: n.namespace == context_pop, nodes)):
571
737
  nodes.pop(node_id)
572
738
 
573
- graph = _EngineContext(nodes, pending_nodes, active_nodes, succeeded_nodes, failed_nodes)
739
+ self.graph = self.graph.with_updates(
740
+ nodes, pending_nodes, active_nodes,
741
+ succeeded_nodes, failed_nodes)
574
742
 
575
- self.graph = graph
576
743
  self.check_job_status()
577
744
 
578
745
  def check_job_status(self, do_submit=True):
@@ -602,7 +769,7 @@ class GraphProcessor(_actors.Actor):
602
769
  self.actors().send_parent("job_failed", _ex.EModelExec("Job suffered multiple errors", errors))
603
770
 
604
771
  else:
605
- job_result = self.graph.nodes[self.root_id].result
772
+ job_result = self.graph.nodes[self.graph.root_id].result
606
773
  self.actors().send_parent("job_succeeded", job_result)
607
774
 
608
775
 
@@ -614,11 +781,12 @@ class NodeProcessor(_actors.Actor):
614
781
 
615
782
  __NONE_TYPE = type(None)
616
783
 
617
- def __init__(self, graph: _EngineContext, node_id: NodeId, node: _EngineNode):
784
+ def __init__(self, graph: _EngineContext, node: _EngineNode):
618
785
  super().__init__()
619
786
  self.graph = graph
620
- self.node_id = node_id
621
787
  self.node = node
788
+ self.node_id = node.node.id
789
+
622
790
 
623
791
  def on_start(self):
624
792
 
@@ -654,8 +822,15 @@ class NodeProcessor(_actors.Actor):
654
822
 
655
823
  NodeLogger.log_node_start(self.node)
656
824
 
825
+ # Context contains only node states available when the context is set up
657
826
  ctx = NodeContextImpl(self.graph.nodes)
658
- result = self.node.function(ctx)
827
+
828
+ # Callback remains valid because it only lives inside the call stack for this message
829
+ callback = NodeCallbackImpl(self.actors(), self.node_id)
830
+
831
+ # Execute the node function
832
+ result = self.node.function(ctx, callback)
833
+
659
834
  self._check_result_type(result)
660
835
 
661
836
  NodeLogger.log_node_succeeded(self.node)
@@ -720,14 +895,90 @@ class NodeProcessor(_actors.Actor):
720
895
 
721
896
  class ModelNodeProcessor(NodeProcessor):
722
897
 
723
- def __init__(self, graph: _EngineContext, node_id: NodeId, node: _EngineNode):
724
- super().__init__(graph, node_id, node)
898
+ def __init__(self, graph: _EngineContext, node: _EngineNode):
899
+ super().__init__(graph, node)
725
900
 
726
901
 
727
902
  class DataNodeProcessor(NodeProcessor):
728
903
 
729
- def __init__(self, graph: _EngineContext, node_id: NodeId, node: _EngineNode):
730
- super().__init__(graph, node_id, node)
904
+ def __init__(self, graph: _EngineContext, node: _EngineNode):
905
+ super().__init__(graph, node)
906
+
907
+
908
+ class ChildJobNodeProcessor(NodeProcessor):
909
+
910
+ def __init__(self, graph: _EngineContext, node: _EngineNode):
911
+ super().__init__(graph, node)
912
+
913
+ @_actors.Message
914
+ def evaluate_node(self):
915
+
916
+ NodeLogger.log_node_start(self.node)
917
+
918
+ job_id = self.node.node.job_id # noqa
919
+ job_key = _util.object_key(job_id)
920
+
921
+ node_id = self.actors().id
922
+
923
+ def success_callback(ctx, _, result):
924
+ ctx.send(node_id, "child_job_succeeded", result)
925
+
926
+ def failure_callback(ctx, _, error):
927
+ ctx.send(node_id, "child_job_failed", error)
928
+
929
+ monitor = JobMonitor(job_key, success_callback, failure_callback)
930
+ monitor_id = self.actors().spawn(monitor)
931
+
932
+ graph_spec: _graph.Graph = self.node.node.graph # noqa
933
+
934
+ self.actors().send(self.graph.engine_id, "submit_child_job", job_id, graph_spec, monitor_id)
935
+
936
+ @_actors.Message
937
+ def child_job_succeeded(self, job_result: _cfg.JobResult):
938
+
939
+ self._check_result_type(job_result)
940
+
941
+ NodeLogger.log_node_succeeded(self.node)
942
+
943
+ self.actors().send_parent("node_succeeded", self.node_id, job_result)
944
+
945
+ @_actors.Message
946
+ def child_job_failed(self, job_error: Exception):
947
+
948
+ NodeLogger.log_node_failed(self.node, job_error)
949
+
950
+ self.actors().send_parent("node_failed", self.node_id, job_error)
951
+
952
+
953
+ class GraphLogger:
954
+
955
+ """
956
+ Log the activity of the GraphProcessor
957
+ """
958
+
959
+ _log = _util.logger_for_class(GraphProcessor)
960
+
961
+ @classmethod
962
+ def log_node_add(cls, node: _graph.Node):
963
+
964
+ node_name = node.id.name
965
+ namespace = node.id.namespace
966
+
967
+ cls._log.info(f"ADD {cls._func_type(node)} [{node_name}] / {namespace}")
968
+
969
+ @classmethod
970
+ def log_dependency_add(cls, node_id: NodeId, dep_id: NodeId):
971
+
972
+ if node_id.namespace == dep_id.namespace:
973
+ cls._log.info(f"ADD DEPENDENCY [{node_id.name}] -> [{dep_id.name}] / {node_id.namespace}")
974
+ else:
975
+ cls._log.info(f"ADD DEPENDENCY [{node_id.name}] / {node_id.namespace} -> [{dep_id.name}] / {dep_id.namespace}")
976
+
977
+ @classmethod
978
+ def _func_type(cls, node: _graph.Node):
979
+
980
+ func_type = type(node)
981
+ return func_type.__name__[:-4]
731
982
 
732
983
 
733
984
  class NodeLogger:
@@ -912,3 +1163,21 @@ class NodeContextImpl(_func.NodeContext):
912
1163
  for node_id, node in self.__nodes.items():
913
1164
  if node.complete and not node.error:
914
1165
  yield node_id, node.result
1166
+
1167
+
1168
+ class NodeCallbackImpl(_func.NodeCallback):
1169
+
1170
+ """
1171
+ Callback impl is passed to node functions so they can call into the engine
1172
+ It is only valid as long as the node function runs inside the call stack of a single message
1173
+ """
1174
+
1175
+ def __init__(self, actor_ctx: _actors.ActorContext, node_id: NodeId):
1176
+ self.__actor_ctx = actor_ctx
1177
+ self.__node_id = node_id
1178
+
1179
+ def send_graph_updates(
1180
+ self, new_nodes: tp.Dict[NodeId, _graph.Node],
1181
+ new_deps: tp.Dict[NodeId, tp.List[_graph.Dependency]]):
1182
+
1183
+ self.__actor_ctx.send_parent("update_graph", self.__node_id, new_nodes, new_deps)