mlrun 1.10.0rc13__py3-none-any.whl → 1.10.0rc15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mlrun might be problematic. Click here for more details.
- mlrun/artifacts/base.py +0 -31
- mlrun/artifacts/llm_prompt.py +106 -20
- mlrun/artifacts/manager.py +0 -5
- mlrun/common/constants.py +0 -1
- mlrun/common/schemas/__init__.py +1 -0
- mlrun/common/schemas/model_monitoring/__init__.py +1 -0
- mlrun/common/schemas/model_monitoring/functions.py +1 -1
- mlrun/common/schemas/model_monitoring/model_endpoints.py +10 -0
- mlrun/common/schemas/workflow.py +0 -1
- mlrun/config.py +1 -1
- mlrun/datastore/model_provider/model_provider.py +42 -14
- mlrun/datastore/model_provider/openai_provider.py +96 -15
- mlrun/db/base.py +14 -0
- mlrun/db/httpdb.py +42 -9
- mlrun/db/nopdb.py +8 -0
- mlrun/execution.py +16 -7
- mlrun/model.py +15 -0
- mlrun/model_monitoring/__init__.py +1 -0
- mlrun/model_monitoring/applications/base.py +176 -20
- mlrun/model_monitoring/db/_schedules.py +84 -24
- mlrun/model_monitoring/db/tsdb/base.py +72 -1
- mlrun/model_monitoring/db/tsdb/tdengine/schemas.py +7 -1
- mlrun/model_monitoring/db/tsdb/tdengine/tdengine_connector.py +37 -0
- mlrun/model_monitoring/db/tsdb/v3io/v3io_connector.py +25 -0
- mlrun/model_monitoring/helpers.py +26 -4
- mlrun/projects/project.py +38 -12
- mlrun/runtimes/daskjob.py +6 -0
- mlrun/runtimes/mpijob/abstract.py +6 -0
- mlrun/runtimes/mpijob/v1.py +6 -0
- mlrun/runtimes/nuclio/application/application.py +2 -0
- mlrun/runtimes/nuclio/function.py +6 -0
- mlrun/runtimes/nuclio/serving.py +12 -11
- mlrun/runtimes/pod.py +21 -0
- mlrun/runtimes/remotesparkjob.py +6 -0
- mlrun/runtimes/sparkjob/spark3job.py +6 -0
- mlrun/serving/__init__.py +2 -0
- mlrun/serving/server.py +95 -26
- mlrun/serving/states.py +130 -10
- mlrun/utils/helpers.py +36 -12
- mlrun/utils/retryer.py +15 -2
- mlrun/utils/version/version.json +2 -2
- {mlrun-1.10.0rc13.dist-info → mlrun-1.10.0rc15.dist-info}/METADATA +3 -8
- {mlrun-1.10.0rc13.dist-info → mlrun-1.10.0rc15.dist-info}/RECORD +47 -47
- {mlrun-1.10.0rc13.dist-info → mlrun-1.10.0rc15.dist-info}/WHEEL +0 -0
- {mlrun-1.10.0rc13.dist-info → mlrun-1.10.0rc15.dist-info}/entry_points.txt +0 -0
- {mlrun-1.10.0rc13.dist-info → mlrun-1.10.0rc15.dist-info}/licenses/LICENSE +0 -0
- {mlrun-1.10.0rc13.dist-info → mlrun-1.10.0rc15.dist-info}/top_level.txt +0 -0
mlrun/runtimes/nuclio/serving.py
CHANGED
|
@@ -720,6 +720,7 @@ class ServingRuntime(RemoteRuntime):
|
|
|
720
720
|
"track_models": self.spec.track_models,
|
|
721
721
|
"default_content_type": self.spec.default_content_type,
|
|
722
722
|
"model_endpoint_creation_task_name": self.spec.model_endpoint_creation_task_name,
|
|
723
|
+
# TODO: find another way to pass this (needed for local run)
|
|
723
724
|
"filename": getattr(self.spec, "filename", None),
|
|
724
725
|
}
|
|
725
726
|
|
|
@@ -788,17 +789,13 @@ class ServingRuntime(RemoteRuntime):
|
|
|
788
789
|
monitoring_mock=self.spec.track_models,
|
|
789
790
|
)
|
|
790
791
|
|
|
791
|
-
|
|
792
|
-
|
|
793
|
-
|
|
794
|
-
|
|
795
|
-
server.
|
|
796
|
-
|
|
797
|
-
|
|
798
|
-
self.spec.track_models,
|
|
799
|
-
server.context,
|
|
800
|
-
self.spec,
|
|
801
|
-
)
|
|
792
|
+
server.graph = add_system_steps_to_graph(
|
|
793
|
+
server.project,
|
|
794
|
+
server.graph,
|
|
795
|
+
self.spec.track_models,
|
|
796
|
+
server.context,
|
|
797
|
+
self.spec,
|
|
798
|
+
)
|
|
802
799
|
|
|
803
800
|
if workdir:
|
|
804
801
|
os.chdir(old_workdir)
|
|
@@ -858,6 +855,7 @@ class ServingRuntime(RemoteRuntime):
|
|
|
858
855
|
description=self.spec.description,
|
|
859
856
|
workdir=self.spec.workdir,
|
|
860
857
|
image_pull_secret=self.spec.image_pull_secret,
|
|
858
|
+
build=self.spec.build,
|
|
861
859
|
node_name=self.spec.node_name,
|
|
862
860
|
node_selector=self.spec.node_selector,
|
|
863
861
|
affinity=self.spec.affinity,
|
|
@@ -868,6 +866,9 @@ class ServingRuntime(RemoteRuntime):
|
|
|
868
866
|
security_context=self.spec.security_context,
|
|
869
867
|
state_thresholds=self.spec.state_thresholds,
|
|
870
868
|
serving_spec=self._get_serving_spec(),
|
|
869
|
+
track_models=self.spec.track_models,
|
|
870
|
+
parameters=self.spec.parameters,
|
|
871
|
+
graph=self.spec.graph,
|
|
871
872
|
)
|
|
872
873
|
job = KubejobRuntime(
|
|
873
874
|
spec=spec,
|
mlrun/runtimes/pod.py
CHANGED
|
@@ -104,6 +104,9 @@ class KubeResourceSpec(FunctionSpec):
|
|
|
104
104
|
"security_context",
|
|
105
105
|
"state_thresholds",
|
|
106
106
|
"serving_spec",
|
|
107
|
+
"track_models",
|
|
108
|
+
"parameters",
|
|
109
|
+
"graph",
|
|
107
110
|
]
|
|
108
111
|
_default_fields_to_strip = FunctionSpec._default_fields_to_strip + [
|
|
109
112
|
"volumes",
|
|
@@ -180,6 +183,9 @@ class KubeResourceSpec(FunctionSpec):
|
|
|
180
183
|
security_context=None,
|
|
181
184
|
state_thresholds=None,
|
|
182
185
|
serving_spec=None,
|
|
186
|
+
track_models=None,
|
|
187
|
+
parameters=None,
|
|
188
|
+
graph=None,
|
|
183
189
|
):
|
|
184
190
|
super().__init__(
|
|
185
191
|
command=command,
|
|
@@ -226,6 +232,10 @@ class KubeResourceSpec(FunctionSpec):
|
|
|
226
232
|
or mlrun.mlconf.function.spec.state_thresholds.default.to_dict()
|
|
227
233
|
)
|
|
228
234
|
self.serving_spec = serving_spec
|
|
235
|
+
self.track_models = track_models
|
|
236
|
+
self.parameters = parameters
|
|
237
|
+
self._graph = None
|
|
238
|
+
self.graph = graph
|
|
229
239
|
# Termination grace period is internal for runtimes that have a pod termination hook hence it is not in the
|
|
230
240
|
# _dict_fields and doesn't have a setter.
|
|
231
241
|
self._termination_grace_period_seconds = None
|
|
@@ -303,6 +313,17 @@ class KubeResourceSpec(FunctionSpec):
|
|
|
303
313
|
def termination_grace_period_seconds(self) -> typing.Optional[int]:
|
|
304
314
|
return self._termination_grace_period_seconds
|
|
305
315
|
|
|
316
|
+
@property
|
|
317
|
+
def graph(self):
|
|
318
|
+
"""states graph, holding the serving workflow/DAG topology"""
|
|
319
|
+
return self._graph
|
|
320
|
+
|
|
321
|
+
@graph.setter
|
|
322
|
+
def graph(self, graph):
|
|
323
|
+
from ..serving.states import graph_root_setter
|
|
324
|
+
|
|
325
|
+
graph_root_setter(self, graph)
|
|
326
|
+
|
|
306
327
|
def _serialize_field(
|
|
307
328
|
self, struct: dict, field_name: typing.Optional[str] = None, strip: bool = False
|
|
308
329
|
) -> typing.Any:
|
mlrun/runtimes/remotesparkjob.py
CHANGED
|
@@ -59,6 +59,9 @@ class RemoteSparkSpec(KubeResourceSpec):
|
|
|
59
59
|
security_context=None,
|
|
60
60
|
state_thresholds=None,
|
|
61
61
|
serving_spec=None,
|
|
62
|
+
graph=None,
|
|
63
|
+
parameters=None,
|
|
64
|
+
track_models=None,
|
|
62
65
|
):
|
|
63
66
|
super().__init__(
|
|
64
67
|
command=command,
|
|
@@ -89,6 +92,9 @@ class RemoteSparkSpec(KubeResourceSpec):
|
|
|
89
92
|
security_context=security_context,
|
|
90
93
|
state_thresholds=state_thresholds,
|
|
91
94
|
serving_spec=serving_spec,
|
|
95
|
+
graph=graph,
|
|
96
|
+
parameters=parameters,
|
|
97
|
+
track_models=track_models,
|
|
92
98
|
)
|
|
93
99
|
self.provider = provider
|
|
94
100
|
|
|
@@ -169,6 +169,9 @@ class Spark3JobSpec(KubeResourceSpec):
|
|
|
169
169
|
security_context=None,
|
|
170
170
|
state_thresholds=None,
|
|
171
171
|
serving_spec=None,
|
|
172
|
+
graph=None,
|
|
173
|
+
parameters=None,
|
|
174
|
+
track_models=None,
|
|
172
175
|
):
|
|
173
176
|
super().__init__(
|
|
174
177
|
command=command,
|
|
@@ -199,6 +202,9 @@ class Spark3JobSpec(KubeResourceSpec):
|
|
|
199
202
|
security_context=security_context,
|
|
200
203
|
state_thresholds=state_thresholds,
|
|
201
204
|
serving_spec=serving_spec,
|
|
205
|
+
graph=graph,
|
|
206
|
+
parameters=parameters,
|
|
207
|
+
track_models=track_models,
|
|
202
208
|
)
|
|
203
209
|
|
|
204
210
|
self.driver_resources = driver_resources or {}
|
mlrun/serving/__init__.py
CHANGED
|
@@ -28,6 +28,7 @@ __all__ = [
|
|
|
28
28
|
"Model",
|
|
29
29
|
"ModelSelector",
|
|
30
30
|
"MonitoredStep",
|
|
31
|
+
"LLModel",
|
|
31
32
|
]
|
|
32
33
|
|
|
33
34
|
from .routers import ModelRouter, VotingEnsemble # noqa
|
|
@@ -47,6 +48,7 @@ from .states import (
|
|
|
47
48
|
Model,
|
|
48
49
|
ModelSelector,
|
|
49
50
|
MonitoredStep,
|
|
51
|
+
LLModel,
|
|
50
52
|
) # noqa
|
|
51
53
|
from .v1_serving import MLModelServer, new_v1_model_server # noqa
|
|
52
54
|
from .v2_serving import V2ModelServer # noqa
|
mlrun/serving/server.py
CHANGED
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
__all__ = ["GraphServer", "create_graph_server", "GraphContext", "MockEvent"]
|
|
16
16
|
|
|
17
17
|
import asyncio
|
|
18
|
+
import base64
|
|
18
19
|
import copy
|
|
19
20
|
import json
|
|
20
21
|
import os
|
|
@@ -384,6 +385,7 @@ def add_monitoring_general_steps(
|
|
|
384
385
|
graph: RootFlowStep,
|
|
385
386
|
context,
|
|
386
387
|
serving_spec,
|
|
388
|
+
pause_until_background_task_completion: bool,
|
|
387
389
|
) -> tuple[RootFlowStep, FlowStep]:
|
|
388
390
|
"""
|
|
389
391
|
Adding the monitoring flow connection steps, this steps allow the graph to reconstruct the serving event enrich it
|
|
@@ -392,18 +394,22 @@ def add_monitoring_general_steps(
|
|
|
392
394
|
"background_task_status_step" --> "filter_none" --> "monitoring_pre_processor_step" --> "flatten_events"
|
|
393
395
|
--> "sampling_step" --> "filter_none_sampling" --> "model_monitoring_stream"
|
|
394
396
|
"""
|
|
397
|
+
background_task_status_step = None
|
|
398
|
+
if pause_until_background_task_completion:
|
|
399
|
+
background_task_status_step = graph.add_step(
|
|
400
|
+
"mlrun.serving.system_steps.BackgroundTaskStatus",
|
|
401
|
+
"background_task_status_step",
|
|
402
|
+
model_endpoint_creation_strategy=mlrun.common.schemas.ModelEndpointCreationStrategy.SKIP,
|
|
403
|
+
)
|
|
395
404
|
monitor_flow_step = graph.add_step(
|
|
396
|
-
"mlrun.serving.system_steps.BackgroundTaskStatus",
|
|
397
|
-
"background_task_status_step",
|
|
398
|
-
model_endpoint_creation_strategy=mlrun.common.schemas.ModelEndpointCreationStrategy.SKIP,
|
|
399
|
-
)
|
|
400
|
-
graph.add_step(
|
|
401
405
|
"storey.Filter",
|
|
402
406
|
"filter_none",
|
|
403
407
|
_fn="(event is not None)",
|
|
404
|
-
after="background_task_status_step",
|
|
408
|
+
after="background_task_status_step" if background_task_status_step else None,
|
|
405
409
|
model_endpoint_creation_strategy=mlrun.common.schemas.ModelEndpointCreationStrategy.SKIP,
|
|
406
410
|
)
|
|
411
|
+
if background_task_status_step:
|
|
412
|
+
monitor_flow_step = background_task_status_step
|
|
407
413
|
graph.add_step(
|
|
408
414
|
"mlrun.serving.system_steps.MonitoringPreProcessor",
|
|
409
415
|
"monitoring_pre_processor_step",
|
|
@@ -466,14 +472,28 @@ def add_monitoring_general_steps(
|
|
|
466
472
|
|
|
467
473
|
|
|
468
474
|
def add_system_steps_to_graph(
|
|
469
|
-
project: str,
|
|
475
|
+
project: str,
|
|
476
|
+
graph: RootFlowStep,
|
|
477
|
+
track_models: bool,
|
|
478
|
+
context,
|
|
479
|
+
serving_spec,
|
|
480
|
+
pause_until_background_task_completion: bool = True,
|
|
470
481
|
) -> RootFlowStep:
|
|
482
|
+
if not (isinstance(graph, RootFlowStep) and graph.include_monitored_step()):
|
|
483
|
+
return graph
|
|
471
484
|
monitored_steps = graph.get_monitored_steps()
|
|
472
485
|
graph = add_error_raiser_step(graph, monitored_steps)
|
|
473
486
|
if track_models:
|
|
487
|
+
background_task_status_step = None
|
|
474
488
|
graph, monitor_flow_step = add_monitoring_general_steps(
|
|
475
|
-
project,
|
|
489
|
+
project,
|
|
490
|
+
graph,
|
|
491
|
+
context,
|
|
492
|
+
serving_spec,
|
|
493
|
+
pause_until_background_task_completion,
|
|
476
494
|
)
|
|
495
|
+
if background_task_status_step:
|
|
496
|
+
monitor_flow_step = background_task_status_step
|
|
477
497
|
# Connect each model runner to the monitoring step:
|
|
478
498
|
for step_name, step in monitored_steps.items():
|
|
479
499
|
if monitor_flow_step.after:
|
|
@@ -485,6 +505,10 @@ def add_system_steps_to_graph(
|
|
|
485
505
|
monitor_flow_step.after = [
|
|
486
506
|
step_name,
|
|
487
507
|
]
|
|
508
|
+
context.logger.info_with(
|
|
509
|
+
"Server graph after adding system steps",
|
|
510
|
+
graph=str(graph.steps),
|
|
511
|
+
)
|
|
488
512
|
return graph
|
|
489
513
|
|
|
490
514
|
|
|
@@ -494,18 +518,13 @@ def v2_serving_init(context, namespace=None):
|
|
|
494
518
|
context.logger.info("Initializing server from spec")
|
|
495
519
|
spec = mlrun.utils.get_serving_spec()
|
|
496
520
|
server = GraphServer.from_dict(spec)
|
|
497
|
-
|
|
498
|
-
server.
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
)
|
|
505
|
-
context.logger.info_with(
|
|
506
|
-
"Server graph after adding system steps",
|
|
507
|
-
graph=str(server.graph.steps),
|
|
508
|
-
)
|
|
521
|
+
server.graph = add_system_steps_to_graph(
|
|
522
|
+
server.project,
|
|
523
|
+
copy.deepcopy(server.graph),
|
|
524
|
+
spec.get("track_models"),
|
|
525
|
+
context,
|
|
526
|
+
spec,
|
|
527
|
+
)
|
|
509
528
|
|
|
510
529
|
if config.log_level.lower() == "debug":
|
|
511
530
|
server.verbose = True
|
|
@@ -544,17 +563,57 @@ async def async_execute_graph(
|
|
|
544
563
|
data: DataItem,
|
|
545
564
|
batching: bool,
|
|
546
565
|
batch_size: Optional[int],
|
|
566
|
+
read_as_lists: bool,
|
|
567
|
+
nest_under_inputs: bool,
|
|
547
568
|
) -> list[Any]:
|
|
548
569
|
spec = mlrun.utils.get_serving_spec()
|
|
549
570
|
|
|
550
|
-
source_filename = spec.get("filename", None)
|
|
551
571
|
namespace = {}
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
572
|
+
code = os.getenv("MLRUN_EXEC_CODE")
|
|
573
|
+
if code:
|
|
574
|
+
code = base64.b64decode(code).decode("utf-8")
|
|
575
|
+
exec(code, namespace)
|
|
576
|
+
else:
|
|
577
|
+
# TODO: find another way to get the local file path, or ensure that MLRUN_EXEC_CODE
|
|
578
|
+
# gets set in local flow and not just in the remote pod
|
|
579
|
+
source_filename = spec.get("filename", None)
|
|
580
|
+
if source_filename:
|
|
581
|
+
with open(source_filename) as f:
|
|
582
|
+
exec(f.read(), namespace)
|
|
555
583
|
|
|
556
584
|
server = GraphServer.from_dict(spec)
|
|
557
585
|
|
|
586
|
+
if server.model_endpoint_creation_task_name:
|
|
587
|
+
context.logger.info(
|
|
588
|
+
f"Waiting for model endpoint creation task '{server.model_endpoint_creation_task_name}'..."
|
|
589
|
+
)
|
|
590
|
+
background_task = (
|
|
591
|
+
mlrun.get_run_db().wait_for_background_task_to_reach_terminal_state(
|
|
592
|
+
project=server.project,
|
|
593
|
+
name=server.model_endpoint_creation_task_name,
|
|
594
|
+
)
|
|
595
|
+
)
|
|
596
|
+
task_state = background_task.status.state
|
|
597
|
+
if task_state == mlrun.common.schemas.BackgroundTaskState.failed:
|
|
598
|
+
raise mlrun.errors.MLRunRuntimeError(
|
|
599
|
+
"Aborting job due to model endpoint creation background task failure"
|
|
600
|
+
)
|
|
601
|
+
elif task_state != mlrun.common.schemas.BackgroundTaskState.succeeded:
|
|
602
|
+
# this shouldn't happen, but we need to know if it does
|
|
603
|
+
raise mlrun.errors.MLRunRuntimeError(
|
|
604
|
+
"Aborting job because the model endpoint creation background task did not succeed "
|
|
605
|
+
f"(status='{task_state}')"
|
|
606
|
+
)
|
|
607
|
+
|
|
608
|
+
server.graph = add_system_steps_to_graph(
|
|
609
|
+
server.project,
|
|
610
|
+
copy.deepcopy(server.graph),
|
|
611
|
+
spec.get("track_models"),
|
|
612
|
+
context,
|
|
613
|
+
spec,
|
|
614
|
+
pause_until_background_task_completion=False, # we've already awaited it
|
|
615
|
+
)
|
|
616
|
+
|
|
558
617
|
if config.log_level.lower() == "debug":
|
|
559
618
|
server.verbose = True
|
|
560
619
|
context.logger.info_with("Initializing states", namespace=namespace)
|
|
@@ -588,7 +647,9 @@ async def async_execute_graph(
|
|
|
588
647
|
|
|
589
648
|
batch = []
|
|
590
649
|
for index, row in df.iterrows():
|
|
591
|
-
data = row.to_dict()
|
|
650
|
+
data = row.to_list() if read_as_lists else row.to_dict()
|
|
651
|
+
if nest_under_inputs:
|
|
652
|
+
data = {"inputs": data}
|
|
592
653
|
if batching:
|
|
593
654
|
batch.append(data)
|
|
594
655
|
if len(batch) == batch_size:
|
|
@@ -612,6 +673,8 @@ def execute_graph(
|
|
|
612
673
|
data: DataItem,
|
|
613
674
|
batching: bool = False,
|
|
614
675
|
batch_size: Optional[int] = None,
|
|
676
|
+
read_as_lists: bool = False,
|
|
677
|
+
nest_under_inputs: bool = False,
|
|
615
678
|
) -> (list[Any], Any):
|
|
616
679
|
"""
|
|
617
680
|
Execute graph as a job, from start to finish.
|
|
@@ -621,10 +684,16 @@ def execute_graph(
|
|
|
621
684
|
:param batching: Whether to push one or more batches into the graph rather than row by row.
|
|
622
685
|
:param batch_size: The number of rows to push per batch. If not set, and batching=True, the entire dataset will
|
|
623
686
|
be pushed into the graph in one batch.
|
|
687
|
+
:param read_as_lists: Whether to read each row as a list instead of a dictionary.
|
|
688
|
+
:param nest_under_inputs: Whether to wrap each row with {"inputs": ...}.
|
|
624
689
|
|
|
625
690
|
:return: A list of responses.
|
|
626
691
|
"""
|
|
627
|
-
return asyncio.run(
|
|
692
|
+
return asyncio.run(
|
|
693
|
+
async_execute_graph(
|
|
694
|
+
context, data, batching, batch_size, read_as_lists, nest_under_inputs
|
|
695
|
+
)
|
|
696
|
+
)
|
|
628
697
|
|
|
629
698
|
|
|
630
699
|
def _set_callbacks(server, context):
|
mlrun/serving/states.py
CHANGED
|
@@ -1081,6 +1081,7 @@ class Model(storey.ParallelExecutionRunnable, ModelObj):
|
|
|
1081
1081
|
"raise_exception",
|
|
1082
1082
|
"artifact_uri",
|
|
1083
1083
|
"shared_runnable_name",
|
|
1084
|
+
"shared_proxy_mapping",
|
|
1084
1085
|
]
|
|
1085
1086
|
kind = "model"
|
|
1086
1087
|
|
|
@@ -1089,12 +1090,16 @@ class Model(storey.ParallelExecutionRunnable, ModelObj):
|
|
|
1089
1090
|
name: str,
|
|
1090
1091
|
raise_exception: bool = True,
|
|
1091
1092
|
artifact_uri: Optional[str] = None,
|
|
1093
|
+
shared_proxy_mapping: Optional[dict] = None,
|
|
1092
1094
|
**kwargs,
|
|
1093
1095
|
):
|
|
1094
1096
|
super().__init__(name=name, raise_exception=raise_exception, **kwargs)
|
|
1095
1097
|
if artifact_uri is not None and not isinstance(artifact_uri, str):
|
|
1096
1098
|
raise MLRunInvalidArgumentError("'artifact_uri' argument must be a string")
|
|
1097
1099
|
self.artifact_uri = artifact_uri
|
|
1100
|
+
self.shared_proxy_mapping: dict[
|
|
1101
|
+
str : Union[str, ModelArtifact, LLMPromptArtifact]
|
|
1102
|
+
] = shared_proxy_mapping
|
|
1098
1103
|
self.invocation_artifact: Optional[LLMPromptArtifact] = None
|
|
1099
1104
|
self.model_artifact: Optional[ModelArtifact] = None
|
|
1100
1105
|
self.model_provider: Optional[ModelProvider] = None
|
|
@@ -1125,10 +1130,13 @@ class Model(storey.ParallelExecutionRunnable, ModelObj):
|
|
|
1125
1130
|
else:
|
|
1126
1131
|
self.model_artifact = artifact
|
|
1127
1132
|
|
|
1128
|
-
def _get_artifact_object(
|
|
1129
|
-
|
|
1130
|
-
|
|
1131
|
-
|
|
1133
|
+
def _get_artifact_object(
|
|
1134
|
+
self, proxy_uri: Optional[str] = None
|
|
1135
|
+
) -> Union[ModelArtifact, LLMPromptArtifact, None]:
|
|
1136
|
+
uri = proxy_uri or self.artifact_uri
|
|
1137
|
+
if uri:
|
|
1138
|
+
if mlrun.datastore.is_store_uri(uri):
|
|
1139
|
+
artifact, _ = mlrun.store_manager.get_store_artifact(uri)
|
|
1132
1140
|
return artifact
|
|
1133
1141
|
else:
|
|
1134
1142
|
raise ValueError(
|
|
@@ -1148,10 +1156,12 @@ class Model(storey.ParallelExecutionRunnable, ModelObj):
|
|
|
1148
1156
|
"""Override to implement prediction logic if the logic requires asyncio."""
|
|
1149
1157
|
return body
|
|
1150
1158
|
|
|
1151
|
-
def run(self, body: Any, path: str) -> Any:
|
|
1159
|
+
def run(self, body: Any, path: str, origin_name: Optional[str] = None) -> Any:
|
|
1152
1160
|
return self.predict(body)
|
|
1153
1161
|
|
|
1154
|
-
async def run_async(
|
|
1162
|
+
async def run_async(
|
|
1163
|
+
self, body: Any, path: str, origin_name: Optional[str] = None
|
|
1164
|
+
) -> Any:
|
|
1155
1165
|
return await self.predict_async(body)
|
|
1156
1166
|
|
|
1157
1167
|
def get_local_model_path(self, suffix="") -> (str, dict):
|
|
@@ -1186,6 +1196,81 @@ class Model(storey.ParallelExecutionRunnable, ModelObj):
|
|
|
1186
1196
|
return None, None
|
|
1187
1197
|
|
|
1188
1198
|
|
|
1199
|
+
class LLModel(Model):
|
|
1200
|
+
def __init__(self, name: str, **kwargs):
|
|
1201
|
+
super().__init__(name, **kwargs)
|
|
1202
|
+
|
|
1203
|
+
def predict(
|
|
1204
|
+
self, body: Any, messages: list[dict], model_configuration: dict
|
|
1205
|
+
) -> Any:
|
|
1206
|
+
if isinstance(
|
|
1207
|
+
self.invocation_artifact, mlrun.artifacts.LLMPromptArtifact
|
|
1208
|
+
) and isinstance(self.model_provider, ModelProvider):
|
|
1209
|
+
body["result"] = self.model_provider.invoke(
|
|
1210
|
+
messages=messages,
|
|
1211
|
+
as_str=True,
|
|
1212
|
+
**(model_configuration or {}),
|
|
1213
|
+
)
|
|
1214
|
+
return body
|
|
1215
|
+
|
|
1216
|
+
async def predict_async(
|
|
1217
|
+
self, body: Any, messages: list[dict], model_configuration: dict
|
|
1218
|
+
) -> Any:
|
|
1219
|
+
if isinstance(
|
|
1220
|
+
self.invocation_artifact, mlrun.artifacts.LLMPromptArtifact
|
|
1221
|
+
) and isinstance(self.model_provider, ModelProvider):
|
|
1222
|
+
body["result"] = await self.model_provider.async_invoke(
|
|
1223
|
+
messages=messages,
|
|
1224
|
+
as_str=True,
|
|
1225
|
+
**(model_configuration or {}),
|
|
1226
|
+
)
|
|
1227
|
+
return body
|
|
1228
|
+
|
|
1229
|
+
def run(self, body: Any, path: str, origin_name: Optional[str] = None) -> Any:
|
|
1230
|
+
messages, model_configuration = self.enrich_prompt(body, origin_name)
|
|
1231
|
+
return self.predict(
|
|
1232
|
+
body, messages=messages, model_configuration=model_configuration
|
|
1233
|
+
)
|
|
1234
|
+
|
|
1235
|
+
async def run_async(
|
|
1236
|
+
self, body: Any, path: str, origin_name: Optional[str] = None
|
|
1237
|
+
) -> Any:
|
|
1238
|
+
messages, model_configuration = self.enrich_prompt(body, origin_name)
|
|
1239
|
+
return await self.predict_async(
|
|
1240
|
+
body, messages=messages, model_configuration=model_configuration
|
|
1241
|
+
)
|
|
1242
|
+
|
|
1243
|
+
def enrich_prompt(
|
|
1244
|
+
self, body: dict, origin_name: str
|
|
1245
|
+
) -> Union[tuple[list[dict], dict], tuple[None, None]]:
|
|
1246
|
+
if origin_name and self.shared_proxy_mapping:
|
|
1247
|
+
llm_prompt_artifact = self.shared_proxy_mapping.get(origin_name)
|
|
1248
|
+
if isinstance(llm_prompt_artifact, str):
|
|
1249
|
+
llm_prompt_artifact = self._get_artifact_object(llm_prompt_artifact)
|
|
1250
|
+
self.shared_proxy_mapping[origin_name] = llm_prompt_artifact
|
|
1251
|
+
else:
|
|
1252
|
+
llm_prompt_artifact = (
|
|
1253
|
+
self.invocation_artifact or self._get_artifact_object()
|
|
1254
|
+
)
|
|
1255
|
+
if not (
|
|
1256
|
+
llm_prompt_artifact and isinstance(llm_prompt_artifact, LLMPromptArtifact)
|
|
1257
|
+
):
|
|
1258
|
+
logger.warning(
|
|
1259
|
+
"LLMModel must be provided with LLMPromptArtifact",
|
|
1260
|
+
llm_prompt_artifact=llm_prompt_artifact,
|
|
1261
|
+
)
|
|
1262
|
+
return None, None
|
|
1263
|
+
prompt_legend = llm_prompt_artifact.spec.prompt_legend
|
|
1264
|
+
prompt_template = deepcopy(llm_prompt_artifact.read_prompt())
|
|
1265
|
+
kwargs = {
|
|
1266
|
+
place_holder: body.get(body_map["field"])
|
|
1267
|
+
for place_holder, body_map in prompt_legend.items()
|
|
1268
|
+
}
|
|
1269
|
+
for d in prompt_template:
|
|
1270
|
+
d["content"] = d["content"].format(**kwargs)
|
|
1271
|
+
return prompt_template, llm_prompt_artifact.spec.model_configuration
|
|
1272
|
+
|
|
1273
|
+
|
|
1189
1274
|
class ModelSelector:
|
|
1190
1275
|
"""Used to select which models to run on each event."""
|
|
1191
1276
|
|
|
@@ -1292,6 +1377,7 @@ class ModelRunnerStep(MonitoredStep):
|
|
|
1292
1377
|
"""
|
|
1293
1378
|
|
|
1294
1379
|
kind = "model_runner"
|
|
1380
|
+
_dict_fields = MonitoredStep._dict_fields + ["_shared_proxy_mapping"]
|
|
1295
1381
|
|
|
1296
1382
|
def __init__(
|
|
1297
1383
|
self,
|
|
@@ -1311,6 +1397,7 @@ class ModelRunnerStep(MonitoredStep):
|
|
|
1311
1397
|
)
|
|
1312
1398
|
self.raise_exception = raise_exception
|
|
1313
1399
|
self.shape = "folder"
|
|
1400
|
+
self._shared_proxy_mapping = {}
|
|
1314
1401
|
|
|
1315
1402
|
def add_shared_model_proxy(
|
|
1316
1403
|
self,
|
|
@@ -1360,9 +1447,9 @@ class ModelRunnerStep(MonitoredStep):
|
|
|
1360
1447
|
in path.
|
|
1361
1448
|
:param override: bool allow override existing model on the current ModelRunnerStep.
|
|
1362
1449
|
"""
|
|
1363
|
-
model_class =
|
|
1364
|
-
|
|
1365
|
-
shared_runnable_name
|
|
1450
|
+
model_class, model_params = (
|
|
1451
|
+
"mlrun.serving.Model",
|
|
1452
|
+
{"name": endpoint_name, "shared_runnable_name": shared_model_name},
|
|
1366
1453
|
)
|
|
1367
1454
|
if isinstance(model_artifact, str):
|
|
1368
1455
|
model_artifact_uri = model_artifact
|
|
@@ -1389,6 +1476,20 @@ class ModelRunnerStep(MonitoredStep):
|
|
|
1389
1476
|
f"ModelRunnerStep can only add proxy models that were added to the root flow step, "
|
|
1390
1477
|
f"model {shared_model_name} is not in the shared models."
|
|
1391
1478
|
)
|
|
1479
|
+
if shared_model_name not in self._shared_proxy_mapping:
|
|
1480
|
+
self._shared_proxy_mapping[shared_model_name] = {
|
|
1481
|
+
endpoint_name: model_artifact.uri
|
|
1482
|
+
if isinstance(model_artifact, (ModelArtifact, LLMPromptArtifact))
|
|
1483
|
+
else model_artifact
|
|
1484
|
+
}
|
|
1485
|
+
else:
|
|
1486
|
+
self._shared_proxy_mapping[shared_model_name].update(
|
|
1487
|
+
{
|
|
1488
|
+
endpoint_name: model_artifact.uri
|
|
1489
|
+
if isinstance(model_artifact, (ModelArtifact, LLMPromptArtifact))
|
|
1490
|
+
else model_artifact
|
|
1491
|
+
}
|
|
1492
|
+
)
|
|
1392
1493
|
self.add_model(
|
|
1393
1494
|
endpoint_name=endpoint_name,
|
|
1394
1495
|
model_class=model_class,
|
|
@@ -1401,6 +1502,7 @@ class ModelRunnerStep(MonitoredStep):
|
|
|
1401
1502
|
outputs=outputs,
|
|
1402
1503
|
input_path=input_path,
|
|
1403
1504
|
result_path=result_path,
|
|
1505
|
+
**model_params,
|
|
1404
1506
|
)
|
|
1405
1507
|
|
|
1406
1508
|
def add_model(
|
|
@@ -1659,6 +1761,7 @@ class ModelRunnerStep(MonitoredStep):
|
|
|
1659
1761
|
model_selector=model_selector,
|
|
1660
1762
|
runnables=model_objects,
|
|
1661
1763
|
execution_mechanism_by_runnable_name=execution_mechanism_by_model_name,
|
|
1764
|
+
shared_proxy_mapping=self._shared_proxy_mapping or None,
|
|
1662
1765
|
name=self.name,
|
|
1663
1766
|
context=context,
|
|
1664
1767
|
)
|
|
@@ -2494,7 +2597,24 @@ class RootFlowStep(FlowStep):
|
|
|
2494
2597
|
max_threads=self.shared_max_threads,
|
|
2495
2598
|
pool_factor=self.pool_factor,
|
|
2496
2599
|
)
|
|
2497
|
-
|
|
2600
|
+
monitored_steps = self.get_monitored_steps().values()
|
|
2601
|
+
for monitored_step in monitored_steps:
|
|
2602
|
+
if isinstance(monitored_step, ModelRunnerStep):
|
|
2603
|
+
for model, model_params in self.shared_models.values():
|
|
2604
|
+
if "shared_proxy_mapping" in model_params:
|
|
2605
|
+
model_params["shared_proxy_mapping"].update(
|
|
2606
|
+
deepcopy(
|
|
2607
|
+
monitored_step._shared_proxy_mapping.get(
|
|
2608
|
+
model_params.get("name"), {}
|
|
2609
|
+
)
|
|
2610
|
+
)
|
|
2611
|
+
)
|
|
2612
|
+
else:
|
|
2613
|
+
model_params["shared_proxy_mapping"] = deepcopy(
|
|
2614
|
+
monitored_step._shared_proxy_mapping.get(
|
|
2615
|
+
model_params.get("name"), {}
|
|
2616
|
+
)
|
|
2617
|
+
)
|
|
2498
2618
|
for model, model_params in self.shared_models.values():
|
|
2499
2619
|
model = get_class(model, namespace).from_dict(
|
|
2500
2620
|
model_params, init_with_params=True
|
mlrun/utils/helpers.py
CHANGED
|
@@ -162,14 +162,6 @@ def get_artifact_target(item: dict, project=None):
|
|
|
162
162
|
return item["spec"].get("target_path")
|
|
163
163
|
|
|
164
164
|
|
|
165
|
-
# TODO: Remove once data migration v5 is obsolete
|
|
166
|
-
def is_legacy_artifact(artifact):
|
|
167
|
-
if isinstance(artifact, dict):
|
|
168
|
-
return "metadata" not in artifact
|
|
169
|
-
else:
|
|
170
|
-
return not hasattr(artifact, "metadata")
|
|
171
|
-
|
|
172
|
-
|
|
173
165
|
logger = create_logger(config.log_level, config.log_formatter, "mlrun", sys.stdout)
|
|
174
166
|
missing = object()
|
|
175
167
|
|
|
@@ -1050,7 +1042,14 @@ def fill_function_hash(function_dict, tag=""):
|
|
|
1050
1042
|
|
|
1051
1043
|
|
|
1052
1044
|
def retry_until_successful(
|
|
1053
|
-
backoff: int,
|
|
1045
|
+
backoff: int,
|
|
1046
|
+
timeout: int,
|
|
1047
|
+
logger,
|
|
1048
|
+
verbose: bool,
|
|
1049
|
+
_function,
|
|
1050
|
+
*args,
|
|
1051
|
+
fatal_exceptions=(),
|
|
1052
|
+
**kwargs,
|
|
1054
1053
|
):
|
|
1055
1054
|
"""
|
|
1056
1055
|
Runs function with given *args and **kwargs.
|
|
@@ -1063,14 +1062,31 @@ def retry_until_successful(
|
|
|
1063
1062
|
:param verbose: whether to log the failure on each retry
|
|
1064
1063
|
:param _function: function to run
|
|
1065
1064
|
:param args: functions args
|
|
1065
|
+
:param fatal_exceptions: exception types that should not be retried
|
|
1066
1066
|
:param kwargs: functions kwargs
|
|
1067
1067
|
:return: function result
|
|
1068
1068
|
"""
|
|
1069
|
-
return Retryer(
|
|
1069
|
+
return Retryer(
|
|
1070
|
+
backoff,
|
|
1071
|
+
timeout,
|
|
1072
|
+
logger,
|
|
1073
|
+
verbose,
|
|
1074
|
+
_function,
|
|
1075
|
+
*args,
|
|
1076
|
+
fatal_exceptions=fatal_exceptions,
|
|
1077
|
+
**kwargs,
|
|
1078
|
+
).run()
|
|
1070
1079
|
|
|
1071
1080
|
|
|
1072
1081
|
async def retry_until_successful_async(
|
|
1073
|
-
backoff: int,
|
|
1082
|
+
backoff: int,
|
|
1083
|
+
timeout: int,
|
|
1084
|
+
logger,
|
|
1085
|
+
verbose: bool,
|
|
1086
|
+
_function,
|
|
1087
|
+
*args,
|
|
1088
|
+
fatal_exceptions=(),
|
|
1089
|
+
**kwargs,
|
|
1074
1090
|
):
|
|
1075
1091
|
"""
|
|
1076
1092
|
Runs function with given *args and **kwargs.
|
|
@@ -1082,12 +1098,20 @@ async def retry_until_successful_async(
|
|
|
1082
1098
|
:param logger: a logger so we can log the failures
|
|
1083
1099
|
:param verbose: whether to log the failure on each retry
|
|
1084
1100
|
:param _function: function to run
|
|
1101
|
+
:param fatal_exceptions: exception types that should not be retried
|
|
1085
1102
|
:param args: functions args
|
|
1086
1103
|
:param kwargs: functions kwargs
|
|
1087
1104
|
:return: function result
|
|
1088
1105
|
"""
|
|
1089
1106
|
return await AsyncRetryer(
|
|
1090
|
-
backoff,
|
|
1107
|
+
backoff,
|
|
1108
|
+
timeout,
|
|
1109
|
+
logger,
|
|
1110
|
+
verbose,
|
|
1111
|
+
_function,
|
|
1112
|
+
*args,
|
|
1113
|
+
fatal_exceptions=fatal_exceptions,
|
|
1114
|
+
**kwargs,
|
|
1091
1115
|
).run()
|
|
1092
1116
|
|
|
1093
1117
|
|