mlrun 1.10.0rc18__py3-none-any.whl → 1.11.0rc16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mlrun might be problematic. Click here for more details.
- mlrun/__init__.py +24 -3
- mlrun/__main__.py +0 -4
- mlrun/artifacts/dataset.py +2 -2
- mlrun/artifacts/document.py +6 -1
- mlrun/artifacts/llm_prompt.py +21 -15
- mlrun/artifacts/model.py +3 -3
- mlrun/artifacts/plots.py +1 -1
- mlrun/{model_monitoring/db/tsdb/tdengine → auth}/__init__.py +2 -3
- mlrun/auth/nuclio.py +89 -0
- mlrun/auth/providers.py +429 -0
- mlrun/auth/utils.py +415 -0
- mlrun/common/constants.py +14 -0
- mlrun/common/model_monitoring/helpers.py +123 -0
- mlrun/common/runtimes/constants.py +28 -0
- mlrun/common/schemas/__init__.py +14 -3
- mlrun/common/schemas/alert.py +2 -2
- mlrun/common/schemas/api_gateway.py +3 -0
- mlrun/common/schemas/auth.py +12 -10
- mlrun/common/schemas/client_spec.py +4 -0
- mlrun/common/schemas/constants.py +25 -0
- mlrun/common/schemas/frontend_spec.py +1 -8
- mlrun/common/schemas/function.py +34 -0
- mlrun/common/schemas/hub.py +33 -20
- mlrun/common/schemas/model_monitoring/__init__.py +2 -1
- mlrun/common/schemas/model_monitoring/constants.py +12 -15
- mlrun/common/schemas/model_monitoring/functions.py +13 -4
- mlrun/common/schemas/model_monitoring/model_endpoints.py +11 -0
- mlrun/common/schemas/pipeline.py +1 -1
- mlrun/common/schemas/secret.py +17 -2
- mlrun/common/secrets.py +95 -1
- mlrun/common/types.py +10 -10
- mlrun/config.py +69 -19
- mlrun/data_types/infer.py +2 -2
- mlrun/datastore/__init__.py +12 -5
- mlrun/datastore/azure_blob.py +162 -47
- mlrun/datastore/base.py +274 -10
- mlrun/datastore/datastore.py +7 -2
- mlrun/datastore/datastore_profile.py +84 -22
- mlrun/datastore/model_provider/huggingface_provider.py +225 -41
- mlrun/datastore/model_provider/mock_model_provider.py +87 -0
- mlrun/datastore/model_provider/model_provider.py +206 -74
- mlrun/datastore/model_provider/openai_provider.py +226 -66
- mlrun/datastore/s3.py +39 -18
- mlrun/datastore/sources.py +1 -1
- mlrun/datastore/store_resources.py +4 -4
- mlrun/datastore/storeytargets.py +17 -12
- mlrun/datastore/targets.py +1 -1
- mlrun/datastore/utils.py +25 -6
- mlrun/datastore/v3io.py +1 -1
- mlrun/db/base.py +63 -32
- mlrun/db/httpdb.py +373 -153
- mlrun/db/nopdb.py +54 -21
- mlrun/errors.py +4 -2
- mlrun/execution.py +66 -25
- mlrun/feature_store/api.py +1 -1
- mlrun/feature_store/common.py +1 -1
- mlrun/feature_store/feature_vector_utils.py +1 -1
- mlrun/feature_store/steps.py +8 -6
- mlrun/frameworks/_common/utils.py +3 -3
- mlrun/frameworks/_dl_common/loggers/logger.py +1 -1
- mlrun/frameworks/_dl_common/loggers/tensorboard_logger.py +2 -1
- mlrun/frameworks/_ml_common/loggers/mlrun_logger.py +1 -1
- mlrun/frameworks/_ml_common/utils.py +2 -1
- mlrun/frameworks/auto_mlrun/auto_mlrun.py +4 -3
- mlrun/frameworks/lgbm/mlrun_interfaces/mlrun_interface.py +2 -1
- mlrun/frameworks/onnx/dataset.py +2 -1
- mlrun/frameworks/onnx/mlrun_interface.py +2 -1
- mlrun/frameworks/pytorch/callbacks/logging_callback.py +5 -4
- mlrun/frameworks/pytorch/callbacks/mlrun_logging_callback.py +2 -1
- mlrun/frameworks/pytorch/callbacks/tensorboard_logging_callback.py +2 -1
- mlrun/frameworks/pytorch/utils.py +2 -1
- mlrun/frameworks/sklearn/metric.py +2 -1
- mlrun/frameworks/tf_keras/callbacks/logging_callback.py +5 -4
- mlrun/frameworks/tf_keras/callbacks/mlrun_logging_callback.py +2 -1
- mlrun/frameworks/tf_keras/callbacks/tensorboard_logging_callback.py +2 -1
- mlrun/hub/__init__.py +52 -0
- mlrun/hub/base.py +142 -0
- mlrun/hub/module.py +172 -0
- mlrun/hub/step.py +113 -0
- mlrun/k8s_utils.py +105 -16
- mlrun/launcher/base.py +15 -7
- mlrun/launcher/local.py +4 -1
- mlrun/model.py +14 -4
- mlrun/model_monitoring/__init__.py +0 -1
- mlrun/model_monitoring/api.py +65 -28
- mlrun/model_monitoring/applications/__init__.py +1 -1
- mlrun/model_monitoring/applications/base.py +299 -128
- mlrun/model_monitoring/applications/context.py +2 -4
- mlrun/model_monitoring/controller.py +132 -58
- mlrun/model_monitoring/db/_schedules.py +38 -29
- mlrun/model_monitoring/db/_stats.py +6 -16
- mlrun/model_monitoring/db/tsdb/__init__.py +9 -7
- mlrun/model_monitoring/db/tsdb/base.py +29 -9
- mlrun/model_monitoring/db/tsdb/preaggregate.py +234 -0
- mlrun/model_monitoring/db/tsdb/stream_graph_steps.py +63 -0
- mlrun/model_monitoring/db/tsdb/timescaledb/queries/timescaledb_metrics_queries.py +414 -0
- mlrun/model_monitoring/db/tsdb/timescaledb/queries/timescaledb_predictions_queries.py +376 -0
- mlrun/model_monitoring/db/tsdb/timescaledb/queries/timescaledb_results_queries.py +590 -0
- mlrun/model_monitoring/db/tsdb/timescaledb/timescaledb_connection.py +434 -0
- mlrun/model_monitoring/db/tsdb/timescaledb/timescaledb_connector.py +541 -0
- mlrun/model_monitoring/db/tsdb/timescaledb/timescaledb_operations.py +808 -0
- mlrun/model_monitoring/db/tsdb/timescaledb/timescaledb_schema.py +502 -0
- mlrun/model_monitoring/db/tsdb/timescaledb/timescaledb_stream.py +163 -0
- mlrun/model_monitoring/db/tsdb/timescaledb/timescaledb_stream_graph_steps.py +60 -0
- mlrun/model_monitoring/db/tsdb/timescaledb/utils/timescaledb_dataframe_processor.py +141 -0
- mlrun/model_monitoring/db/tsdb/timescaledb/utils/timescaledb_query_builder.py +585 -0
- mlrun/model_monitoring/db/tsdb/timescaledb/writer_graph_steps.py +73 -0
- mlrun/model_monitoring/db/tsdb/v3io/stream_graph_steps.py +20 -9
- mlrun/model_monitoring/db/tsdb/v3io/v3io_connector.py +235 -51
- mlrun/model_monitoring/features_drift_table.py +2 -1
- mlrun/model_monitoring/helpers.py +30 -6
- mlrun/model_monitoring/stream_processing.py +34 -28
- mlrun/model_monitoring/writer.py +224 -4
- mlrun/package/__init__.py +2 -1
- mlrun/platforms/__init__.py +0 -43
- mlrun/platforms/iguazio.py +8 -4
- mlrun/projects/operations.py +17 -11
- mlrun/projects/pipelines.py +2 -2
- mlrun/projects/project.py +187 -123
- mlrun/run.py +95 -21
- mlrun/runtimes/__init__.py +2 -186
- mlrun/runtimes/base.py +103 -25
- mlrun/runtimes/constants.py +225 -0
- mlrun/runtimes/daskjob.py +5 -2
- mlrun/runtimes/databricks_job/databricks_runtime.py +2 -1
- mlrun/runtimes/local.py +5 -2
- mlrun/runtimes/mounts.py +20 -2
- mlrun/runtimes/nuclio/__init__.py +12 -7
- mlrun/runtimes/nuclio/api_gateway.py +36 -6
- mlrun/runtimes/nuclio/application/application.py +339 -40
- mlrun/runtimes/nuclio/function.py +222 -72
- mlrun/runtimes/nuclio/serving.py +132 -42
- mlrun/runtimes/pod.py +213 -21
- mlrun/runtimes/utils.py +49 -9
- mlrun/secrets.py +99 -14
- mlrun/serving/__init__.py +2 -0
- mlrun/serving/remote.py +84 -11
- mlrun/serving/routers.py +26 -44
- mlrun/serving/server.py +138 -51
- mlrun/serving/serving_wrapper.py +6 -2
- mlrun/serving/states.py +997 -283
- mlrun/serving/steps.py +62 -0
- mlrun/serving/system_steps.py +149 -95
- mlrun/serving/v2_serving.py +9 -10
- mlrun/track/trackers/mlflow_tracker.py +29 -31
- mlrun/utils/helpers.py +292 -94
- mlrun/utils/http.py +9 -2
- mlrun/utils/notifications/notification/base.py +18 -0
- mlrun/utils/notifications/notification/git.py +3 -5
- mlrun/utils/notifications/notification/mail.py +39 -16
- mlrun/utils/notifications/notification/slack.py +2 -4
- mlrun/utils/notifications/notification/webhook.py +2 -5
- mlrun/utils/notifications/notification_pusher.py +3 -3
- mlrun/utils/version/version.json +2 -2
- mlrun/utils/version/version.py +3 -4
- {mlrun-1.10.0rc18.dist-info → mlrun-1.11.0rc16.dist-info}/METADATA +63 -74
- {mlrun-1.10.0rc18.dist-info → mlrun-1.11.0rc16.dist-info}/RECORD +161 -143
- mlrun/api/schemas/__init__.py +0 -259
- mlrun/db/auth_utils.py +0 -152
- mlrun/model_monitoring/db/tsdb/tdengine/schemas.py +0 -344
- mlrun/model_monitoring/db/tsdb/tdengine/stream_graph_steps.py +0 -75
- mlrun/model_monitoring/db/tsdb/tdengine/tdengine_connection.py +0 -281
- mlrun/model_monitoring/db/tsdb/tdengine/tdengine_connector.py +0 -1266
- {mlrun-1.10.0rc18.dist-info → mlrun-1.11.0rc16.dist-info}/WHEEL +0 -0
- {mlrun-1.10.0rc18.dist-info → mlrun-1.11.0rc16.dist-info}/entry_points.txt +0 -0
- {mlrun-1.10.0rc18.dist-info → mlrun-1.11.0rc16.dist-info}/licenses/LICENSE +0 -0
- {mlrun-1.10.0rc18.dist-info → mlrun-1.11.0rc16.dist-info}/top_level.txt +0 -0
mlrun/serving/states.py
CHANGED
|
@@ -24,12 +24,15 @@ import inspect
|
|
|
24
24
|
import os
|
|
25
25
|
import pathlib
|
|
26
26
|
import traceback
|
|
27
|
+
import warnings
|
|
27
28
|
from abc import ABC
|
|
29
|
+
from collections.abc import Collection
|
|
28
30
|
from copy import copy, deepcopy
|
|
29
31
|
from inspect import getfullargspec, signature
|
|
30
32
|
from typing import Any, Optional, Union, cast
|
|
31
33
|
|
|
32
34
|
import storey.utils
|
|
35
|
+
from deprecated import deprecated
|
|
33
36
|
from storey import ParallelExecutionMechanisms
|
|
34
37
|
|
|
35
38
|
import mlrun
|
|
@@ -38,17 +41,21 @@ import mlrun.common.schemas as schemas
|
|
|
38
41
|
from mlrun.artifacts.llm_prompt import LLMPromptArtifact, PlaceholderDefaultDict
|
|
39
42
|
from mlrun.artifacts.model import ModelArtifact
|
|
40
43
|
from mlrun.datastore.datastore_profile import (
|
|
41
|
-
|
|
44
|
+
DatastoreProfileKafkaStream,
|
|
42
45
|
DatastoreProfileKafkaTarget,
|
|
43
46
|
DatastoreProfileV3io,
|
|
44
47
|
datastore_profile_read,
|
|
45
48
|
)
|
|
46
|
-
from mlrun.datastore.model_provider.model_provider import
|
|
49
|
+
from mlrun.datastore.model_provider.model_provider import (
|
|
50
|
+
InvokeResponseFormat,
|
|
51
|
+
ModelProvider,
|
|
52
|
+
UsageResponseKeys,
|
|
53
|
+
)
|
|
47
54
|
from mlrun.datastore.storeytargets import KafkaStoreyTarget, StreamStoreyTarget
|
|
48
|
-
from mlrun.utils import get_data_from_path, logger, split_path
|
|
55
|
+
from mlrun.utils import get_data_from_path, logger, set_data_by_path, split_path
|
|
49
56
|
|
|
50
57
|
from ..config import config
|
|
51
|
-
from ..datastore import get_stream_pusher
|
|
58
|
+
from ..datastore import _DummyStream, get_stream_pusher
|
|
52
59
|
from ..datastore.utils import (
|
|
53
60
|
get_kafka_brokers_from_dict,
|
|
54
61
|
parse_kafka_url,
|
|
@@ -85,25 +92,6 @@ class StepKinds:
|
|
|
85
92
|
model_runner = "model_runner"
|
|
86
93
|
|
|
87
94
|
|
|
88
|
-
_task_step_fields = [
|
|
89
|
-
"kind",
|
|
90
|
-
"class_name",
|
|
91
|
-
"class_args",
|
|
92
|
-
"handler",
|
|
93
|
-
"skip_context",
|
|
94
|
-
"after",
|
|
95
|
-
"function",
|
|
96
|
-
"comment",
|
|
97
|
-
"shape",
|
|
98
|
-
"full_event",
|
|
99
|
-
"on_error",
|
|
100
|
-
"responder",
|
|
101
|
-
"input_path",
|
|
102
|
-
"result_path",
|
|
103
|
-
"model_endpoint_creation_strategy",
|
|
104
|
-
"endpoint_type",
|
|
105
|
-
]
|
|
106
|
-
|
|
107
95
|
_default_fields_to_strip_from_step = [
|
|
108
96
|
"model_endpoint_creation_strategy",
|
|
109
97
|
"endpoint_type",
|
|
@@ -129,7 +117,14 @@ def new_remote_endpoint(
|
|
|
129
117
|
class BaseStep(ModelObj):
|
|
130
118
|
kind = "BaseStep"
|
|
131
119
|
default_shape = "ellipse"
|
|
132
|
-
_dict_fields = [
|
|
120
|
+
_dict_fields = [
|
|
121
|
+
"kind",
|
|
122
|
+
"comment",
|
|
123
|
+
"after",
|
|
124
|
+
"on_error",
|
|
125
|
+
"max_iterations",
|
|
126
|
+
"cycle_from",
|
|
127
|
+
]
|
|
133
128
|
_default_fields_to_strip = _default_fields_to_strip_from_step
|
|
134
129
|
|
|
135
130
|
def __init__(
|
|
@@ -137,6 +132,7 @@ class BaseStep(ModelObj):
|
|
|
137
132
|
name: Optional[str] = None,
|
|
138
133
|
after: Optional[list] = None,
|
|
139
134
|
shape: Optional[str] = None,
|
|
135
|
+
max_iterations: Optional[int] = None,
|
|
140
136
|
):
|
|
141
137
|
self.name = name
|
|
142
138
|
self._parent = None
|
|
@@ -150,6 +146,8 @@ class BaseStep(ModelObj):
|
|
|
150
146
|
self.model_endpoint_creation_strategy = (
|
|
151
147
|
schemas.ModelEndpointCreationStrategy.SKIP
|
|
152
148
|
)
|
|
149
|
+
self._max_iterations = max_iterations
|
|
150
|
+
self.cycle_from = []
|
|
153
151
|
|
|
154
152
|
def get_shape(self):
|
|
155
153
|
"""graphviz shape"""
|
|
@@ -343,6 +341,8 @@ class BaseStep(ModelObj):
|
|
|
343
341
|
model_endpoint_creation_strategy: Optional[
|
|
344
342
|
schemas.ModelEndpointCreationStrategy
|
|
345
343
|
] = None,
|
|
344
|
+
cycle_to: Optional[list[str]] = None,
|
|
345
|
+
max_iterations: Optional[int] = None,
|
|
346
346
|
**class_args,
|
|
347
347
|
):
|
|
348
348
|
"""add a step right after this step and return the new step
|
|
@@ -372,21 +372,17 @@ class BaseStep(ModelObj):
|
|
|
372
372
|
to event["y"] resulting in {"x": 5, "y": <result>}
|
|
373
373
|
:param model_endpoint_creation_strategy: Strategy for creating or updating the model endpoint:
|
|
374
374
|
|
|
375
|
-
* **overwrite**:
|
|
376
|
-
|
|
377
|
-
1. If model endpoints with the same name exist, delete the `latest` one.
|
|
378
|
-
2. Create a new model endpoint entry and set it as `latest`.
|
|
375
|
+
* **overwrite**: If model endpoints with the same name exist, delete the `latest` one;
|
|
376
|
+
create a new model endpoint entry and set it as `latest`.
|
|
379
377
|
|
|
380
|
-
* **inplace** (default):
|
|
378
|
+
* **inplace** (default): If model endpoints with the same name exist, update the `latest`
|
|
379
|
+
entry; otherwise, create a new entry.
|
|
381
380
|
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
* **archive**:
|
|
386
|
-
|
|
387
|
-
1. If model endpoints with the same name exist, preserve them.
|
|
388
|
-
2. Create a new model endpoint with the same name and set it to `latest`.
|
|
381
|
+
* **archive**: If model endpoints with the same name exist, preserve them;
|
|
382
|
+
create a new model endpoint with the same name and set it to `latest`.
|
|
389
383
|
|
|
384
|
+
:param cycle_to: list of step names to create a cycle to (for cyclic graphs)
|
|
385
|
+
:param max_iterations: maximum number of iterations for this step in case of a cycle graph
|
|
390
386
|
:param class_args: class init arguments
|
|
391
387
|
"""
|
|
392
388
|
if hasattr(self, "steps"):
|
|
@@ -421,8 +417,39 @@ class BaseStep(ModelObj):
|
|
|
421
417
|
# check that its not the root, todo: in future may gave nested flows
|
|
422
418
|
step.after_step(self.name)
|
|
423
419
|
parent._last_added = step
|
|
420
|
+
step.cycle_to(cycle_to or [])
|
|
421
|
+
step._max_iterations = max_iterations
|
|
424
422
|
return step
|
|
425
423
|
|
|
424
|
+
def cycle_to(self, step_names: Union[str, list[str]]):
|
|
425
|
+
"""create a cycle in the graph to the specified step names
|
|
426
|
+
|
|
427
|
+
example:
|
|
428
|
+
in the below example, a cycle is created from 'step3' to 'step1':
|
|
429
|
+
graph.to('step1')\
|
|
430
|
+
.to('step2')\
|
|
431
|
+
.to('step3')\
|
|
432
|
+
.cycle_to(['step1']) # creates a cycle from step3 to step1
|
|
433
|
+
|
|
434
|
+
:param step_names: list of step names to create a cycle to (for cyclic graphs)
|
|
435
|
+
"""
|
|
436
|
+
root = self._extract_root_step()
|
|
437
|
+
if not isinstance(root, RootFlowStep):
|
|
438
|
+
raise GraphError("cycle_to() can only be called on a step within a graph")
|
|
439
|
+
if not root.allow_cyclic and step_names:
|
|
440
|
+
raise GraphError("cyclic graphs are not allowed, enable allow_cyclic")
|
|
441
|
+
step_names = [step_names] if isinstance(step_names, str) else step_names
|
|
442
|
+
|
|
443
|
+
for step_name in step_names:
|
|
444
|
+
if step_name not in root:
|
|
445
|
+
raise GraphError(
|
|
446
|
+
f"step {step_name} doesnt exist in the graph under {self._parent.fullname}"
|
|
447
|
+
)
|
|
448
|
+
root[step_name].after_step(self.name, append=True)
|
|
449
|
+
root[step_name].cycle_from.append(self.name)
|
|
450
|
+
|
|
451
|
+
return self
|
|
452
|
+
|
|
426
453
|
def set_flow(
|
|
427
454
|
self,
|
|
428
455
|
steps: list[Union[str, StepToDict, dict[str, Any]]],
|
|
@@ -517,7 +544,9 @@ class BaseStep(ModelObj):
|
|
|
517
544
|
|
|
518
545
|
root = self._extract_root_step()
|
|
519
546
|
|
|
520
|
-
if not isinstance(root, RootFlowStep)
|
|
547
|
+
if not isinstance(root, RootFlowStep) or (
|
|
548
|
+
isinstance(root, RootFlowStep) and root.engine != "async"
|
|
549
|
+
):
|
|
521
550
|
raise GraphError(
|
|
522
551
|
"ModelRunnerStep can be added to 'Flow' topology graph only"
|
|
523
552
|
)
|
|
@@ -541,8 +570,8 @@ class BaseStep(ModelObj):
|
|
|
541
570
|
# Update model endpoints names in the root step
|
|
542
571
|
root.update_model_endpoints_names(step_model_endpoints_names)
|
|
543
572
|
|
|
544
|
-
@staticmethod
|
|
545
573
|
def _verify_shared_models(
|
|
574
|
+
self,
|
|
546
575
|
root: "RootFlowStep",
|
|
547
576
|
step: "ModelRunnerStep",
|
|
548
577
|
step_model_endpoints_names: list[str],
|
|
@@ -571,35 +600,41 @@ class BaseStep(ModelObj):
|
|
|
571
600
|
prefix, _ = mlrun.datastore.parse_store_uri(model_artifact_uri)
|
|
572
601
|
# if the model artifact is a prompt, we need to get the model URI
|
|
573
602
|
# to ensure that the shared runnable name is correct
|
|
603
|
+
llm_artifact_uri = None
|
|
574
604
|
if prefix == mlrun.utils.StorePrefix.LLMPrompt:
|
|
575
605
|
llm_artifact, _ = mlrun.store_manager.get_store_artifact(
|
|
576
606
|
model_artifact_uri
|
|
577
607
|
)
|
|
608
|
+
llm_artifact_uri = llm_artifact.uri
|
|
578
609
|
model_artifact_uri = mlrun.utils.remove_tag_from_artifact_uri(
|
|
579
610
|
llm_artifact.spec.parent_uri
|
|
580
611
|
)
|
|
581
|
-
actual_shared_name =
|
|
582
|
-
model_artifact_uri
|
|
612
|
+
actual_shared_name, shared_model_class, shared_model_params = (
|
|
613
|
+
root.get_shared_model_by_artifact_uri(model_artifact_uri)
|
|
583
614
|
)
|
|
584
615
|
|
|
585
|
-
if not
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
]["shared_runnable_name"] = actual_shared_name
|
|
594
|
-
shared_models.append(actual_shared_name)
|
|
616
|
+
if not actual_shared_name:
|
|
617
|
+
raise GraphError(
|
|
618
|
+
f"Can't find shared model named {shared_runnable_name}"
|
|
619
|
+
)
|
|
620
|
+
elif not shared_runnable_name:
|
|
621
|
+
step.class_args[schemas.ModelRunnerStepData.MODELS][name][
|
|
622
|
+
schemas.ModelsData.MODEL_PARAMETERS.value
|
|
623
|
+
]["shared_runnable_name"] = actual_shared_name
|
|
595
624
|
elif actual_shared_name != shared_runnable_name:
|
|
596
625
|
raise GraphError(
|
|
597
626
|
f"Model endpoint {name} shared runnable name mismatch: "
|
|
598
627
|
f"expected {actual_shared_name}, got {shared_runnable_name}"
|
|
599
628
|
)
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
629
|
+
shared_models.append(actual_shared_name)
|
|
630
|
+
self._edit_proxy_model_data(
|
|
631
|
+
step,
|
|
632
|
+
name,
|
|
633
|
+
actual_shared_name,
|
|
634
|
+
shared_model_params,
|
|
635
|
+
shared_model_class,
|
|
636
|
+
llm_artifact_uri or model_artifact_uri,
|
|
637
|
+
)
|
|
603
638
|
undefined_shared_models = list(
|
|
604
639
|
set(shared_models) - set(root.shared_models.keys())
|
|
605
640
|
)
|
|
@@ -608,12 +643,71 @@ class BaseStep(ModelObj):
|
|
|
608
643
|
f"The following shared models are not defined in the graph: {undefined_shared_models}."
|
|
609
644
|
)
|
|
610
645
|
|
|
646
|
+
@staticmethod
|
|
647
|
+
def _edit_proxy_model_data(
|
|
648
|
+
step: "ModelRunnerStep",
|
|
649
|
+
name: str,
|
|
650
|
+
actual_shared_name: str,
|
|
651
|
+
shared_model_params: dict,
|
|
652
|
+
shared_model_class: Any,
|
|
653
|
+
artifact: Union[ModelArtifact, LLMPromptArtifact, str],
|
|
654
|
+
):
|
|
655
|
+
monitoring_data = step.class_args.setdefault(
|
|
656
|
+
schemas.ModelRunnerStepData.MONITORING_DATA, {}
|
|
657
|
+
)
|
|
658
|
+
|
|
659
|
+
# edit monitoring data according to the shared model parameters
|
|
660
|
+
monitoring_data[name][schemas.MonitoringData.INPUT_PATH] = shared_model_params[
|
|
661
|
+
"input_path"
|
|
662
|
+
]
|
|
663
|
+
monitoring_data[name][schemas.MonitoringData.RESULT_PATH] = shared_model_params[
|
|
664
|
+
"result_path"
|
|
665
|
+
]
|
|
666
|
+
monitoring_data[name][schemas.MonitoringData.INPUTS] = shared_model_params[
|
|
667
|
+
"inputs"
|
|
668
|
+
]
|
|
669
|
+
monitoring_data[name][schemas.MonitoringData.OUTPUTS] = shared_model_params[
|
|
670
|
+
"outputs"
|
|
671
|
+
]
|
|
672
|
+
monitoring_data[name][schemas.MonitoringData.MODEL_CLASS] = (
|
|
673
|
+
shared_model_class
|
|
674
|
+
if isinstance(shared_model_class, str)
|
|
675
|
+
else shared_model_class.__class__.__name__
|
|
676
|
+
)
|
|
677
|
+
if actual_shared_name and actual_shared_name not in step._shared_proxy_mapping:
|
|
678
|
+
step._shared_proxy_mapping[actual_shared_name] = {
|
|
679
|
+
name: artifact.uri
|
|
680
|
+
if isinstance(artifact, ModelArtifact | LLMPromptArtifact)
|
|
681
|
+
else artifact
|
|
682
|
+
}
|
|
683
|
+
elif actual_shared_name:
|
|
684
|
+
step._shared_proxy_mapping[actual_shared_name].update(
|
|
685
|
+
{
|
|
686
|
+
name: artifact.uri
|
|
687
|
+
if isinstance(artifact, ModelArtifact | LLMPromptArtifact)
|
|
688
|
+
else artifact
|
|
689
|
+
}
|
|
690
|
+
)
|
|
691
|
+
|
|
611
692
|
|
|
612
693
|
class TaskStep(BaseStep):
|
|
613
694
|
"""task execution step, runs a class or handler"""
|
|
614
695
|
|
|
615
696
|
kind = "task"
|
|
616
|
-
_dict_fields =
|
|
697
|
+
_dict_fields = BaseStep._dict_fields + [
|
|
698
|
+
"class_name",
|
|
699
|
+
"class_args",
|
|
700
|
+
"handler",
|
|
701
|
+
"skip_context",
|
|
702
|
+
"function",
|
|
703
|
+
"shape",
|
|
704
|
+
"full_event",
|
|
705
|
+
"responder",
|
|
706
|
+
"input_path",
|
|
707
|
+
"result_path",
|
|
708
|
+
"model_endpoint_creation_strategy",
|
|
709
|
+
"endpoint_type",
|
|
710
|
+
]
|
|
617
711
|
_default_class = ""
|
|
618
712
|
|
|
619
713
|
def __init__(
|
|
@@ -639,6 +733,7 @@ class TaskStep(BaseStep):
|
|
|
639
733
|
self.handler = handler
|
|
640
734
|
self.function = function
|
|
641
735
|
self._handler = None
|
|
736
|
+
self._outlets_selector = None
|
|
642
737
|
self._object = None
|
|
643
738
|
self._async_object = None
|
|
644
739
|
self.skip_context = None
|
|
@@ -706,6 +801,8 @@ class TaskStep(BaseStep):
|
|
|
706
801
|
handler = "do"
|
|
707
802
|
if handler:
|
|
708
803
|
self._handler = getattr(self._object, handler, None)
|
|
804
|
+
if hasattr(self._object, "select_outlets"):
|
|
805
|
+
self._outlets_selector = self._object.select_outlets
|
|
709
806
|
|
|
710
807
|
self._set_error_handler()
|
|
711
808
|
if mode != "skip":
|
|
@@ -879,7 +976,7 @@ class ErrorStep(TaskStep):
|
|
|
879
976
|
"""error execution step, runs a class or handler"""
|
|
880
977
|
|
|
881
978
|
kind = "error_step"
|
|
882
|
-
_dict_fields =
|
|
979
|
+
_dict_fields = TaskStep._dict_fields + ["before", "base_step"]
|
|
883
980
|
_default_class = ""
|
|
884
981
|
|
|
885
982
|
def __init__(
|
|
@@ -916,7 +1013,7 @@ class RouterStep(TaskStep):
|
|
|
916
1013
|
|
|
917
1014
|
kind = "router"
|
|
918
1015
|
default_shape = "doubleoctagon"
|
|
919
|
-
_dict_fields =
|
|
1016
|
+
_dict_fields = TaskStep._dict_fields + ["routes", "name"]
|
|
920
1017
|
_default_class = "mlrun.serving.ModelRouter"
|
|
921
1018
|
|
|
922
1019
|
def __init__(
|
|
@@ -983,20 +1080,14 @@ class RouterStep(TaskStep):
|
|
|
983
1080
|
:param function: function this step should run in
|
|
984
1081
|
:param creation_strategy: Strategy for creating or updating the model endpoint:
|
|
985
1082
|
|
|
986
|
-
* **overwrite**:
|
|
987
|
-
|
|
988
|
-
1. If model endpoints with the same name exist, delete the `latest` one.
|
|
989
|
-
2. Create a new model endpoint entry and set it as `latest`.
|
|
990
|
-
|
|
991
|
-
* **inplace** (default):
|
|
1083
|
+
* **overwrite**: If model endpoints with the same name exist, delete the `latest` one;
|
|
1084
|
+
create a new model endpoint entry and set it as `latest`.
|
|
992
1085
|
|
|
993
|
-
|
|
994
|
-
|
|
1086
|
+
* **inplace** (default): If model endpoints with the same name exist, update the `latest`
|
|
1087
|
+
entry;otherwise, create a new entry.
|
|
995
1088
|
|
|
996
|
-
* **archive**:
|
|
997
|
-
|
|
998
|
-
1. If model endpoints with the same name exist, preserve them.
|
|
999
|
-
2. Create a new model endpoint with the same name and set it to `latest`.
|
|
1089
|
+
* **archive**: If model endpoints with the same name exist, preserve them;
|
|
1090
|
+
create a new model endpoint with the same name and set it to `latest`.
|
|
1000
1091
|
|
|
1001
1092
|
"""
|
|
1002
1093
|
if len(self.routes.keys()) >= MAX_MODELS_PER_ROUTER and key not in self.routes:
|
|
@@ -1090,6 +1181,7 @@ class Model(storey.ParallelExecutionRunnable, ModelObj):
|
|
|
1090
1181
|
"artifact_uri",
|
|
1091
1182
|
"shared_runnable_name",
|
|
1092
1183
|
"shared_proxy_mapping",
|
|
1184
|
+
"execution_mechanism",
|
|
1093
1185
|
]
|
|
1094
1186
|
kind = "model"
|
|
1095
1187
|
|
|
@@ -1111,6 +1203,8 @@ class Model(storey.ParallelExecutionRunnable, ModelObj):
|
|
|
1111
1203
|
self.invocation_artifact: Optional[LLMPromptArtifact] = None
|
|
1112
1204
|
self.model_artifact: Optional[ModelArtifact] = None
|
|
1113
1205
|
self.model_provider: Optional[ModelProvider] = None
|
|
1206
|
+
self._artifact_were_loaded = False
|
|
1207
|
+
self._execution_mechanism = None
|
|
1114
1208
|
|
|
1115
1209
|
def __init_subclass__(cls):
|
|
1116
1210
|
super().__init_subclass__()
|
|
@@ -1130,13 +1224,33 @@ class Model(storey.ParallelExecutionRunnable, ModelObj):
|
|
|
1130
1224
|
raise_missing_schema_exception=False,
|
|
1131
1225
|
)
|
|
1132
1226
|
|
|
1133
|
-
|
|
1134
|
-
|
|
1135
|
-
|
|
1136
|
-
|
|
1137
|
-
|
|
1227
|
+
# Check if the relevant predict method is implemented when trying to initialize the model
|
|
1228
|
+
if self._execution_mechanism == storey.ParallelExecutionMechanisms.asyncio:
|
|
1229
|
+
if self.__class__.predict_async is Model.predict_async:
|
|
1230
|
+
raise mlrun.errors.ModelRunnerError(
|
|
1231
|
+
{
|
|
1232
|
+
self.name: f"is running with {self._execution_mechanism} "
|
|
1233
|
+
f"execution_mechanism but predict_async() is not implemented"
|
|
1234
|
+
}
|
|
1235
|
+
)
|
|
1138
1236
|
else:
|
|
1139
|
-
self.
|
|
1237
|
+
if self.__class__.predict is Model.predict:
|
|
1238
|
+
raise mlrun.errors.ModelRunnerError(
|
|
1239
|
+
{
|
|
1240
|
+
self.name: f"is running with {self._execution_mechanism} execution_mechanism but predict() "
|
|
1241
|
+
f"is not implemented"
|
|
1242
|
+
}
|
|
1243
|
+
)
|
|
1244
|
+
|
|
1245
|
+
def _load_artifacts(self) -> None:
|
|
1246
|
+
if not self._artifact_were_loaded:
|
|
1247
|
+
artifact = self._get_artifact_object()
|
|
1248
|
+
if isinstance(artifact, LLMPromptArtifact):
|
|
1249
|
+
self.invocation_artifact = artifact
|
|
1250
|
+
self.model_artifact = self.invocation_artifact.model_artifact
|
|
1251
|
+
else:
|
|
1252
|
+
self.model_artifact = artifact
|
|
1253
|
+
self._artifact_were_loaded = True
|
|
1140
1254
|
|
|
1141
1255
|
def _get_artifact_object(
|
|
1142
1256
|
self, proxy_uri: Optional[str] = None
|
|
@@ -1144,7 +1258,9 @@ class Model(storey.ParallelExecutionRunnable, ModelObj):
|
|
|
1144
1258
|
uri = proxy_uri or self.artifact_uri
|
|
1145
1259
|
if uri:
|
|
1146
1260
|
if mlrun.datastore.is_store_uri(uri):
|
|
1147
|
-
artifact, _ = mlrun.store_manager.get_store_artifact(
|
|
1261
|
+
artifact, _ = mlrun.store_manager.get_store_artifact(
|
|
1262
|
+
uri, allow_empty_resources=True
|
|
1263
|
+
)
|
|
1148
1264
|
return artifact
|
|
1149
1265
|
else:
|
|
1150
1266
|
raise ValueError(
|
|
@@ -1158,13 +1274,15 @@ class Model(storey.ParallelExecutionRunnable, ModelObj):
|
|
|
1158
1274
|
|
|
1159
1275
|
def predict(self, body: Any, **kwargs) -> Any:
|
|
1160
1276
|
"""Override to implement prediction logic. If the logic requires asyncio, override predict_async() instead."""
|
|
1161
|
-
|
|
1277
|
+
raise NotImplementedError("predict() method not implemented")
|
|
1162
1278
|
|
|
1163
1279
|
async def predict_async(self, body: Any, **kwargs) -> Any:
|
|
1164
1280
|
"""Override to implement prediction logic if the logic requires asyncio."""
|
|
1165
|
-
|
|
1281
|
+
raise NotImplementedError("predict_async() method not implemented")
|
|
1166
1282
|
|
|
1167
1283
|
def run(self, body: Any, path: str, origin_name: Optional[str] = None) -> Any:
|
|
1284
|
+
if isinstance(body, list):
|
|
1285
|
+
body = self.format_batch(body)
|
|
1168
1286
|
return self.predict(body)
|
|
1169
1287
|
|
|
1170
1288
|
async def run_async(
|
|
@@ -1203,28 +1321,117 @@ class Model(storey.ParallelExecutionRunnable, ModelObj):
|
|
|
1203
1321
|
return model_file, extra_dataitems
|
|
1204
1322
|
return None, None
|
|
1205
1323
|
|
|
1324
|
+
@staticmethod
|
|
1325
|
+
def format_batch(body: Any):
|
|
1326
|
+
return body
|
|
1327
|
+
|
|
1206
1328
|
|
|
1207
1329
|
class LLModel(Model):
|
|
1330
|
+
"""
|
|
1331
|
+
A model wrapper for handling LLM (Large Language Model) prompt-based inference.
|
|
1332
|
+
|
|
1333
|
+
This class extends the base `Model` to provide specialized handling for
|
|
1334
|
+
`LLMPromptArtifact` objects, enabling both synchronous and asynchronous
|
|
1335
|
+
invocation of language models.
|
|
1336
|
+
|
|
1337
|
+
**Model Invocation**:
|
|
1338
|
+
|
|
1339
|
+
- The execution of enriched prompts is delegated to the `model_provider`
|
|
1340
|
+
configured for the model (e.g., **Hugging Face** or **OpenAI**).
|
|
1341
|
+
- The `model_provider` is responsible for sending the prompt to the correct
|
|
1342
|
+
backend API and returning the generated output.
|
|
1343
|
+
- Users can override the `predict` and `predict_async` methods to customize
|
|
1344
|
+
the behavior of the model invocation.
|
|
1345
|
+
|
|
1346
|
+
**Prompt Enrichment Overview**:
|
|
1347
|
+
|
|
1348
|
+
- If an `LLMPromptArtifact` is found, load its prompt template and fill in
|
|
1349
|
+
placeholders using values from the request body.
|
|
1350
|
+
- If the artifact is not an `LLMPromptArtifact`, skip formatting and attempt
|
|
1351
|
+
to retrieve `messages` directly from the request body using the input path.
|
|
1352
|
+
|
|
1353
|
+
**Simplified Example**:
|
|
1354
|
+
|
|
1355
|
+
Input body::
|
|
1356
|
+
|
|
1357
|
+
{"city": "Paris", "days": 3}
|
|
1358
|
+
|
|
1359
|
+
Prompt template in artifact::
|
|
1360
|
+
|
|
1361
|
+
[
|
|
1362
|
+
{"role": "system", "content": "You are a travel planning assistant."},
|
|
1363
|
+
{"role": "user", "content": "Create a {{days}}-day itinerary for {{city}}."},
|
|
1364
|
+
]
|
|
1365
|
+
|
|
1366
|
+
Result after enrichment::
|
|
1367
|
+
|
|
1368
|
+
[
|
|
1369
|
+
{"role": "system", "content": "You are a travel planning assistant."},
|
|
1370
|
+
{"role": "user", "content": "Create a 3-day itinerary for Paris."},
|
|
1371
|
+
]
|
|
1372
|
+
|
|
1373
|
+
:param name: Name of the model.
|
|
1374
|
+
:param input_path: Path in the request body where input data is located.
|
|
1375
|
+
:param result_path: Path in the response body where model outputs and the statistics
|
|
1376
|
+
will be stored.
|
|
1377
|
+
"""
|
|
1378
|
+
|
|
1379
|
+
_dict_fields = Model._dict_fields + ["result_path", "input_path"]
|
|
1380
|
+
|
|
1208
1381
|
def __init__(
|
|
1209
|
-
self,
|
|
1382
|
+
self,
|
|
1383
|
+
name: str,
|
|
1384
|
+
input_path: Optional[Union[str, list[str]]] = None,
|
|
1385
|
+
result_path: Optional[Union[str, list[str]]] = None,
|
|
1386
|
+
**kwargs,
|
|
1210
1387
|
):
|
|
1211
1388
|
super().__init__(name, **kwargs)
|
|
1212
1389
|
self._input_path = split_path(input_path)
|
|
1390
|
+
self._result_path = split_path(result_path)
|
|
1391
|
+
logger.info(
|
|
1392
|
+
"LLModel initialized",
|
|
1393
|
+
model_name=name,
|
|
1394
|
+
input_path=input_path,
|
|
1395
|
+
result_path=result_path,
|
|
1396
|
+
)
|
|
1213
1397
|
|
|
1214
1398
|
def predict(
|
|
1215
1399
|
self,
|
|
1216
1400
|
body: Any,
|
|
1217
1401
|
messages: Optional[list[dict]] = None,
|
|
1218
|
-
|
|
1402
|
+
invocation_config: Optional[dict] = None,
|
|
1219
1403
|
**kwargs,
|
|
1220
1404
|
) -> Any:
|
|
1405
|
+
llm_prompt_artifact = kwargs.get("llm_prompt_artifact")
|
|
1221
1406
|
if isinstance(
|
|
1222
|
-
|
|
1407
|
+
llm_prompt_artifact, mlrun.artifacts.LLMPromptArtifact
|
|
1223
1408
|
) and isinstance(self.model_provider, ModelProvider):
|
|
1224
|
-
|
|
1409
|
+
logger.debug(
|
|
1410
|
+
"Invoking model provider",
|
|
1411
|
+
model_name=self.name,
|
|
1412
|
+
messages=messages,
|
|
1413
|
+
invocation_config=invocation_config,
|
|
1414
|
+
)
|
|
1415
|
+
response_with_stats = self.model_provider.invoke(
|
|
1225
1416
|
messages=messages,
|
|
1226
|
-
|
|
1227
|
-
**(
|
|
1417
|
+
invoke_response_format=InvokeResponseFormat.USAGE,
|
|
1418
|
+
**(invocation_config or {}),
|
|
1419
|
+
)
|
|
1420
|
+
set_data_by_path(
|
|
1421
|
+
path=self._result_path, data=body, value=response_with_stats
|
|
1422
|
+
)
|
|
1423
|
+
logger.debug(
|
|
1424
|
+
"LLModel prediction completed",
|
|
1425
|
+
model_name=self.name,
|
|
1426
|
+
answer=response_with_stats.get("answer"),
|
|
1427
|
+
usage=response_with_stats.get("usage"),
|
|
1428
|
+
)
|
|
1429
|
+
else:
|
|
1430
|
+
logger.warning(
|
|
1431
|
+
"LLModel invocation artifact or model provider not set, skipping prediction",
|
|
1432
|
+
model_name=self.name,
|
|
1433
|
+
invocation_artifact_type=type(llm_prompt_artifact).__name__,
|
|
1434
|
+
model_provider_type=type(self.model_provider).__name__,
|
|
1228
1435
|
)
|
|
1229
1436
|
return body
|
|
1230
1437
|
|
|
@@ -1232,61 +1439,130 @@ class LLModel(Model):
|
|
|
1232
1439
|
self,
|
|
1233
1440
|
body: Any,
|
|
1234
1441
|
messages: Optional[list[dict]] = None,
|
|
1235
|
-
|
|
1442
|
+
invocation_config: Optional[dict] = None,
|
|
1236
1443
|
**kwargs,
|
|
1237
1444
|
) -> Any:
|
|
1445
|
+
llm_prompt_artifact = kwargs.get("llm_prompt_artifact")
|
|
1238
1446
|
if isinstance(
|
|
1239
|
-
|
|
1447
|
+
llm_prompt_artifact, mlrun.artifacts.LLMPromptArtifact
|
|
1240
1448
|
) and isinstance(self.model_provider, ModelProvider):
|
|
1241
|
-
|
|
1449
|
+
logger.debug(
|
|
1450
|
+
"Async invoking model provider",
|
|
1451
|
+
model_name=self.name,
|
|
1242
1452
|
messages=messages,
|
|
1243
|
-
|
|
1244
|
-
|
|
1453
|
+
invocation_config=invocation_config,
|
|
1454
|
+
)
|
|
1455
|
+
response_with_stats = await self.model_provider.async_invoke(
|
|
1456
|
+
messages=messages,
|
|
1457
|
+
invoke_response_format=InvokeResponseFormat.USAGE,
|
|
1458
|
+
**(invocation_config or {}),
|
|
1459
|
+
)
|
|
1460
|
+
set_data_by_path(
|
|
1461
|
+
path=self._result_path, data=body, value=response_with_stats
|
|
1462
|
+
)
|
|
1463
|
+
logger.debug(
|
|
1464
|
+
"LLModel async prediction completed",
|
|
1465
|
+
model_name=self.name,
|
|
1466
|
+
answer=response_with_stats.get("answer"),
|
|
1467
|
+
usage=response_with_stats.get("usage"),
|
|
1468
|
+
)
|
|
1469
|
+
else:
|
|
1470
|
+
logger.warning(
|
|
1471
|
+
"LLModel invocation artifact or model provider not set, skipping async prediction",
|
|
1472
|
+
model_name=self.name,
|
|
1473
|
+
invocation_artifact_type=type(llm_prompt_artifact).__name__,
|
|
1474
|
+
model_provider_type=type(self.model_provider).__name__,
|
|
1245
1475
|
)
|
|
1246
1476
|
return body
|
|
1247
1477
|
|
|
1478
|
+
def init(self):
|
|
1479
|
+
super().init()
|
|
1480
|
+
|
|
1481
|
+
if not self.model_provider:
|
|
1482
|
+
if self._execution_mechanism != storey.ParallelExecutionMechanisms.asyncio:
|
|
1483
|
+
unchanged_predict = self.__class__.predict is LLModel.predict
|
|
1484
|
+
predict_function_name = "predict"
|
|
1485
|
+
else:
|
|
1486
|
+
unchanged_predict = (
|
|
1487
|
+
self.__class__.predict_async is LLModel.predict_async
|
|
1488
|
+
)
|
|
1489
|
+
predict_function_name = "predict_async"
|
|
1490
|
+
if unchanged_predict:
|
|
1491
|
+
raise mlrun.errors.MLRunRuntimeError(
|
|
1492
|
+
f"Model provider could not be determined for model '{self.name}',"
|
|
1493
|
+
f" and the {predict_function_name} function was not overridden."
|
|
1494
|
+
)
|
|
1495
|
+
|
|
1248
1496
|
def run(self, body: Any, path: str, origin_name: Optional[str] = None) -> Any:
|
|
1249
|
-
|
|
1497
|
+
llm_prompt_artifact = self._get_invocation_artifact(origin_name)
|
|
1498
|
+
messages, invocation_config = self.enrich_prompt(
|
|
1499
|
+
body, origin_name, llm_prompt_artifact
|
|
1500
|
+
)
|
|
1501
|
+
logger.info(
|
|
1502
|
+
"Calling LLModel predict",
|
|
1503
|
+
model_name=self.name,
|
|
1504
|
+
model_endpoint_name=origin_name,
|
|
1505
|
+
messages_len=len(messages) if messages else 0,
|
|
1506
|
+
)
|
|
1250
1507
|
return self.predict(
|
|
1251
|
-
body,
|
|
1508
|
+
body,
|
|
1509
|
+
messages=messages,
|
|
1510
|
+
invocation_config=invocation_config,
|
|
1511
|
+
llm_prompt_artifact=llm_prompt_artifact,
|
|
1252
1512
|
)
|
|
1253
1513
|
|
|
1254
1514
|
async def run_async(
|
|
1255
1515
|
self, body: Any, path: str, origin_name: Optional[str] = None
|
|
1256
1516
|
) -> Any:
|
|
1257
|
-
|
|
1517
|
+
llm_prompt_artifact = self._get_invocation_artifact(origin_name)
|
|
1518
|
+
messages, invocation_config = self.enrich_prompt(
|
|
1519
|
+
body, origin_name, llm_prompt_artifact
|
|
1520
|
+
)
|
|
1521
|
+
logger.info(
|
|
1522
|
+
"Calling LLModel async predict",
|
|
1523
|
+
model_name=self.name,
|
|
1524
|
+
model_endpoint_name=origin_name,
|
|
1525
|
+
messages_len=len(messages) if messages else 0,
|
|
1526
|
+
)
|
|
1258
1527
|
return await self.predict_async(
|
|
1259
|
-
body,
|
|
1528
|
+
body,
|
|
1529
|
+
messages=messages,
|
|
1530
|
+
invocation_config=invocation_config,
|
|
1531
|
+
llm_prompt_artifact=llm_prompt_artifact,
|
|
1260
1532
|
)
|
|
1261
1533
|
|
|
1262
1534
|
def enrich_prompt(
|
|
1263
|
-
self,
|
|
1535
|
+
self,
|
|
1536
|
+
body: dict,
|
|
1537
|
+
origin_name: str,
|
|
1538
|
+
llm_prompt_artifact: Optional[LLMPromptArtifact] = None,
|
|
1264
1539
|
) -> Union[tuple[list[dict], dict], tuple[None, None]]:
|
|
1265
|
-
|
|
1266
|
-
|
|
1267
|
-
|
|
1268
|
-
|
|
1269
|
-
|
|
1270
|
-
|
|
1271
|
-
llm_prompt_artifact = (
|
|
1272
|
-
self.invocation_artifact or self._get_artifact_object()
|
|
1273
|
-
)
|
|
1274
|
-
if not (
|
|
1540
|
+
logger.info(
|
|
1541
|
+
"Enriching prompt",
|
|
1542
|
+
model_name=self.name,
|
|
1543
|
+
model_endpoint_name=origin_name,
|
|
1544
|
+
)
|
|
1545
|
+
if not llm_prompt_artifact or not (
|
|
1275
1546
|
llm_prompt_artifact and isinstance(llm_prompt_artifact, LLMPromptArtifact)
|
|
1276
1547
|
):
|
|
1277
1548
|
logger.warning(
|
|
1278
|
-
"
|
|
1549
|
+
"LLModel must be provided with LLMPromptArtifact",
|
|
1550
|
+
model_name=self.name,
|
|
1551
|
+
artifact_type=type(llm_prompt_artifact).__name__,
|
|
1279
1552
|
llm_prompt_artifact=llm_prompt_artifact,
|
|
1280
1553
|
)
|
|
1281
|
-
|
|
1282
|
-
|
|
1283
|
-
|
|
1554
|
+
prompt_legend, prompt_template, invocation_config = {}, [], {}
|
|
1555
|
+
else:
|
|
1556
|
+
prompt_legend = llm_prompt_artifact.spec.prompt_legend
|
|
1557
|
+
prompt_template = deepcopy(llm_prompt_artifact.read_prompt())
|
|
1558
|
+
invocation_config = llm_prompt_artifact.spec.invocation_config
|
|
1284
1559
|
input_data = copy(get_data_from_path(self._input_path, body))
|
|
1285
|
-
if isinstance(input_data, dict):
|
|
1560
|
+
if isinstance(input_data, dict) and prompt_template:
|
|
1286
1561
|
kwargs = (
|
|
1287
1562
|
{
|
|
1288
1563
|
place_holder: input_data.get(body_map["field"])
|
|
1289
1564
|
for place_holder, body_map in prompt_legend.items()
|
|
1565
|
+
if input_data.get(body_map["field"])
|
|
1290
1566
|
}
|
|
1291
1567
|
if prompt_legend
|
|
1292
1568
|
else {}
|
|
@@ -1298,23 +1574,124 @@ class LLModel(Model):
|
|
|
1298
1574
|
message["content"] = message["content"].format(**input_data)
|
|
1299
1575
|
except KeyError as e:
|
|
1300
1576
|
logger.warning(
|
|
1301
|
-
"Input data
|
|
1302
|
-
|
|
1577
|
+
"Input data missing placeholder, content stays unformatted",
|
|
1578
|
+
model_name=self.name,
|
|
1579
|
+
key_error=mlrun.errors.err_to_str(e),
|
|
1303
1580
|
)
|
|
1304
1581
|
message["content"] = message["content"].format_map(
|
|
1305
1582
|
default_place_holders
|
|
1306
1583
|
)
|
|
1584
|
+
elif isinstance(input_data, dict) and not prompt_template:
|
|
1585
|
+
# If there is no prompt template, we assume the input data is already in the correct format.
|
|
1586
|
+
logger.debug("Attempting to retrieve messages from the request body.")
|
|
1587
|
+
prompt_template = input_data.get("messages", [])
|
|
1307
1588
|
else:
|
|
1308
1589
|
logger.warning(
|
|
1309
|
-
|
|
1310
|
-
|
|
1590
|
+
"Expected input data to be a dict, prompt template stays unformatted",
|
|
1591
|
+
model_name=self.name,
|
|
1592
|
+
input_data_type=type(input_data).__name__,
|
|
1311
1593
|
)
|
|
1312
|
-
return prompt_template,
|
|
1594
|
+
return prompt_template, invocation_config
|
|
1595
|
+
|
|
1596
|
+
def _get_invocation_artifact(
|
|
1597
|
+
self, origin_name: Optional[str] = None
|
|
1598
|
+
) -> Union[LLMPromptArtifact, None]:
|
|
1599
|
+
"""
|
|
1600
|
+
Get the LLMPromptArtifact object for this model.
|
|
1601
|
+
|
|
1602
|
+
:param proxy_uri: Optional; URI to the proxy artifact.
|
|
1603
|
+
:return: LLMPromptArtifact object or None if not found.
|
|
1604
|
+
"""
|
|
1605
|
+
if origin_name and self.shared_proxy_mapping:
|
|
1606
|
+
llm_prompt_artifact = self.shared_proxy_mapping.get(origin_name)
|
|
1607
|
+
if isinstance(llm_prompt_artifact, str):
|
|
1608
|
+
llm_prompt_artifact = self._get_artifact_object(llm_prompt_artifact)
|
|
1609
|
+
self.shared_proxy_mapping[origin_name] = llm_prompt_artifact
|
|
1610
|
+
elif self._artifact_were_loaded:
|
|
1611
|
+
llm_prompt_artifact = self.invocation_artifact
|
|
1612
|
+
else:
|
|
1613
|
+
self._load_artifacts()
|
|
1614
|
+
llm_prompt_artifact = self.invocation_artifact
|
|
1615
|
+
return llm_prompt_artifact
|
|
1616
|
+
|
|
1617
|
+
|
|
1618
|
+
class ModelRunnerSelector(ModelObj):
|
|
1619
|
+
"""
|
|
1620
|
+
Strategy for controlling model selection and output routing in ModelRunnerStep.
|
|
1621
|
+
|
|
1622
|
+
Subclass this to implement custom logic for agent workflows:
|
|
1623
|
+
- `select_models()`: Called BEFORE execution to choose which models run
|
|
1624
|
+
- `select_outlets()`: Called AFTER execution to route output to downstream steps
|
|
1625
|
+
|
|
1626
|
+
Return `None` from either method to use default behavior (all models / all outlets).
|
|
1627
|
+
|
|
1628
|
+
Example::
|
|
1629
|
+
|
|
1630
|
+
class ToolSelector(ModelRunnerSelector):
|
|
1631
|
+
def select_outlets(self, event):
|
|
1632
|
+
tool = event.get("tool_call")
|
|
1633
|
+
return [tool] if tool else ["final"]
|
|
1634
|
+
"""
|
|
1635
|
+
|
|
1636
|
+
def __init__(self, **kwargs):
|
|
1637
|
+
super().__init__()
|
|
1638
|
+
|
|
1639
|
+
def __init_subclass__(cls):
|
|
1640
|
+
super().__init_subclass__()
|
|
1641
|
+
cls._dict_fields = list(
|
|
1642
|
+
set(cls._dict_fields)
|
|
1643
|
+
| set(inspect.signature(cls.__init__).parameters.keys())
|
|
1644
|
+
)
|
|
1645
|
+
cls._dict_fields.remove("self")
|
|
1646
|
+
|
|
1647
|
+
def select_models(
|
|
1648
|
+
self,
|
|
1649
|
+
event: Any,
|
|
1650
|
+
available_models: list[Model],
|
|
1651
|
+
) -> Optional[Union[list[str], list[Model]]]:
|
|
1652
|
+
"""
|
|
1653
|
+
Called before model execution.
|
|
1654
|
+
|
|
1655
|
+
:param event: The full event
|
|
1656
|
+
:param available_models: List of available models
|
|
1657
|
+
|
|
1658
|
+
Returns the models to execute (by name or Model objects).
|
|
1659
|
+
"""
|
|
1660
|
+
return None
|
|
1661
|
+
|
|
1662
|
+
def select_outlets(
|
|
1663
|
+
self,
|
|
1664
|
+
event: Any,
|
|
1665
|
+
) -> Optional[list[str]]:
|
|
1666
|
+
"""
|
|
1667
|
+
Called after model execution.
|
|
1668
|
+
|
|
1669
|
+
:param event: The event body after model execution
|
|
1670
|
+
:return: Returns the downstream outlets to route the event to.
|
|
1671
|
+
"""
|
|
1672
|
+
return None
|
|
1313
1673
|
|
|
1314
1674
|
|
|
1315
|
-
|
|
1675
|
+
# TODO: Remove in 1.13.0
|
|
1676
|
+
@deprecated(
|
|
1677
|
+
version="1.11.0",
|
|
1678
|
+
reason="ModelSelector will be removed in 1.13.0, use ModelRunnerSelector instead",
|
|
1679
|
+
category=FutureWarning,
|
|
1680
|
+
)
|
|
1681
|
+
class ModelSelector(ModelObj):
|
|
1316
1682
|
"""Used to select which models to run on each event."""
|
|
1317
1683
|
|
|
1684
|
+
def __init__(self, **kwargs):
|
|
1685
|
+
super().__init__()
|
|
1686
|
+
|
|
1687
|
+
def __init_subclass__(cls):
|
|
1688
|
+
super().__init_subclass__()
|
|
1689
|
+
cls._dict_fields = list(
|
|
1690
|
+
set(cls._dict_fields)
|
|
1691
|
+
| set(inspect.signature(cls.__init__).parameters.keys())
|
|
1692
|
+
)
|
|
1693
|
+
cls._dict_fields.remove("self")
|
|
1694
|
+
|
|
1318
1695
|
def select(
|
|
1319
1696
|
self, event, available_models: list[Model]
|
|
1320
1697
|
) -> Union[list[str], list[Model]]:
|
|
@@ -1332,16 +1709,22 @@ class ModelRunner(storey.ParallelExecution):
|
|
|
1332
1709
|
"""
|
|
1333
1710
|
Runs multiple Models on each event. See ModelRunnerStep.
|
|
1334
1711
|
|
|
1335
|
-
:param
|
|
1336
|
-
|
|
1712
|
+
:param model_runner_selector: ModelSelector instance whose select() method will be used to select models
|
|
1713
|
+
to run on each event. Optional. If not passed, all models will be run.
|
|
1337
1714
|
"""
|
|
1338
1715
|
|
|
1339
1716
|
def __init__(
|
|
1340
|
-
self,
|
|
1717
|
+
self,
|
|
1718
|
+
*args,
|
|
1719
|
+
context,
|
|
1720
|
+
model_runner_selector: Optional[ModelRunnerSelector] = None,
|
|
1721
|
+
raise_exception: bool = True,
|
|
1722
|
+
**kwargs,
|
|
1341
1723
|
):
|
|
1342
1724
|
super().__init__(*args, **kwargs)
|
|
1343
|
-
self.
|
|
1725
|
+
self.model_runner_selector = model_runner_selector or ModelRunnerSelector()
|
|
1344
1726
|
self.context = context
|
|
1727
|
+
self._raise_exception = raise_exception
|
|
1345
1728
|
|
|
1346
1729
|
def preprocess_event(self, event):
|
|
1347
1730
|
if not hasattr(event, "_metadata"):
|
|
@@ -1354,7 +1737,31 @@ class ModelRunner(storey.ParallelExecution):
|
|
|
1354
1737
|
|
|
1355
1738
|
def select_runnables(self, event):
|
|
1356
1739
|
models = cast(list[Model], self.runnables)
|
|
1357
|
-
return self.
|
|
1740
|
+
return self.model_runner_selector.select_models(event, models)
|
|
1741
|
+
|
|
1742
|
+
def select_outlets(self, event) -> Optional[Collection[str]]:
|
|
1743
|
+
sys_outlets = [f"{self.name}_error_raise"]
|
|
1744
|
+
if "background_task_status_step" in self._name_to_outlet:
|
|
1745
|
+
sys_outlets.append("background_task_status_step")
|
|
1746
|
+
if self._raise_exception and self._is_error(event):
|
|
1747
|
+
return sys_outlets
|
|
1748
|
+
user_outlets = self.model_runner_selector.select_outlets(event)
|
|
1749
|
+
if user_outlets:
|
|
1750
|
+
return (
|
|
1751
|
+
user_outlets if isinstance(user_outlets, list) else [user_outlets]
|
|
1752
|
+
) + sys_outlets
|
|
1753
|
+
return None
|
|
1754
|
+
|
|
1755
|
+
def _is_error(self, event: dict) -> bool:
|
|
1756
|
+
if len(self.runnables) == 1:
|
|
1757
|
+
if isinstance(event, dict):
|
|
1758
|
+
return event.get("error") is not None
|
|
1759
|
+
else:
|
|
1760
|
+
for model in event:
|
|
1761
|
+
body_by_model = event.get(model)
|
|
1762
|
+
if isinstance(body_by_model, dict) and "error" in body_by_model:
|
|
1763
|
+
return True
|
|
1764
|
+
return False
|
|
1358
1765
|
|
|
1359
1766
|
|
|
1360
1767
|
class MonitoredStep(ABC, TaskStep, StepToDict):
|
|
@@ -1406,34 +1813,122 @@ class ModelRunnerStep(MonitoredStep):
|
|
|
1406
1813
|
model_runner_step.add_model(..., model_class=MyModel(name="my_model"))
|
|
1407
1814
|
graph.to(model_runner_step)
|
|
1408
1815
|
|
|
1409
|
-
|
|
1410
|
-
|
|
1411
|
-
|
|
1412
|
-
|
|
1816
|
+
Note when ModelRunnerStep is used in a graph, MLRun automatically imports
|
|
1817
|
+
the default language model class (LLModel) during function deployment.
|
|
1818
|
+
|
|
1819
|
+
Note ModelRunnerStep can only be added to a graph that has the flow topology and running with async engine.
|
|
1413
1820
|
|
|
1414
|
-
|
|
1415
|
-
|
|
1416
|
-
|
|
1417
|
-
added
|
|
1821
|
+
Note see configure_pool_resource method documentation for default number of max threads and max processes.
|
|
1822
|
+
|
|
1823
|
+
:raise ModelRunnerError: when a model raises an error the ModelRunnerStep will handle it, collect errors and
|
|
1824
|
+
outputs from added models. If raise_exception is True will raise ModelRunnerError. Else
|
|
1825
|
+
will add the error msg as part of the event body mapped by model name if more than
|
|
1826
|
+
one model was added to the ModelRunnerStep
|
|
1418
1827
|
"""
|
|
1419
1828
|
|
|
1420
1829
|
kind = "model_runner"
|
|
1421
|
-
_dict_fields = MonitoredStep._dict_fields + [
|
|
1830
|
+
_dict_fields = MonitoredStep._dict_fields + [
|
|
1831
|
+
"_shared_proxy_mapping",
|
|
1832
|
+
"max_processes",
|
|
1833
|
+
"max_threads",
|
|
1834
|
+
"pool_factor",
|
|
1835
|
+
]
|
|
1422
1836
|
|
|
1423
1837
|
def __init__(
|
|
1424
1838
|
self,
|
|
1425
1839
|
*args,
|
|
1426
1840
|
name: Optional[str] = None,
|
|
1841
|
+
model_runner_selector: Optional[Union[str, ModelRunnerSelector]] = None,
|
|
1842
|
+
model_runner_selector_parameters: Optional[dict] = None,
|
|
1427
1843
|
model_selector: Optional[Union[str, ModelSelector]] = None,
|
|
1844
|
+
model_selector_parameters: Optional[dict] = None,
|
|
1428
1845
|
raise_exception: bool = True,
|
|
1429
1846
|
**kwargs,
|
|
1430
1847
|
):
|
|
1848
|
+
"""
|
|
1849
|
+
|
|
1850
|
+
:param name: The name of the ModelRunnerStep.
|
|
1851
|
+
:param model_runner_selector: ModelRunnerSelector instance whose select_models()
|
|
1852
|
+
and select_outlets() methods will be used
|
|
1853
|
+
to select models to run on each event and outlets to
|
|
1854
|
+
route the event to.
|
|
1855
|
+
:param model_runner_selector_parameters: Parameters for the model_runner_selector, if model_runner_selector
|
|
1856
|
+
is the class name we will use this param when
|
|
1857
|
+
initializing the selector.
|
|
1858
|
+
:param model_selector: (Deprecated)
|
|
1859
|
+
:param model_selector_parameters: (Deprecated)
|
|
1860
|
+
:param raise_exception: Determines whether to raise ModelRunnerError when one or more models
|
|
1861
|
+
raise an error during execution.
|
|
1862
|
+
If False, errors will be added to the event body.
|
|
1863
|
+
"""
|
|
1864
|
+
self.max_processes = None
|
|
1865
|
+
self.max_threads = None
|
|
1866
|
+
self.pool_factor = None
|
|
1867
|
+
|
|
1868
|
+
if (model_selector or model_selector_parameters) and (
|
|
1869
|
+
model_runner_selector or model_runner_selector_parameters
|
|
1870
|
+
):
|
|
1871
|
+
raise GraphError(
|
|
1872
|
+
"Cannot provide both `model_selector`/`model_selector_parameters` "
|
|
1873
|
+
"and `model_runner_selector`/`model_runner_selector_parameters`. "
|
|
1874
|
+
"Please use only the latter pair."
|
|
1875
|
+
)
|
|
1876
|
+
if model_selector or model_selector_parameters:
|
|
1877
|
+
warnings.warn(
|
|
1878
|
+
"`model_selector` and `model_selector_parameters` are deprecated, "
|
|
1879
|
+
"please use `model_runner_selector` and `model_runner_selector_parameters` instead.",
|
|
1880
|
+
# TODO: Remove this in 1.13.0
|
|
1881
|
+
FutureWarning,
|
|
1882
|
+
)
|
|
1883
|
+
if isinstance(model_selector, ModelSelector) and model_selector_parameters:
|
|
1884
|
+
raise mlrun.errors.MLRunInvalidArgumentError(
|
|
1885
|
+
"Cannot provide a model_selector object as argument to `model_selector` and also provide "
|
|
1886
|
+
"`model_selector_parameters`."
|
|
1887
|
+
)
|
|
1888
|
+
if model_selector:
|
|
1889
|
+
model_selector_parameters = model_selector_parameters or (
|
|
1890
|
+
model_selector.to_dict()
|
|
1891
|
+
if isinstance(model_selector, ModelSelector)
|
|
1892
|
+
else {}
|
|
1893
|
+
)
|
|
1894
|
+
model_selector = (
|
|
1895
|
+
model_selector
|
|
1896
|
+
if isinstance(model_selector, str)
|
|
1897
|
+
else model_selector.__class__.__name__
|
|
1898
|
+
)
|
|
1899
|
+
else:
|
|
1900
|
+
if (
|
|
1901
|
+
isinstance(model_runner_selector, ModelRunnerSelector)
|
|
1902
|
+
and model_runner_selector_parameters
|
|
1903
|
+
):
|
|
1904
|
+
raise mlrun.errors.MLRunInvalidArgumentError(
|
|
1905
|
+
"Cannot provide a model_runner_selector object as argument to `model_runner_selector` "
|
|
1906
|
+
"and also provide `model_runner_selector_parameters`."
|
|
1907
|
+
)
|
|
1908
|
+
if model_runner_selector:
|
|
1909
|
+
model_runner_selector_parameters = model_runner_selector_parameters or (
|
|
1910
|
+
model_runner_selector.to_dict()
|
|
1911
|
+
if isinstance(model_runner_selector, ModelRunnerSelector)
|
|
1912
|
+
else {}
|
|
1913
|
+
)
|
|
1914
|
+
model_runner_selector = (
|
|
1915
|
+
model_runner_selector
|
|
1916
|
+
if isinstance(model_runner_selector, str)
|
|
1917
|
+
else model_runner_selector.__class__.__name__
|
|
1918
|
+
)
|
|
1919
|
+
|
|
1431
1920
|
super().__init__(
|
|
1432
1921
|
*args,
|
|
1433
1922
|
name=name,
|
|
1434
1923
|
raise_exception=raise_exception,
|
|
1435
1924
|
class_name="mlrun.serving.ModelRunner",
|
|
1436
|
-
class_args=dict(
|
|
1925
|
+
class_args=dict(
|
|
1926
|
+
model_selector=(model_selector, model_selector_parameters),
|
|
1927
|
+
model_runner_selector=(
|
|
1928
|
+
model_runner_selector,
|
|
1929
|
+
model_runner_selector_parameters,
|
|
1930
|
+
),
|
|
1931
|
+
),
|
|
1437
1932
|
**kwargs,
|
|
1438
1933
|
)
|
|
1439
1934
|
self.raise_exception = raise_exception
|
|
@@ -1449,10 +1944,6 @@ class ModelRunnerStep(MonitoredStep):
|
|
|
1449
1944
|
model_endpoint_creation_strategy: Optional[
|
|
1450
1945
|
schemas.ModelEndpointCreationStrategy
|
|
1451
1946
|
] = schemas.ModelEndpointCreationStrategy.INPLACE,
|
|
1452
|
-
inputs: Optional[list[str]] = None,
|
|
1453
|
-
outputs: Optional[list[str]] = None,
|
|
1454
|
-
input_path: Optional[str] = None,
|
|
1455
|
-
result_path: Optional[str] = None,
|
|
1456
1947
|
override: bool = False,
|
|
1457
1948
|
) -> None:
|
|
1458
1949
|
"""
|
|
@@ -1465,28 +1956,18 @@ class ModelRunnerStep(MonitoredStep):
|
|
|
1465
1956
|
:param shared_model_name: str, the name of the shared model that is already defined within the graph
|
|
1466
1957
|
:param labels: model endpoint labels, should be list of str or mapping of str:str
|
|
1467
1958
|
:param model_endpoint_creation_strategy: Strategy for creating or updating the model endpoint:
|
|
1468
|
-
* **overwrite**:
|
|
1469
|
-
1. If model endpoints with the same name exist, delete the `latest` one.
|
|
1470
|
-
2. Create a new model endpoint entry and set it as `latest`.
|
|
1471
|
-
* **inplace** (default):
|
|
1472
|
-
1. If model endpoints with the same name exist, update the `latest` entry.
|
|
1473
|
-
2. Otherwise, create a new entry.
|
|
1474
|
-
* **archive**:
|
|
1475
|
-
1. If model endpoints with the same name exist, preserve them.
|
|
1476
|
-
2. Create a new model endpoint with the same name and set it to `latest`.
|
|
1477
1959
|
|
|
1478
|
-
|
|
1479
|
-
|
|
1480
|
-
|
|
1481
|
-
|
|
1482
|
-
|
|
1483
|
-
|
|
1484
|
-
|
|
1485
|
-
|
|
1486
|
-
|
|
1487
|
-
notation (e.g "outputs.my_model_outputs") expects list or dictionary type object
|
|
1488
|
-
in path.
|
|
1960
|
+
* **overwrite**: If model endpoints with the same name exist, delete the `latest` one;
|
|
1961
|
+
create a new model endpoint entry and set it as `latest`.
|
|
1962
|
+
|
|
1963
|
+
* **inplace** (default): If model endpoints with the same name exist, update the `latest` entry;
|
|
1964
|
+
otherwise, create a new entry.
|
|
1965
|
+
|
|
1966
|
+
* **archive**: If model endpoints with the same name exist, preserve them;
|
|
1967
|
+
create a new model endpoint with the same name and set it to `latest`.
|
|
1968
|
+
|
|
1489
1969
|
:param override: bool allow override existing model on the current ModelRunnerStep.
|
|
1970
|
+
:raise GraphError: when the shared model is not found in the root flow step shared models.
|
|
1490
1971
|
"""
|
|
1491
1972
|
model_class, model_params = (
|
|
1492
1973
|
"mlrun.serving.Model",
|
|
@@ -1503,11 +1984,21 @@ class ModelRunnerStep(MonitoredStep):
|
|
|
1503
1984
|
"model_artifact must be a string, ModelArtifact or LLMPromptArtifact"
|
|
1504
1985
|
)
|
|
1505
1986
|
root = self._extract_root_step()
|
|
1987
|
+
shared_model_params = {}
|
|
1506
1988
|
if isinstance(root, RootFlowStep):
|
|
1507
|
-
|
|
1508
|
-
|
|
1509
|
-
or root.get_shared_model_name_by_artifact_uri(model_artifact_uri)
|
|
1989
|
+
actual_shared_model_name, shared_model_class, shared_model_params = (
|
|
1990
|
+
root.get_shared_model_by_artifact_uri(model_artifact_uri)
|
|
1510
1991
|
)
|
|
1992
|
+
if not actual_shared_model_name or (
|
|
1993
|
+
shared_model_name and actual_shared_model_name != shared_model_name
|
|
1994
|
+
):
|
|
1995
|
+
raise GraphError(
|
|
1996
|
+
f"ModelRunnerStep can only add proxy models that were added to the root flow step, "
|
|
1997
|
+
f"model {shared_model_name} is not in the shared models."
|
|
1998
|
+
)
|
|
1999
|
+
elif not shared_model_name:
|
|
2000
|
+
shared_model_name = actual_shared_model_name
|
|
2001
|
+
model_params["shared_runnable_name"] = shared_model_name
|
|
1511
2002
|
if not root.shared_models or (
|
|
1512
2003
|
root.shared_models
|
|
1513
2004
|
and shared_model_name
|
|
@@ -1517,17 +2008,31 @@ class ModelRunnerStep(MonitoredStep):
|
|
|
1517
2008
|
f"ModelRunnerStep can only add proxy models that were added to the root flow step, "
|
|
1518
2009
|
f"model {shared_model_name} is not in the shared models."
|
|
1519
2010
|
)
|
|
1520
|
-
|
|
2011
|
+
monitoring_data = self.class_args.get(
|
|
2012
|
+
schemas.ModelRunnerStepData.MONITORING_DATA, {}
|
|
2013
|
+
)
|
|
2014
|
+
monitoring_data.setdefault(endpoint_name, {})[
|
|
2015
|
+
schemas.MonitoringData.MODEL_CLASS
|
|
2016
|
+
] = (
|
|
2017
|
+
shared_model_class
|
|
2018
|
+
if isinstance(shared_model_class, str)
|
|
2019
|
+
else shared_model_class.__class__.__name__
|
|
2020
|
+
)
|
|
2021
|
+
self.class_args[schemas.ModelRunnerStepData.MONITORING_DATA] = (
|
|
2022
|
+
monitoring_data
|
|
2023
|
+
)
|
|
2024
|
+
|
|
2025
|
+
if shared_model_name and shared_model_name not in self._shared_proxy_mapping:
|
|
1521
2026
|
self._shared_proxy_mapping[shared_model_name] = {
|
|
1522
2027
|
endpoint_name: model_artifact.uri
|
|
1523
|
-
if isinstance(model_artifact,
|
|
2028
|
+
if isinstance(model_artifact, ModelArtifact | LLMPromptArtifact)
|
|
1524
2029
|
else model_artifact
|
|
1525
2030
|
}
|
|
1526
|
-
|
|
2031
|
+
elif override and shared_model_name:
|
|
1527
2032
|
self._shared_proxy_mapping[shared_model_name].update(
|
|
1528
2033
|
{
|
|
1529
2034
|
endpoint_name: model_artifact.uri
|
|
1530
|
-
if isinstance(model_artifact,
|
|
2035
|
+
if isinstance(model_artifact, ModelArtifact | LLMPromptArtifact)
|
|
1531
2036
|
else model_artifact
|
|
1532
2037
|
}
|
|
1533
2038
|
)
|
|
@@ -1538,11 +2043,11 @@ class ModelRunnerStep(MonitoredStep):
|
|
|
1538
2043
|
model_artifact=model_artifact,
|
|
1539
2044
|
labels=labels,
|
|
1540
2045
|
model_endpoint_creation_strategy=model_endpoint_creation_strategy,
|
|
2046
|
+
inputs=shared_model_params.get("inputs"),
|
|
2047
|
+
outputs=shared_model_params.get("outputs"),
|
|
2048
|
+
input_path=shared_model_params.get("input_path"),
|
|
2049
|
+
result_path=shared_model_params.get("result_path"),
|
|
1541
2050
|
override=override,
|
|
1542
|
-
inputs=inputs,
|
|
1543
|
-
outputs=outputs,
|
|
1544
|
-
input_path=input_path,
|
|
1545
|
-
result_path=result_path,
|
|
1546
2051
|
**model_params,
|
|
1547
2052
|
)
|
|
1548
2053
|
|
|
@@ -1567,48 +2072,52 @@ class ModelRunnerStep(MonitoredStep):
|
|
|
1567
2072
|
Add a Model to this ModelRunner.
|
|
1568
2073
|
|
|
1569
2074
|
:param endpoint_name: str, will identify the model in the ModelRunnerStep, and assign model endpoint name
|
|
1570
|
-
:param model_class: Model class name
|
|
2075
|
+
:param model_class: Model class name. If LLModel is chosen
|
|
2076
|
+
(either by name `LLModel` or by its full path, e.g. mlrun.serving.states.LLModel),
|
|
2077
|
+
outputs will be overridden with UsageResponseKeys fields.
|
|
1571
2078
|
:param execution_mechanism: Parallel execution mechanism to be used to execute this model. Must be one of:
|
|
1572
|
-
|
|
1573
|
-
|
|
1574
|
-
|
|
1575
|
-
|
|
1576
|
-
|
|
1577
|
-
|
|
1578
|
-
|
|
1579
|
-
|
|
1580
|
-
|
|
1581
|
-
|
|
1582
|
-
|
|
1583
|
-
|
|
1584
|
-
|
|
1585
|
-
|
|
1586
|
-
|
|
1587
|
-
|
|
1588
|
-
|
|
1589
|
-
|
|
1590
|
-
|
|
1591
|
-
|
|
1592
|
-
|
|
1593
|
-
|
|
1594
|
-
|
|
1595
|
-
|
|
1596
|
-
|
|
1597
|
-
|
|
1598
|
-
|
|
1599
|
-
|
|
1600
|
-
|
|
1601
|
-
|
|
1602
|
-
|
|
1603
|
-
|
|
1604
|
-
|
|
1605
|
-
:param inputs: list of the model inputs (e.g. features) ,if provided will override the inputs
|
|
2079
|
+
|
|
2080
|
+
* **process_pool**: To run in a separate process from a process pool. This is appropriate
|
|
2081
|
+
for CPU or GPU intensive tasks as they would otherwise block the main process by holding
|
|
2082
|
+
Python's Global Interpreter Lock (GIL).
|
|
2083
|
+
|
|
2084
|
+
* **dedicated_process**: To run in a separate dedicated process. This is appropriate for CPU
|
|
2085
|
+
or GPU intensive tasks that also require significant Runnable-specific initialization
|
|
2086
|
+
(e.g. a large model).
|
|
2087
|
+
|
|
2088
|
+
* **thread_pool**: To run in a separate thread. This is appropriate for blocking I/O tasks,
|
|
2089
|
+
as they would otherwise block the main event loop thread.
|
|
2090
|
+
|
|
2091
|
+
* **asyncio**: To run in an asyncio task. This is appropriate for I/O tasks that use
|
|
2092
|
+
asyncio, allowing the event loop to continue running while waiting for a response.
|
|
2093
|
+
|
|
2094
|
+
* **naive**: To run in the main event loop. This is appropriate only for trivial computation
|
|
2095
|
+
and/or file I/O. It means that the runnable will not actually be run in parallel to
|
|
2096
|
+
anything else.
|
|
2097
|
+
|
|
2098
|
+
:param model_artifact: model artifact or mlrun model artifact uri
|
|
2099
|
+
:param labels: model endpoint labels, should be list of str or mapping of str:str
|
|
2100
|
+
:param model_endpoint_creation_strategy: Strategy for creating or updating the model endpoint:
|
|
2101
|
+
|
|
2102
|
+
* **overwrite**: If model endpoints with the same name exist, delete the `latest` one;
|
|
2103
|
+
create a new model endpoint entry and set it as `latest`.
|
|
2104
|
+
|
|
2105
|
+
* **inplace** (default): If model endpoints with the same name exist, update the `latest`
|
|
2106
|
+
entry; otherwise, create a new entry.
|
|
2107
|
+
|
|
2108
|
+
* **archive**: If model endpoints with the same name exist, preserve them;
|
|
2109
|
+
create a new model endpoint with the same name and set it to `latest`.
|
|
2110
|
+
|
|
2111
|
+
:param inputs: list of the model inputs (e.g. features) ,if provided will override the inputs
|
|
1606
2112
|
that been configured in the model artifact, please note that those inputs need to
|
|
1607
2113
|
be equal in length and order to the inputs that model_class predict method expects
|
|
1608
|
-
|
|
2114
|
+
:param outputs: list of the model outputs (e.g. labels) ,if provided will override the outputs
|
|
1609
2115
|
that been configured in the model artifact, please note that those outputs need to
|
|
1610
2116
|
be equal to the model_class predict method outputs (length, and order)
|
|
1611
|
-
|
|
2117
|
+
|
|
2118
|
+
When using LLModel, the output will be overridden with UsageResponseKeys.fields().
|
|
2119
|
+
|
|
2120
|
+
:param input_path: when specified selects the key/path in the event to use as model monitoring inputs
|
|
1612
2121
|
this require that the event body will behave like a dict, expects scopes to be
|
|
1613
2122
|
defined by dot notation (e.g "data.d").
|
|
1614
2123
|
examples: input_path="data.b"
|
|
@@ -1618,7 +2127,7 @@ class ModelRunnerStep(MonitoredStep):
|
|
|
1618
2127
|
be {"f0": [1, 2]}.
|
|
1619
2128
|
if a ``list`` or ``list of lists`` is provided, it must follow the order and
|
|
1620
2129
|
size defined by the input schema.
|
|
1621
|
-
|
|
2130
|
+
:param result_path: when specified selects the key/path in the output event to use as model monitoring
|
|
1622
2131
|
outputs this require that the output event body will behave like a dict,
|
|
1623
2132
|
expects scopes to be defined by dot notation (e.g "data.d").
|
|
1624
2133
|
examples: result_path="out.b"
|
|
@@ -1629,14 +2138,22 @@ class ModelRunnerStep(MonitoredStep):
|
|
|
1629
2138
|
if a ``list`` or ``list of lists`` is provided, it must follow the order and
|
|
1630
2139
|
size defined by the output schema.
|
|
1631
2140
|
|
|
1632
|
-
|
|
1633
|
-
|
|
2141
|
+
:param override: bool allow override existing model on the current ModelRunnerStep.
|
|
2142
|
+
:param model_parameters: Parameters for model instantiation
|
|
1634
2143
|
"""
|
|
1635
2144
|
if isinstance(model_class, Model) and model_parameters:
|
|
1636
2145
|
raise mlrun.errors.MLRunInvalidArgumentError(
|
|
1637
2146
|
"Cannot provide a model object as argument to `model_class` and also provide `model_parameters`."
|
|
1638
2147
|
)
|
|
1639
|
-
|
|
2148
|
+
if type(model_class) is LLModel or (
|
|
2149
|
+
isinstance(model_class, str)
|
|
2150
|
+
and model_class.split(".")[-1] == LLModel.__name__
|
|
2151
|
+
):
|
|
2152
|
+
if outputs:
|
|
2153
|
+
warnings.warn(
|
|
2154
|
+
"LLModel with existing outputs detected, overriding to default"
|
|
2155
|
+
)
|
|
2156
|
+
outputs = UsageResponseKeys.fields()
|
|
1640
2157
|
model_parameters = model_parameters or (
|
|
1641
2158
|
model_class.to_dict() if isinstance(model_class, Model) else {}
|
|
1642
2159
|
)
|
|
@@ -1652,8 +2169,6 @@ class ModelRunnerStep(MonitoredStep):
|
|
|
1652
2169
|
except mlrun.errors.MLRunNotFoundError:
|
|
1653
2170
|
raise mlrun.errors.MLRunInvalidArgumentError("Artifact not found.")
|
|
1654
2171
|
|
|
1655
|
-
outputs = outputs or self._get_model_output_schema(model_artifact)
|
|
1656
|
-
|
|
1657
2172
|
model_artifact = (
|
|
1658
2173
|
model_artifact.uri
|
|
1659
2174
|
if isinstance(model_artifact, mlrun.artifacts.Artifact)
|
|
@@ -1719,28 +2234,13 @@ class ModelRunnerStep(MonitoredStep):
|
|
|
1719
2234
|
self.class_args[schemas.ModelRunnerStepData.MONITORING_DATA] = monitoring_data
|
|
1720
2235
|
|
|
1721
2236
|
@staticmethod
|
|
1722
|
-
def
|
|
1723
|
-
model_artifact: Union[ModelArtifact, LLMPromptArtifact],
|
|
1724
|
-
) -> Optional[list[str]]:
|
|
1725
|
-
if isinstance(
|
|
1726
|
-
model_artifact,
|
|
1727
|
-
ModelArtifact,
|
|
1728
|
-
):
|
|
1729
|
-
return [feature.name for feature in model_artifact.spec.outputs]
|
|
1730
|
-
elif isinstance(
|
|
1731
|
-
model_artifact,
|
|
1732
|
-
LLMPromptArtifact,
|
|
1733
|
-
):
|
|
1734
|
-
_model_artifact = model_artifact.model_artifact
|
|
1735
|
-
return [feature.name for feature in _model_artifact.spec.outputs]
|
|
1736
|
-
|
|
1737
|
-
@staticmethod
|
|
1738
|
-
def _get_model_endpoint_output_schema(
|
|
2237
|
+
def _get_model_endpoint_schema(
|
|
1739
2238
|
name: str,
|
|
1740
2239
|
project: str,
|
|
1741
2240
|
uid: str,
|
|
1742
|
-
) -> list[str]:
|
|
2241
|
+
) -> tuple[list[str], list[str]]:
|
|
1743
2242
|
output_schema = None
|
|
2243
|
+
input_schema = None
|
|
1744
2244
|
try:
|
|
1745
2245
|
model_endpoint: mlrun.common.schemas.model_monitoring.ModelEndpoint = (
|
|
1746
2246
|
mlrun.db.get_run_db().get_model_endpoint(
|
|
@@ -1751,6 +2251,7 @@ class ModelRunnerStep(MonitoredStep):
|
|
|
1751
2251
|
)
|
|
1752
2252
|
)
|
|
1753
2253
|
output_schema = model_endpoint.spec.label_names
|
|
2254
|
+
input_schema = model_endpoint.spec.feature_names
|
|
1754
2255
|
except (
|
|
1755
2256
|
mlrun.errors.MLRunNotFoundError,
|
|
1756
2257
|
mlrun.errors.MLRunInvalidArgumentError,
|
|
@@ -1759,7 +2260,7 @@ class ModelRunnerStep(MonitoredStep):
|
|
|
1759
2260
|
f"Model endpoint not found, using default output schema for model {name}",
|
|
1760
2261
|
error=f"{type(ex).__name__}: {ex}",
|
|
1761
2262
|
)
|
|
1762
|
-
return output_schema
|
|
2263
|
+
return input_schema, output_schema
|
|
1763
2264
|
|
|
1764
2265
|
def _calculate_monitoring_data(self) -> dict[str, dict[str, str]]:
|
|
1765
2266
|
monitoring_data = deepcopy(
|
|
@@ -1775,47 +2276,154 @@ class ModelRunnerStep(MonitoredStep):
|
|
|
1775
2276
|
monitoring_data[model][schemas.MonitoringData.RESULT_PATH] = split_path(
|
|
1776
2277
|
monitoring_data[model][schemas.MonitoringData.RESULT_PATH]
|
|
1777
2278
|
)
|
|
2279
|
+
|
|
2280
|
+
mep_output_schema, mep_input_schema = None, None
|
|
2281
|
+
|
|
2282
|
+
output_schema = self.class_args[
|
|
2283
|
+
mlrun.common.schemas.ModelRunnerStepData.MONITORING_DATA
|
|
2284
|
+
][model][schemas.MonitoringData.OUTPUTS]
|
|
2285
|
+
input_schema = self.class_args[
|
|
2286
|
+
mlrun.common.schemas.ModelRunnerStepData.MONITORING_DATA
|
|
2287
|
+
][model][schemas.MonitoringData.INPUTS]
|
|
2288
|
+
if not output_schema or not input_schema:
|
|
2289
|
+
# if output or input schema is not provided, try to get it from the model endpoint
|
|
2290
|
+
mep_input_schema, mep_output_schema = (
|
|
2291
|
+
self._get_model_endpoint_schema(
|
|
2292
|
+
model,
|
|
2293
|
+
self.context.project,
|
|
2294
|
+
monitoring_data[model].get(
|
|
2295
|
+
schemas.MonitoringData.MODEL_ENDPOINT_UID, ""
|
|
2296
|
+
),
|
|
2297
|
+
)
|
|
2298
|
+
)
|
|
2299
|
+
self.class_args[
|
|
2300
|
+
mlrun.common.schemas.ModelRunnerStepData.MONITORING_DATA
|
|
2301
|
+
][model][schemas.MonitoringData.OUTPUTS] = (
|
|
2302
|
+
output_schema or mep_output_schema
|
|
2303
|
+
)
|
|
2304
|
+
self.class_args[
|
|
2305
|
+
mlrun.common.schemas.ModelRunnerStepData.MONITORING_DATA
|
|
2306
|
+
][model][schemas.MonitoringData.INPUTS] = (
|
|
2307
|
+
input_schema or mep_input_schema
|
|
2308
|
+
)
|
|
1778
2309
|
return monitoring_data
|
|
1779
2310
|
else:
|
|
1780
2311
|
raise mlrun.errors.MLRunInvalidArgumentError(
|
|
1781
2312
|
"Monitoring data must be a dictionary."
|
|
1782
2313
|
)
|
|
1783
2314
|
|
|
2315
|
+
def configure_pool_resource(
|
|
2316
|
+
self,
|
|
2317
|
+
max_processes: Optional[int] = None,
|
|
2318
|
+
max_threads: Optional[int] = None,
|
|
2319
|
+
pool_factor: Optional[int] = None,
|
|
2320
|
+
) -> None:
|
|
2321
|
+
"""
|
|
2322
|
+
Configure the resource limits for the shared models in the graph.
|
|
2323
|
+
|
|
2324
|
+
:param max_processes: Maximum number of processes to spawn (excluding dedicated processes).
|
|
2325
|
+
Defaults to the number of CPUs or 16 if undetectable.
|
|
2326
|
+
:param max_threads: Maximum number of threads to spawn. Defaults to 32.
|
|
2327
|
+
:param pool_factor: Multiplier to scale the number of process/thread workers per runnable. Defaults to 1.
|
|
2328
|
+
"""
|
|
2329
|
+
self.max_processes = max_processes
|
|
2330
|
+
self.max_threads = max_threads
|
|
2331
|
+
self.pool_factor = pool_factor
|
|
2332
|
+
|
|
1784
2333
|
def init_object(self, context, namespace, mode="sync", reset=False, **extra_kwargs):
|
|
1785
2334
|
self.context = context
|
|
1786
2335
|
if not self._is_local_function(context):
|
|
1787
2336
|
# skip init of non local functions
|
|
1788
2337
|
return
|
|
1789
|
-
model_selector = self.class_args.get(
|
|
2338
|
+
model_selector, model_selector_params = self.class_args.get(
|
|
2339
|
+
"model_selector", (None, None)
|
|
2340
|
+
)
|
|
2341
|
+
model_runner_selector, model_runner_selector_params = self.class_args.get(
|
|
2342
|
+
"model_runner_selector", (None, None)
|
|
2343
|
+
)
|
|
1790
2344
|
execution_mechanism_by_model_name = self.class_args.get(
|
|
1791
2345
|
schemas.ModelRunnerStepData.MODEL_TO_EXECUTION_MECHANISM
|
|
1792
2346
|
)
|
|
1793
2347
|
models = self.class_args.get(schemas.ModelRunnerStepData.MODELS, {})
|
|
1794
|
-
if
|
|
1795
|
-
model_selector = get_class(model_selector, namespace)(
|
|
2348
|
+
if model_selector:
|
|
2349
|
+
model_selector = get_class(model_selector, namespace).from_dict(
|
|
2350
|
+
model_selector_params, init_with_params=True
|
|
2351
|
+
)
|
|
2352
|
+
model_runner_selector = (
|
|
2353
|
+
self._convert_model_selector_to_model_runner_selector(
|
|
2354
|
+
model_selector=model_selector
|
|
2355
|
+
)
|
|
2356
|
+
)
|
|
2357
|
+
elif model_runner_selector:
|
|
2358
|
+
model_runner_selector = get_class(
|
|
2359
|
+
model_runner_selector, namespace
|
|
2360
|
+
).from_dict(model_runner_selector_params, init_with_params=True)
|
|
1796
2361
|
model_objects = []
|
|
1797
2362
|
for model, model_params in models.values():
|
|
2363
|
+
model_name = model_params.get("name")
|
|
1798
2364
|
model_params[schemas.MonitoringData.INPUT_PATH] = (
|
|
1799
2365
|
self.class_args.get(
|
|
1800
2366
|
mlrun.common.schemas.ModelRunnerStepData.MONITORING_DATA, {}
|
|
1801
2367
|
)
|
|
1802
|
-
.get(
|
|
2368
|
+
.get(model_name, {})
|
|
1803
2369
|
.get(schemas.MonitoringData.INPUT_PATH)
|
|
1804
2370
|
)
|
|
2371
|
+
model_params[schemas.MonitoringData.RESULT_PATH] = (
|
|
2372
|
+
self.class_args.get(
|
|
2373
|
+
mlrun.common.schemas.ModelRunnerStepData.MONITORING_DATA, {}
|
|
2374
|
+
)
|
|
2375
|
+
.get(model_name, {})
|
|
2376
|
+
.get(schemas.MonitoringData.RESULT_PATH)
|
|
2377
|
+
)
|
|
1805
2378
|
model = get_class(model, namespace).from_dict(
|
|
1806
2379
|
model_params, init_with_params=True
|
|
1807
2380
|
)
|
|
1808
2381
|
model._raise_exception = False
|
|
2382
|
+
model._execution_mechanism = execution_mechanism_by_model_name.get(
|
|
2383
|
+
model_name
|
|
2384
|
+
)
|
|
1809
2385
|
model_objects.append(model)
|
|
1810
2386
|
self._async_object = ModelRunner(
|
|
1811
|
-
|
|
2387
|
+
model_runner_selector=model_runner_selector,
|
|
1812
2388
|
runnables=model_objects,
|
|
1813
2389
|
execution_mechanism_by_runnable_name=execution_mechanism_by_model_name,
|
|
1814
2390
|
shared_proxy_mapping=self._shared_proxy_mapping or None,
|
|
1815
2391
|
name=self.name,
|
|
1816
2392
|
context=context,
|
|
2393
|
+
max_processes=self.max_processes,
|
|
2394
|
+
max_threads=self.max_threads,
|
|
2395
|
+
pool_factor=self.pool_factor,
|
|
2396
|
+
raise_exception=self.raise_exception,
|
|
2397
|
+
**extra_kwargs,
|
|
1817
2398
|
)
|
|
1818
2399
|
|
|
2400
|
+
def _convert_model_selector_to_model_runner_selector(
|
|
2401
|
+
self,
|
|
2402
|
+
model_selector,
|
|
2403
|
+
) -> "ModelRunnerSelector":
|
|
2404
|
+
"""
|
|
2405
|
+
Wrap a ModelSelector into a ModelRunnerSelector for backward compatibility.
|
|
2406
|
+
"""
|
|
2407
|
+
|
|
2408
|
+
class Adapter(ModelRunnerSelector):
|
|
2409
|
+
def __init__(self):
|
|
2410
|
+
self.selector = model_selector
|
|
2411
|
+
|
|
2412
|
+
def select_models(
|
|
2413
|
+
self, event, available_models
|
|
2414
|
+
) -> Union[list[str], list[Model]]:
|
|
2415
|
+
# Call old ModelSelector logic
|
|
2416
|
+
return self.selector.select(event, available_models)
|
|
2417
|
+
|
|
2418
|
+
def select_outlets(
|
|
2419
|
+
self,
|
|
2420
|
+
event,
|
|
2421
|
+
) -> Optional[list[str]]:
|
|
2422
|
+
# By default, return all outlets (old ModelSelector didn't control routing)
|
|
2423
|
+
return None
|
|
2424
|
+
|
|
2425
|
+
return Adapter()
|
|
2426
|
+
|
|
1819
2427
|
|
|
1820
2428
|
class ModelRunnerErrorRaiser(storey.MapClass):
|
|
1821
2429
|
def __init__(self, raise_exception: bool, models_names: list[str], **kwargs):
|
|
@@ -1828,11 +2436,15 @@ class ModelRunnerErrorRaiser(storey.MapClass):
|
|
|
1828
2436
|
errors = {}
|
|
1829
2437
|
should_raise = False
|
|
1830
2438
|
if len(self._models_names) == 1:
|
|
1831
|
-
|
|
1832
|
-
|
|
2439
|
+
if isinstance(event.body, dict):
|
|
2440
|
+
should_raise = event.body.get("error") is not None
|
|
2441
|
+
errors[self._models_names[0]] = event.body.get("error")
|
|
1833
2442
|
else:
|
|
1834
2443
|
for model in event.body:
|
|
1835
|
-
|
|
2444
|
+
body_by_model = event.body.get(model)
|
|
2445
|
+
errors[model] = None
|
|
2446
|
+
if isinstance(body_by_model, dict):
|
|
2447
|
+
errors[model] = body_by_model.get("error")
|
|
1836
2448
|
if errors[model] is not None:
|
|
1837
2449
|
should_raise = True
|
|
1838
2450
|
if should_raise:
|
|
@@ -1902,6 +2514,8 @@ class QueueStep(BaseStep, StepToDict):
|
|
|
1902
2514
|
model_endpoint_creation_strategy: Optional[
|
|
1903
2515
|
schemas.ModelEndpointCreationStrategy
|
|
1904
2516
|
] = None,
|
|
2517
|
+
cycle_to: Optional[list[str]] = None,
|
|
2518
|
+
max_iterations: Optional[int] = None,
|
|
1905
2519
|
**class_args,
|
|
1906
2520
|
):
|
|
1907
2521
|
if not function:
|
|
@@ -1919,6 +2533,8 @@ class QueueStep(BaseStep, StepToDict):
|
|
|
1919
2533
|
input_path,
|
|
1920
2534
|
result_path,
|
|
1921
2535
|
model_endpoint_creation_strategy,
|
|
2536
|
+
cycle_to,
|
|
2537
|
+
max_iterations,
|
|
1922
2538
|
**class_args,
|
|
1923
2539
|
)
|
|
1924
2540
|
|
|
@@ -1954,8 +2570,10 @@ class FlowStep(BaseStep):
|
|
|
1954
2570
|
after: Optional[list] = None,
|
|
1955
2571
|
engine=None,
|
|
1956
2572
|
final_step=None,
|
|
2573
|
+
allow_cyclic: bool = False,
|
|
2574
|
+
max_iterations: Optional[int] = None,
|
|
1957
2575
|
):
|
|
1958
|
-
super().__init__(name, after)
|
|
2576
|
+
super().__init__(name, after, max_iterations=max_iterations)
|
|
1959
2577
|
self._steps = None
|
|
1960
2578
|
self.steps = steps
|
|
1961
2579
|
self.engine = engine
|
|
@@ -1967,6 +2585,7 @@ class FlowStep(BaseStep):
|
|
|
1967
2585
|
self._wait_for_result = False
|
|
1968
2586
|
self._source = None
|
|
1969
2587
|
self._start_steps = []
|
|
2588
|
+
self._allow_cyclic = allow_cyclic
|
|
1970
2589
|
|
|
1971
2590
|
def get_children(self):
|
|
1972
2591
|
return self._steps.values()
|
|
@@ -2000,6 +2619,8 @@ class FlowStep(BaseStep):
|
|
|
2000
2619
|
model_endpoint_creation_strategy: Optional[
|
|
2001
2620
|
schemas.ModelEndpointCreationStrategy
|
|
2002
2621
|
] = None,
|
|
2622
|
+
cycle_to: Optional[list[str]] = None,
|
|
2623
|
+
max_iterations: Optional[int] = None,
|
|
2003
2624
|
**class_args,
|
|
2004
2625
|
):
|
|
2005
2626
|
"""add task, queue or router step/class to the flow
|
|
@@ -2033,21 +2654,17 @@ class FlowStep(BaseStep):
|
|
|
2033
2654
|
to event["y"] resulting in {"x": 5, "y": <result>}
|
|
2034
2655
|
:param model_endpoint_creation_strategy: Strategy for creating or updating the model endpoint:
|
|
2035
2656
|
|
|
2036
|
-
|
|
2037
|
-
|
|
2038
|
-
1. If model endpoints with the same name exist, delete the `latest` one.
|
|
2039
|
-
2. Create a new model endpoint entry and set it as `latest`.
|
|
2657
|
+
* **overwrite**: If model endpoints with the same name exist, delete the `latest` one;
|
|
2658
|
+
create a new model endpoint entry and set it as `latest`.
|
|
2040
2659
|
|
|
2041
|
-
* **inplace** (default):
|
|
2660
|
+
* **inplace** (default): If model endpoints with the same name exist, update the `latest`
|
|
2661
|
+
entry; otherwise, create a new entry.
|
|
2042
2662
|
|
|
2043
|
-
|
|
2044
|
-
|
|
2045
|
-
|
|
2046
|
-
* **archive**:
|
|
2047
|
-
|
|
2048
|
-
1. If model endpoints with the same name exist, preserve them.
|
|
2049
|
-
2. Create a new model endpoint with the same name and set it to `latest`.
|
|
2663
|
+
* **archive**: If model endpoints with the same name exist, preserve them;
|
|
2664
|
+
create a new model endpoint with the same name and set it to `latest`.
|
|
2050
2665
|
|
|
2666
|
+
:param cycle_to: list of step names to create a cycle to (for cyclic graphs)
|
|
2667
|
+
:param max_iterations: maximum number of iterations for this step in case of a cycle graph
|
|
2051
2668
|
:param class_args: class init arguments
|
|
2052
2669
|
"""
|
|
2053
2670
|
|
|
@@ -2073,6 +2690,8 @@ class FlowStep(BaseStep):
|
|
|
2073
2690
|
after_list = after if isinstance(after, list) else [after]
|
|
2074
2691
|
for after in after_list:
|
|
2075
2692
|
self.insert_step(name, step, after, before)
|
|
2693
|
+
step.cycle_to(cycle_to or [])
|
|
2694
|
+
step._max_iterations = max_iterations
|
|
2076
2695
|
return step
|
|
2077
2696
|
|
|
2078
2697
|
def insert_step(self, key, step, after, before=None):
|
|
@@ -2165,13 +2784,24 @@ class FlowStep(BaseStep):
|
|
|
2165
2784
|
for step in self._steps.values():
|
|
2166
2785
|
step._next = None
|
|
2167
2786
|
step._visited = False
|
|
2168
|
-
if step.after:
|
|
2787
|
+
if step.after and not step.cycle_from:
|
|
2788
|
+
has_illegal_branches = len(step.after) > 1 and self.engine == "sync"
|
|
2789
|
+
if has_illegal_branches:
|
|
2790
|
+
raise GraphError(
|
|
2791
|
+
f"synchronous flow engine doesnt support branches use async for step {step.name}"
|
|
2792
|
+
)
|
|
2169
2793
|
loop_step = has_loop(step, [])
|
|
2170
|
-
if loop_step:
|
|
2794
|
+
if loop_step and not self.allow_cyclic:
|
|
2171
2795
|
raise GraphError(
|
|
2172
2796
|
f"Error, loop detected in step {loop_step}, graph must be acyclic (DAG)"
|
|
2173
2797
|
)
|
|
2174
|
-
|
|
2798
|
+
elif (
|
|
2799
|
+
step.after
|
|
2800
|
+
and step.cycle_from
|
|
2801
|
+
and set(step.after) == set(step.cycle_from)
|
|
2802
|
+
):
|
|
2803
|
+
start_steps.append(step.name)
|
|
2804
|
+
elif not step.cycle_from:
|
|
2175
2805
|
start_steps.append(step.name)
|
|
2176
2806
|
|
|
2177
2807
|
responders = []
|
|
@@ -2268,6 +2898,9 @@ class FlowStep(BaseStep):
|
|
|
2268
2898
|
def process_step(state, step, root):
|
|
2269
2899
|
if not state._is_local_function(self.context) or state._visited:
|
|
2270
2900
|
return
|
|
2901
|
+
state._visited = (
|
|
2902
|
+
True # mark visited to avoid re-visit in case of multiple uplinks
|
|
2903
|
+
)
|
|
2271
2904
|
for item in state.next or []:
|
|
2272
2905
|
next_state = root[item]
|
|
2273
2906
|
if next_state.async_object:
|
|
@@ -2278,7 +2911,7 @@ class FlowStep(BaseStep):
|
|
|
2278
2911
|
)
|
|
2279
2912
|
|
|
2280
2913
|
default_source, self._wait_for_result = _init_async_objects(
|
|
2281
|
-
self.context, self._steps.values()
|
|
2914
|
+
self.context, self._steps.values(), self
|
|
2282
2915
|
)
|
|
2283
2916
|
|
|
2284
2917
|
source = self._source or default_source
|
|
@@ -2509,6 +3142,8 @@ class RootFlowStep(FlowStep):
|
|
|
2509
3142
|
"shared_models",
|
|
2510
3143
|
"shared_models_mechanism",
|
|
2511
3144
|
"pool_factor",
|
|
3145
|
+
"allow_cyclic",
|
|
3146
|
+
"max_iterations",
|
|
2512
3147
|
]
|
|
2513
3148
|
|
|
2514
3149
|
def __init__(
|
|
@@ -2518,13 +3153,11 @@ class RootFlowStep(FlowStep):
|
|
|
2518
3153
|
after: Optional[list] = None,
|
|
2519
3154
|
engine=None,
|
|
2520
3155
|
final_step=None,
|
|
3156
|
+
allow_cyclic: bool = False,
|
|
3157
|
+
max_iterations: Optional[int] = 10_000,
|
|
2521
3158
|
):
|
|
2522
3159
|
super().__init__(
|
|
2523
|
-
name,
|
|
2524
|
-
steps,
|
|
2525
|
-
after,
|
|
2526
|
-
engine,
|
|
2527
|
-
final_step,
|
|
3160
|
+
name, steps, after, engine, final_step, allow_cyclic, max_iterations
|
|
2528
3161
|
)
|
|
2529
3162
|
self._models = set()
|
|
2530
3163
|
self._route_models = set()
|
|
@@ -2535,48 +3168,102 @@ class RootFlowStep(FlowStep):
|
|
|
2535
3168
|
self._shared_max_threads = None
|
|
2536
3169
|
self._pool_factor = None
|
|
2537
3170
|
|
|
3171
|
+
@property
|
|
3172
|
+
def max_iterations(self) -> int:
|
|
3173
|
+
return self._max_iterations
|
|
3174
|
+
|
|
3175
|
+
@max_iterations.setter
|
|
3176
|
+
def max_iterations(self, max_iterations: int):
|
|
3177
|
+
self._max_iterations = max_iterations
|
|
3178
|
+
|
|
3179
|
+
@property
|
|
3180
|
+
def allow_cyclic(self) -> bool:
|
|
3181
|
+
return self._allow_cyclic
|
|
3182
|
+
|
|
3183
|
+
@allow_cyclic.setter
|
|
3184
|
+
def allow_cyclic(self, allow_cyclic: bool):
|
|
3185
|
+
self._allow_cyclic = allow_cyclic
|
|
3186
|
+
|
|
2538
3187
|
def add_shared_model(
|
|
2539
3188
|
self,
|
|
2540
3189
|
name: str,
|
|
2541
3190
|
model_class: Union[str, Model],
|
|
2542
3191
|
execution_mechanism: Union[str, ParallelExecutionMechanisms],
|
|
2543
3192
|
model_artifact: Union[str, ModelArtifact],
|
|
3193
|
+
inputs: Optional[list[str]] = None,
|
|
3194
|
+
outputs: Optional[list[str]] = None,
|
|
3195
|
+
input_path: Optional[str] = None,
|
|
3196
|
+
result_path: Optional[str] = None,
|
|
2544
3197
|
override: bool = False,
|
|
2545
3198
|
**model_parameters,
|
|
2546
3199
|
) -> None:
|
|
2547
3200
|
"""
|
|
2548
3201
|
Add a shared model to the graph, this model will be available to all the ModelRunners in the graph
|
|
2549
3202
|
:param name: Name of the shared model (should be unique in the graph)
|
|
2550
|
-
:param model_class: Model class name
|
|
3203
|
+
:param model_class: Model class name. If LLModel is chosen
|
|
3204
|
+
(either by name `LLModel` or by its full path, e.g. mlrun.serving.states.LLModel),
|
|
3205
|
+
outputs will be overridden with UsageResponseKeys fields.
|
|
2551
3206
|
:param execution_mechanism: Parallel execution mechanism to be used to execute this model. Must be one of:
|
|
2552
|
-
|
|
3207
|
+
|
|
3208
|
+
* **process_pool**: To run in a separate process from a process pool. This is appropriate for CPU or GPU
|
|
2553
3209
|
intensive tasks as they would otherwise block the main process by holding Python's Global Interpreter
|
|
2554
3210
|
Lock (GIL).
|
|
2555
|
-
|
|
2556
|
-
|
|
2557
|
-
|
|
3211
|
+
|
|
3212
|
+
* **dedicated_process**: To run in a separate dedicated process. This is appropriate for CPU or GPU
|
|
3213
|
+
intensive tasks that also require significant Runnable-specific initialization (e.g. a large model).
|
|
3214
|
+
|
|
3215
|
+
* **thread_pool**: To run in a separate thread. This is appropriate for blocking I/O tasks, as they would
|
|
2558
3216
|
otherwise block the main event loop thread.
|
|
2559
|
-
|
|
3217
|
+
|
|
3218
|
+
* **asyncio**: To run in an asyncio task. This is appropriate for I/O tasks that use asyncio, allowing the
|
|
2560
3219
|
event loop to continue running while waiting for a response.
|
|
2561
|
-
|
|
2562
|
-
|
|
2563
|
-
|
|
3220
|
+
|
|
3221
|
+
* **shared_executor**: Reuses an external executor (typically managed by the flow or context) to execute
|
|
3222
|
+
the runnable. Should be used only if you have multiple `ParallelExecution` in the same flow and
|
|
3223
|
+
especially useful when:
|
|
3224
|
+
|
|
2564
3225
|
- You want to share a heavy resource like a large model loaded onto a GPU.
|
|
3226
|
+
|
|
2565
3227
|
- You want to centralize task scheduling or coordination for multiple lightweight tasks.
|
|
3228
|
+
|
|
2566
3229
|
- You aim to minimize overhead from creating new executors or processes/threads per runnable.
|
|
3230
|
+
|
|
2567
3231
|
The runnable is expected to be pre-initialized and reused across events, enabling efficient use of
|
|
2568
3232
|
memory and hardware accelerators.
|
|
2569
|
-
* "naive" – To run in the main event loop. This is appropriate only for trivial computation and/or file I/O.
|
|
2570
|
-
It means that the runnable will not actually be run in parallel to anything else.
|
|
2571
3233
|
|
|
2572
|
-
|
|
2573
|
-
|
|
2574
|
-
|
|
3234
|
+
* **naive**: To run in the main event loop. This is appropriate only for trivial computation and/or file
|
|
3235
|
+
I/O. It means that the runnable will not actually be run in parallel to anything else.
|
|
3236
|
+
|
|
3237
|
+
:param model_artifact: model artifact or mlrun model artifact uri
|
|
3238
|
+
:param inputs: list of the model inputs (e.g. features) ,if provided will override the inputs
|
|
3239
|
+
that been configured in the model artifact, please note that those inputs need
|
|
3240
|
+
to be equal in length and order to the inputs that model_class
|
|
3241
|
+
predict method expects
|
|
3242
|
+
:param outputs: list of the model outputs (e.g. labels) ,if provided will override the outputs
|
|
3243
|
+
that been configured in the model artifact, please note that those outputs need
|
|
3244
|
+
to be equal to the model_class
|
|
3245
|
+
predict method outputs (length, and order)
|
|
3246
|
+
:param input_path: input path inside the user event, expect scopes to be defined by dot notation
|
|
3247
|
+
(e.g "inputs.my_model_inputs"). expects list or dictionary type object in path.
|
|
3248
|
+
:param result_path: result path inside the user output event, expect scopes to be defined by dot
|
|
3249
|
+
notation (e.g "outputs.my_model_outputs") expects list or dictionary type object
|
|
3250
|
+
in path.
|
|
3251
|
+
:param override: bool allow override existing model on the current ModelRunnerStep.
|
|
3252
|
+
:param model_parameters: Parameters for model instantiation
|
|
2575
3253
|
"""
|
|
2576
3254
|
if isinstance(model_class, Model) and model_parameters:
|
|
2577
3255
|
raise mlrun.errors.MLRunInvalidArgumentError(
|
|
2578
3256
|
"Cannot provide a model object as argument to `model_class` and also provide `model_parameters`."
|
|
2579
3257
|
)
|
|
3258
|
+
if type(model_class) is LLModel or (
|
|
3259
|
+
isinstance(model_class, str)
|
|
3260
|
+
and model_class.split(".")[-1] == LLModel.__name__
|
|
3261
|
+
):
|
|
3262
|
+
if outputs:
|
|
3263
|
+
warnings.warn(
|
|
3264
|
+
"LLModel with existing outputs detected, overriding to default"
|
|
3265
|
+
)
|
|
3266
|
+
outputs = UsageResponseKeys.fields()
|
|
2580
3267
|
|
|
2581
3268
|
if execution_mechanism == ParallelExecutionMechanisms.shared_executor:
|
|
2582
3269
|
raise mlrun.errors.MLRunInvalidArgumentError(
|
|
@@ -2604,6 +3291,14 @@ class RootFlowStep(FlowStep):
|
|
|
2604
3291
|
"Inconsistent name for the added model."
|
|
2605
3292
|
)
|
|
2606
3293
|
model_parameters["name"] = name
|
|
3294
|
+
model_parameters["inputs"] = inputs or model_parameters.get("inputs", [])
|
|
3295
|
+
model_parameters["outputs"] = outputs or model_parameters.get("outputs", [])
|
|
3296
|
+
model_parameters["input_path"] = input_path or model_parameters.get(
|
|
3297
|
+
"input_path"
|
|
3298
|
+
)
|
|
3299
|
+
model_parameters["result_path"] = result_path or model_parameters.get(
|
|
3300
|
+
"result_path"
|
|
3301
|
+
)
|
|
2607
3302
|
|
|
2608
3303
|
if name in self.shared_models and not override:
|
|
2609
3304
|
raise mlrun.errors.MLRunInvalidArgumentError(
|
|
@@ -2618,7 +3313,9 @@ class RootFlowStep(FlowStep):
|
|
|
2618
3313
|
self.shared_models[name] = (model_class, model_parameters)
|
|
2619
3314
|
self.shared_models_mechanism[name] = execution_mechanism
|
|
2620
3315
|
|
|
2621
|
-
def
|
|
3316
|
+
def get_shared_model_by_artifact_uri(
|
|
3317
|
+
self, artifact_uri: str
|
|
3318
|
+
) -> Union[tuple[str, str, dict], tuple[None, None, None]]:
|
|
2622
3319
|
"""
|
|
2623
3320
|
Get a shared model by its artifact URI.
|
|
2624
3321
|
:param artifact_uri: The artifact URI of the model.
|
|
@@ -2626,10 +3323,10 @@ class RootFlowStep(FlowStep):
|
|
|
2626
3323
|
"""
|
|
2627
3324
|
for model_name, (model_class, model_params) in self.shared_models.items():
|
|
2628
3325
|
if model_params.get("artifact_uri") == artifact_uri:
|
|
2629
|
-
return model_name
|
|
2630
|
-
return None
|
|
3326
|
+
return model_name, model_class, model_params
|
|
3327
|
+
return None, None, None
|
|
2631
3328
|
|
|
2632
|
-
def
|
|
3329
|
+
def configure_shared_pool_resource(
|
|
2633
3330
|
self,
|
|
2634
3331
|
max_processes: Optional[int] = None,
|
|
2635
3332
|
max_threads: Optional[int] = None,
|
|
@@ -2637,8 +3334,9 @@ class RootFlowStep(FlowStep):
|
|
|
2637
3334
|
) -> None:
|
|
2638
3335
|
"""
|
|
2639
3336
|
Configure the resource limits for the shared models in the graph.
|
|
3337
|
+
|
|
2640
3338
|
:param max_processes: Maximum number of processes to spawn (excluding dedicated processes).
|
|
2641
|
-
|
|
3339
|
+
Defaults to the number of CPUs or 16 if undetectable.
|
|
2642
3340
|
:param max_threads: Maximum number of threads to spawn. Defaults to 32.
|
|
2643
3341
|
:param pool_factor: Multiplier to scale the number of process/thread workers per runnable. Defaults to 1.
|
|
2644
3342
|
"""
|
|
@@ -2677,6 +3375,7 @@ class RootFlowStep(FlowStep):
|
|
|
2677
3375
|
model_params, init_with_params=True
|
|
2678
3376
|
)
|
|
2679
3377
|
model._raise_exception = False
|
|
3378
|
+
model._execution_mechanism = self._shared_models_mechanism[model.name]
|
|
2680
3379
|
self.context.executor.add_runnable(
|
|
2681
3380
|
model, self._shared_models_mechanism[model.name]
|
|
2682
3381
|
)
|
|
@@ -2796,12 +3495,10 @@ def _add_graphviz_router(graph, step, source=None, **kwargs):
|
|
|
2796
3495
|
graph.edge(step.fullname, route.fullname)
|
|
2797
3496
|
|
|
2798
3497
|
|
|
2799
|
-
def _add_graphviz_model_runner(graph, step, source=None):
|
|
3498
|
+
def _add_graphviz_model_runner(graph, step, source=None, is_monitored=False):
|
|
2800
3499
|
if source:
|
|
2801
3500
|
graph.node("_start", source.name, shape=source.shape, style="filled")
|
|
2802
3501
|
graph.edge("_start", step.fullname)
|
|
2803
|
-
|
|
2804
|
-
is_monitored = step._extract_root_step().track_models
|
|
2805
3502
|
m_cell = '<FONT POINT-SIZE="9">🄼</FONT>' if is_monitored else ""
|
|
2806
3503
|
|
|
2807
3504
|
number_of_models = len(
|
|
@@ -2840,6 +3537,7 @@ def _add_graphviz_flow(
|
|
|
2840
3537
|
allow_empty=True
|
|
2841
3538
|
)
|
|
2842
3539
|
graph.node("_start", source.name, shape=source.shape, style="filled")
|
|
3540
|
+
is_monitored = step.track_models if isinstance(step, RootFlowStep) else False
|
|
2843
3541
|
for start_step in start_steps:
|
|
2844
3542
|
graph.edge("_start", start_step.fullname)
|
|
2845
3543
|
for child in step.get_children():
|
|
@@ -2848,7 +3546,7 @@ def _add_graphviz_flow(
|
|
|
2848
3546
|
with graph.subgraph(name="cluster_" + child.fullname) as sg:
|
|
2849
3547
|
_add_graphviz_router(sg, child)
|
|
2850
3548
|
elif kind == StepKinds.model_runner:
|
|
2851
|
-
_add_graphviz_model_runner(graph, child)
|
|
3549
|
+
_add_graphviz_model_runner(graph, child, is_monitored=is_monitored)
|
|
2852
3550
|
else:
|
|
2853
3551
|
graph.node(child.fullname, label=child.name, shape=child.get_shape())
|
|
2854
3552
|
_add_edges(child.after or [], step, graph, child)
|
|
@@ -3034,7 +3732,7 @@ def params_to_step(
|
|
|
3034
3732
|
return name, step
|
|
3035
3733
|
|
|
3036
3734
|
|
|
3037
|
-
def _init_async_objects(context, steps):
|
|
3735
|
+
def _init_async_objects(context, steps, root):
|
|
3038
3736
|
try:
|
|
3039
3737
|
import storey
|
|
3040
3738
|
except ImportError:
|
|
@@ -3049,6 +3747,7 @@ def _init_async_objects(context, steps):
|
|
|
3049
3747
|
|
|
3050
3748
|
for step in steps:
|
|
3051
3749
|
if hasattr(step, "async_object") and step._is_local_function(context):
|
|
3750
|
+
max_iterations = step._max_iterations or root.max_iterations
|
|
3052
3751
|
if step.kind == StepKinds.queue:
|
|
3053
3752
|
skip_stream = context.is_mock and step.next
|
|
3054
3753
|
if step.path and not skip_stream:
|
|
@@ -3067,23 +3766,25 @@ def _init_async_objects(context, steps):
|
|
|
3067
3766
|
datastore_profile = datastore_profile_read(stream_path)
|
|
3068
3767
|
if isinstance(
|
|
3069
3768
|
datastore_profile,
|
|
3070
|
-
|
|
3769
|
+
DatastoreProfileKafkaTarget | DatastoreProfileKafkaStream,
|
|
3071
3770
|
):
|
|
3072
3771
|
step._async_object = KafkaStoreyTarget(
|
|
3073
3772
|
path=stream_path,
|
|
3074
3773
|
context=context,
|
|
3774
|
+
max_iterations=max_iterations,
|
|
3075
3775
|
**options,
|
|
3076
3776
|
)
|
|
3077
3777
|
elif isinstance(datastore_profile, DatastoreProfileV3io):
|
|
3078
3778
|
step._async_object = StreamStoreyTarget(
|
|
3079
3779
|
stream_path=stream_path,
|
|
3080
3780
|
context=context,
|
|
3781
|
+
max_iterations=max_iterations,
|
|
3081
3782
|
**options,
|
|
3082
3783
|
)
|
|
3083
3784
|
else:
|
|
3084
3785
|
raise mlrun.errors.MLRunValueError(
|
|
3085
3786
|
f"Received an unexpected stream profile type: {type(datastore_profile)}\n"
|
|
3086
|
-
"Expects `DatastoreProfileV3io` or `
|
|
3787
|
+
"Expects `DatastoreProfileV3io` or `DatastoreProfileKafkaStream`."
|
|
3087
3788
|
)
|
|
3088
3789
|
elif stream_path.startswith("kafka://") or kafka_brokers:
|
|
3089
3790
|
topic, brokers = parse_kafka_url(stream_path, kafka_brokers)
|
|
@@ -3097,6 +3798,13 @@ def _init_async_objects(context, steps):
|
|
|
3097
3798
|
brokers=brokers,
|
|
3098
3799
|
producer_options=kafka_producer_options,
|
|
3099
3800
|
context=context,
|
|
3801
|
+
max_iterations=max_iterations,
|
|
3802
|
+
**options,
|
|
3803
|
+
)
|
|
3804
|
+
elif stream_path.startswith("dummy://"):
|
|
3805
|
+
step._async_object = _DummyStream(
|
|
3806
|
+
context=context,
|
|
3807
|
+
max_iterations=max_iterations,
|
|
3100
3808
|
**options,
|
|
3101
3809
|
)
|
|
3102
3810
|
else:
|
|
@@ -3107,10 +3815,14 @@ def _init_async_objects(context, steps):
|
|
|
3107
3815
|
storey.V3ioDriver(endpoint or config.v3io_api),
|
|
3108
3816
|
stream_path,
|
|
3109
3817
|
context=context,
|
|
3818
|
+
max_iterations=max_iterations,
|
|
3110
3819
|
**options,
|
|
3111
3820
|
)
|
|
3112
3821
|
else:
|
|
3113
|
-
step._async_object = storey.Map(
|
|
3822
|
+
step._async_object = storey.Map(
|
|
3823
|
+
lambda x: x,
|
|
3824
|
+
max_iterations=max_iterations,
|
|
3825
|
+
)
|
|
3114
3826
|
|
|
3115
3827
|
elif not step.async_object or not hasattr(step.async_object, "_outlets"):
|
|
3116
3828
|
# if regular class, wrap with storey Map
|
|
@@ -3122,6 +3834,8 @@ def _init_async_objects(context, steps):
|
|
|
3122
3834
|
name=step.name,
|
|
3123
3835
|
context=context,
|
|
3124
3836
|
pass_context=step._inject_context,
|
|
3837
|
+
fn_select_outlets=step._outlets_selector,
|
|
3838
|
+
max_iterations=max_iterations,
|
|
3125
3839
|
)
|
|
3126
3840
|
if (
|
|
3127
3841
|
respond_supported
|