mlrun 1.8.0rc5__py3-none-any.whl → 1.8.0rc6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mlrun might be problematic. Click here for more details.
- mlrun/artifacts/__init__.py +1 -1
- mlrun/artifacts/base.py +12 -1
- mlrun/artifacts/document.py +59 -38
- mlrun/common/model_monitoring/__init__.py +0 -2
- mlrun/common/model_monitoring/helpers.py +0 -28
- mlrun/common/schemas/__init__.py +1 -4
- mlrun/common/schemas/client_spec.py +0 -1
- mlrun/common/schemas/model_monitoring/__init__.py +0 -6
- mlrun/common/schemas/model_monitoring/constants.py +11 -9
- mlrun/common/schemas/model_monitoring/model_endpoints.py +77 -149
- mlrun/common/schemas/notification.py +6 -0
- mlrun/config.py +0 -2
- mlrun/datastore/datastore_profile.py +57 -17
- mlrun/datastore/vectorstore.py +67 -59
- mlrun/db/base.py +22 -18
- mlrun/db/httpdb.py +116 -148
- mlrun/db/nopdb.py +33 -17
- mlrun/execution.py +11 -4
- mlrun/model.py +3 -0
- mlrun/model_monitoring/__init__.py +3 -2
- mlrun/model_monitoring/api.py +40 -43
- mlrun/model_monitoring/applications/_application_steps.py +3 -1
- mlrun/model_monitoring/applications/context.py +15 -17
- mlrun/model_monitoring/controller.py +43 -37
- mlrun/model_monitoring/db/__init__.py +0 -2
- mlrun/model_monitoring/db/tsdb/base.py +2 -1
- mlrun/model_monitoring/db/tsdb/tdengine/tdengine_connector.py +2 -1
- mlrun/model_monitoring/db/tsdb/v3io/v3io_connector.py +43 -0
- mlrun/model_monitoring/helpers.py +12 -66
- mlrun/model_monitoring/stream_processing.py +83 -270
- mlrun/model_monitoring/writer.py +1 -10
- mlrun/projects/project.py +63 -55
- mlrun/runtimes/nuclio/function.py +7 -6
- mlrun/runtimes/nuclio/serving.py +7 -1
- mlrun/serving/routers.py +158 -145
- mlrun/serving/server.py +6 -0
- mlrun/serving/states.py +2 -0
- mlrun/serving/v2_serving.py +69 -60
- mlrun/utils/helpers.py +14 -30
- mlrun/utils/notifications/notification/mail.py +17 -6
- mlrun/utils/version/version.json +2 -2
- {mlrun-1.8.0rc5.dist-info → mlrun-1.8.0rc6.dist-info}/METADATA +1 -1
- {mlrun-1.8.0rc5.dist-info → mlrun-1.8.0rc6.dist-info}/RECORD +47 -60
- mlrun/common/schemas/model_monitoring/model_endpoint_v2.py +0 -149
- mlrun/model_monitoring/db/stores/__init__.py +0 -136
- mlrun/model_monitoring/db/stores/base/__init__.py +0 -15
- mlrun/model_monitoring/db/stores/base/store.py +0 -154
- mlrun/model_monitoring/db/stores/sqldb/__init__.py +0 -13
- mlrun/model_monitoring/db/stores/sqldb/models/__init__.py +0 -46
- mlrun/model_monitoring/db/stores/sqldb/models/base.py +0 -93
- mlrun/model_monitoring/db/stores/sqldb/models/mysql.py +0 -47
- mlrun/model_monitoring/db/stores/sqldb/models/sqlite.py +0 -25
- mlrun/model_monitoring/db/stores/sqldb/sql_store.py +0 -408
- mlrun/model_monitoring/db/stores/v3io_kv/__init__.py +0 -13
- mlrun/model_monitoring/db/stores/v3io_kv/kv_store.py +0 -464
- mlrun/model_monitoring/model_endpoint.py +0 -120
- {mlrun-1.8.0rc5.dist-info → mlrun-1.8.0rc6.dist-info}/LICENSE +0 -0
- {mlrun-1.8.0rc5.dist-info → mlrun-1.8.0rc6.dist-info}/WHEEL +0 -0
- {mlrun-1.8.0rc5.dist-info → mlrun-1.8.0rc6.dist-info}/entry_points.txt +0 -0
- {mlrun-1.8.0rc5.dist-info → mlrun-1.8.0rc6.dist-info}/top_level.txt +0 -0
|
@@ -11,27 +11,22 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
import enum
|
|
14
|
+
import abc
|
|
16
15
|
import json
|
|
17
16
|
from datetime import datetime
|
|
18
17
|
from typing import Any, NamedTuple, Optional, TypeVar
|
|
19
18
|
|
|
20
|
-
from pydantic.v1 import BaseModel,
|
|
19
|
+
from pydantic.v1 import BaseModel, Field, constr
|
|
21
20
|
|
|
22
21
|
# TODO: remove the unused import below after `mlrun.datastore` and `mlrun.utils` usage is removed.
|
|
23
22
|
# At the moment `make lint` fails if this is removed.
|
|
24
|
-
import
|
|
25
|
-
|
|
26
|
-
from ..object import ObjectKind, ObjectSpec, ObjectStatus
|
|
23
|
+
from ..object import ObjectKind, ObjectMetadata, ObjectSpec, ObjectStatus
|
|
24
|
+
from . import ModelEndpointSchema
|
|
27
25
|
from .constants import (
|
|
28
26
|
FQN_REGEX,
|
|
29
27
|
MODEL_ENDPOINT_ID_PATTERN,
|
|
30
28
|
PROJECT_PATTERN,
|
|
31
29
|
EndpointType,
|
|
32
|
-
EventFieldType,
|
|
33
|
-
EventKeyMetrics,
|
|
34
|
-
EventLiveStats,
|
|
35
30
|
ModelEndpointMonitoringMetricType,
|
|
36
31
|
ModelMonitoringMode,
|
|
37
32
|
ResultKindApp,
|
|
@@ -47,81 +42,6 @@ class ModelMonitoringStoreKinds:
|
|
|
47
42
|
EVENTS = "events"
|
|
48
43
|
|
|
49
44
|
|
|
50
|
-
class ModelEndpointMetadata(BaseModel):
|
|
51
|
-
project: constr(regex=PROJECT_PATTERN)
|
|
52
|
-
uid: constr(regex=MODEL_ENDPOINT_ID_PATTERN)
|
|
53
|
-
labels: Optional[dict] = {}
|
|
54
|
-
|
|
55
|
-
class Config:
|
|
56
|
-
extra = Extra.allow
|
|
57
|
-
|
|
58
|
-
@classmethod
|
|
59
|
-
def from_flat_dict(
|
|
60
|
-
cls, endpoint_dict: dict, json_parse_values: Optional[list] = None
|
|
61
|
-
):
|
|
62
|
-
"""Create a `ModelEndpointMetadata` object from an endpoint dictionary
|
|
63
|
-
|
|
64
|
-
:param endpoint_dict: Model endpoint dictionary.
|
|
65
|
-
:param json_parse_values: List of dictionary keys with a JSON string value that will be parsed into a
|
|
66
|
-
dictionary using json.loads().
|
|
67
|
-
"""
|
|
68
|
-
if json_parse_values is None:
|
|
69
|
-
json_parse_values = [EventFieldType.LABELS]
|
|
70
|
-
|
|
71
|
-
return _mapping_attributes(
|
|
72
|
-
model_class=cls,
|
|
73
|
-
flattened_dictionary=endpoint_dict,
|
|
74
|
-
json_parse_values=json_parse_values,
|
|
75
|
-
)
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
class ModelEndpointSpec(ObjectSpec):
|
|
79
|
-
function_uri: Optional[str] = "" # <project_name>/<function_name>:<tag>
|
|
80
|
-
model: Optional[str] = "" # <model_name>:<version>
|
|
81
|
-
model_class: Optional[str] = ""
|
|
82
|
-
model_uri: Optional[str] = ""
|
|
83
|
-
feature_names: Optional[list[str]] = []
|
|
84
|
-
label_names: Optional[list[str]] = []
|
|
85
|
-
stream_path: Optional[str] = ""
|
|
86
|
-
algorithm: Optional[str] = ""
|
|
87
|
-
monitor_configuration: Optional[dict] = {}
|
|
88
|
-
active: Optional[bool] = True
|
|
89
|
-
monitoring_mode: Optional[ModelMonitoringMode] = ModelMonitoringMode.disabled.value
|
|
90
|
-
|
|
91
|
-
@classmethod
|
|
92
|
-
def from_flat_dict(
|
|
93
|
-
cls, endpoint_dict: dict, json_parse_values: Optional[list] = None
|
|
94
|
-
):
|
|
95
|
-
"""Create a `ModelEndpointSpec` object from an endpoint dictionary
|
|
96
|
-
|
|
97
|
-
:param endpoint_dict: Model endpoint dictionary.
|
|
98
|
-
:param json_parse_values: List of dictionary keys with a JSON string value that will be parsed into a
|
|
99
|
-
dictionary using json.loads().
|
|
100
|
-
"""
|
|
101
|
-
if json_parse_values is None:
|
|
102
|
-
json_parse_values = [
|
|
103
|
-
EventFieldType.FEATURE_NAMES,
|
|
104
|
-
EventFieldType.LABEL_NAMES,
|
|
105
|
-
EventFieldType.MONITOR_CONFIGURATION,
|
|
106
|
-
]
|
|
107
|
-
return _mapping_attributes(
|
|
108
|
-
model_class=cls,
|
|
109
|
-
flattened_dictionary=endpoint_dict,
|
|
110
|
-
json_parse_values=json_parse_values,
|
|
111
|
-
)
|
|
112
|
-
|
|
113
|
-
@validator("model_uri")
|
|
114
|
-
@classmethod
|
|
115
|
-
def validate_model_uri(cls, model_uri):
|
|
116
|
-
"""Validate that the model uri includes the required prefix"""
|
|
117
|
-
prefix, uri = mlrun.datastore.parse_store_uri(model_uri)
|
|
118
|
-
if prefix and prefix != mlrun.utils.helpers.StorePrefix.Model:
|
|
119
|
-
return mlrun.datastore.get_store_uri(
|
|
120
|
-
mlrun.utils.helpers.StorePrefix.Model, uri
|
|
121
|
-
)
|
|
122
|
-
return model_uri
|
|
123
|
-
|
|
124
|
-
|
|
125
45
|
class Histogram(BaseModel):
|
|
126
46
|
buckets: list[float]
|
|
127
47
|
counts: list[int]
|
|
@@ -167,50 +87,24 @@ class Features(BaseModel):
|
|
|
167
87
|
)
|
|
168
88
|
|
|
169
89
|
|
|
170
|
-
class
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
last_request: Optional[str] = ""
|
|
175
|
-
error_count: Optional[int] = 0
|
|
176
|
-
drift_status: Optional[str] = ""
|
|
177
|
-
drift_measures: Optional[dict] = {}
|
|
178
|
-
metrics: Optional[dict[str, dict[str, Any]]] = {
|
|
179
|
-
EventKeyMetrics.GENERIC: {
|
|
180
|
-
EventLiveStats.LATENCY_AVG_1H: 0,
|
|
181
|
-
EventLiveStats.PREDICTIONS_PER_SECOND: 0,
|
|
182
|
-
}
|
|
183
|
-
}
|
|
184
|
-
features: Optional[list[Features]] = []
|
|
185
|
-
children: Optional[list[str]] = []
|
|
186
|
-
children_uids: Optional[list[str]] = []
|
|
187
|
-
endpoint_type: Optional[EndpointType] = EndpointType.NODE_EP
|
|
188
|
-
monitoring_feature_set_uri: Optional[str] = ""
|
|
189
|
-
state: Optional[str] = ""
|
|
190
|
-
|
|
191
|
-
class Config:
|
|
192
|
-
extra = Extra.allow
|
|
90
|
+
class ModelEndpointParser(abc.ABC, BaseModel):
|
|
91
|
+
@classmethod
|
|
92
|
+
def json_parse_values(cls) -> list[str]:
|
|
93
|
+
return []
|
|
193
94
|
|
|
194
95
|
@classmethod
|
|
195
96
|
def from_flat_dict(
|
|
196
97
|
cls, endpoint_dict: dict, json_parse_values: Optional[list] = None
|
|
197
|
-
):
|
|
198
|
-
"""Create a `
|
|
98
|
+
) -> "ModelEndpointParser":
|
|
99
|
+
"""Create a `ModelEndpointParser` object from an endpoint dictionary
|
|
199
100
|
|
|
200
101
|
:param endpoint_dict: Model endpoint dictionary.
|
|
201
102
|
:param json_parse_values: List of dictionary keys with a JSON string value that will be parsed into a
|
|
202
103
|
dictionary using json.loads().
|
|
203
104
|
"""
|
|
204
105
|
if json_parse_values is None:
|
|
205
|
-
json_parse_values =
|
|
206
|
-
|
|
207
|
-
EventFieldType.CURRENT_STATS,
|
|
208
|
-
EventFieldType.DRIFT_MEASURES,
|
|
209
|
-
EventFieldType.METRICS,
|
|
210
|
-
EventFieldType.CHILDREN,
|
|
211
|
-
EventFieldType.CHILDREN_UIDS,
|
|
212
|
-
EventFieldType.ENDPOINT_TYPE,
|
|
213
|
-
]
|
|
106
|
+
json_parse_values = cls.json_parse_values()
|
|
107
|
+
|
|
214
108
|
return _mapping_attributes(
|
|
215
109
|
model_class=cls,
|
|
216
110
|
flattened_dictionary=endpoint_dict,
|
|
@@ -218,16 +112,53 @@ class ModelEndpointStatus(ObjectStatus):
|
|
|
218
112
|
)
|
|
219
113
|
|
|
220
114
|
|
|
115
|
+
class ModelEndpointMetadata(ObjectMetadata, ModelEndpointParser):
|
|
116
|
+
project: constr(regex=PROJECT_PATTERN)
|
|
117
|
+
endpoint_type: EndpointType = EndpointType.NODE_EP
|
|
118
|
+
uid: Optional[constr(regex=MODEL_ENDPOINT_ID_PATTERN)]
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class ModelEndpointSpec(ObjectSpec, ModelEndpointParser):
|
|
122
|
+
model_uid: Optional[str] = ""
|
|
123
|
+
model_name: Optional[str] = ""
|
|
124
|
+
model_tag: Optional[str] = ""
|
|
125
|
+
model_class: Optional[str] = ""
|
|
126
|
+
function_name: Optional[str] = ""
|
|
127
|
+
function_tag: Optional[str] = ""
|
|
128
|
+
function_uid: Optional[str] = ""
|
|
129
|
+
feature_names: Optional[list[str]] = []
|
|
130
|
+
label_names: Optional[list[str]] = []
|
|
131
|
+
feature_stats: Optional[dict] = {}
|
|
132
|
+
function_uri: Optional[str] = "" # <project_name>/<function_hash>
|
|
133
|
+
model_uri: Optional[str] = ""
|
|
134
|
+
children: Optional[list[str]] = []
|
|
135
|
+
children_uids: Optional[list[str]] = []
|
|
136
|
+
monitoring_feature_set_uri: Optional[str] = ""
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
class ModelEndpointStatus(ObjectStatus, ModelEndpointParser):
|
|
140
|
+
state: Optional[str] = "unknown" # will be updated according to the function state
|
|
141
|
+
first_request: Optional[datetime] = None
|
|
142
|
+
monitoring_mode: Optional[ModelMonitoringMode] = ModelMonitoringMode.disabled
|
|
143
|
+
|
|
144
|
+
# operative
|
|
145
|
+
last_request: Optional[datetime] = None
|
|
146
|
+
result_status: Optional[int] = -1
|
|
147
|
+
avg_latency: Optional[float] = None
|
|
148
|
+
error_count: Optional[int] = 0
|
|
149
|
+
current_stats: Optional[dict] = {}
|
|
150
|
+
current_stats_timestamp: Optional[datetime] = None
|
|
151
|
+
drift_measures: Optional[dict] = {}
|
|
152
|
+
drift_measures_timestamp: Optional[datetime] = None
|
|
153
|
+
|
|
154
|
+
|
|
221
155
|
class ModelEndpoint(BaseModel):
|
|
222
156
|
kind: ObjectKind = Field(ObjectKind.model_endpoint, const=True)
|
|
223
157
|
metadata: ModelEndpointMetadata
|
|
224
|
-
spec: ModelEndpointSpec
|
|
225
|
-
status: ModelEndpointStatus
|
|
158
|
+
spec: ModelEndpointSpec
|
|
159
|
+
status: ModelEndpointStatus
|
|
226
160
|
|
|
227
|
-
|
|
228
|
-
extra = Extra.allow
|
|
229
|
-
|
|
230
|
-
def flat_dict(self):
|
|
161
|
+
def flat_dict(self) -> dict[str, Any]:
|
|
231
162
|
"""Generate a flattened `ModelEndpoint` dictionary. The flattened dictionary result is important for storing
|
|
232
163
|
the model endpoint object in the database.
|
|
233
164
|
|
|
@@ -235,35 +166,24 @@ class ModelEndpoint(BaseModel):
|
|
|
235
166
|
"""
|
|
236
167
|
# Convert the ModelEndpoint object into a dictionary using BaseModel dict() function
|
|
237
168
|
# In addition, remove the BaseModel kind as it is not required by the DB schema
|
|
238
|
-
model_endpoint_dictionary = self.dict(exclude={"kind"})
|
|
239
169
|
|
|
170
|
+
model_endpoint_dictionary = self.dict(exclude={"kind"})
|
|
171
|
+
exclude = {
|
|
172
|
+
"tag",
|
|
173
|
+
ModelEndpointSchema.FEATURE_STATS,
|
|
174
|
+
ModelEndpointSchema.CURRENT_STATS,
|
|
175
|
+
ModelEndpointSchema.DRIFT_MEASURES,
|
|
176
|
+
ModelEndpointSchema.FUNCTION_URI,
|
|
177
|
+
ModelEndpointSchema.MODEL_URI,
|
|
178
|
+
}
|
|
240
179
|
# Initialize a flattened dictionary that will be filled with the model endpoint dictionary attributes
|
|
241
180
|
flatten_dict = {}
|
|
242
181
|
for k_object in model_endpoint_dictionary:
|
|
243
182
|
for key in model_endpoint_dictionary[k_object]:
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
# for matching the database required format
|
|
249
|
-
if not isinstance(current_value, (str, bool, int)) or isinstance(
|
|
250
|
-
current_value, enum.IntEnum
|
|
251
|
-
):
|
|
252
|
-
flatten_dict[key] = json.dumps(current_value)
|
|
253
|
-
else:
|
|
254
|
-
flatten_dict[key] = current_value
|
|
255
|
-
|
|
256
|
-
if EventFieldType.METRICS not in flatten_dict:
|
|
257
|
-
# Initialize metrics dictionary
|
|
258
|
-
flatten_dict[EventFieldType.METRICS] = {
|
|
259
|
-
EventKeyMetrics.GENERIC: {
|
|
260
|
-
EventLiveStats.LATENCY_AVG_1H: 0,
|
|
261
|
-
EventLiveStats.PREDICTIONS_PER_SECOND: 0,
|
|
262
|
-
}
|
|
263
|
-
}
|
|
264
|
-
|
|
265
|
-
# Remove the features from the dictionary as this field will be filled only within the feature analysis process
|
|
266
|
-
flatten_dict.pop(EventFieldType.FEATURES, None)
|
|
183
|
+
if key not in exclude:
|
|
184
|
+
# Extract the value of the current field
|
|
185
|
+
flatten_dict[key] = model_endpoint_dictionary[k_object][key]
|
|
186
|
+
|
|
267
187
|
return flatten_dict
|
|
268
188
|
|
|
269
189
|
@classmethod
|
|
@@ -280,9 +200,17 @@ class ModelEndpoint(BaseModel):
|
|
|
280
200
|
status=ModelEndpointStatus.from_flat_dict(endpoint_dict=endpoint_dict),
|
|
281
201
|
)
|
|
282
202
|
|
|
203
|
+
def get(self, field, default=None):
|
|
204
|
+
return (
|
|
205
|
+
getattr(self.metadata, field, None)
|
|
206
|
+
or getattr(self.spec, field, None)
|
|
207
|
+
or getattr(self.status, field, None)
|
|
208
|
+
or default
|
|
209
|
+
)
|
|
210
|
+
|
|
283
211
|
|
|
284
212
|
class ModelEndpointList(BaseModel):
|
|
285
|
-
endpoints: list[ModelEndpoint]
|
|
213
|
+
endpoints: list[ModelEndpoint]
|
|
286
214
|
|
|
287
215
|
|
|
288
216
|
class ModelEndpointMonitoringMetric(BaseModel):
|
|
@@ -132,8 +132,14 @@ class SetNotificationRequest(pydantic.v1.BaseModel):
|
|
|
132
132
|
notifications: list[Notification] = None
|
|
133
133
|
|
|
134
134
|
|
|
135
|
+
class NotificationSummary(pydantic.v1.BaseModel):
|
|
136
|
+
failed: int = 0
|
|
137
|
+
succeeded: int = 0
|
|
138
|
+
|
|
139
|
+
|
|
135
140
|
class NotificationState(pydantic.v1.BaseModel):
|
|
136
141
|
kind: str
|
|
137
142
|
err: Optional[
|
|
138
143
|
str
|
|
139
144
|
] # empty error means that the notifications were sent successfully
|
|
145
|
+
summary: NotificationSummary
|
mlrun/config.py
CHANGED
|
@@ -607,8 +607,6 @@ default_config = {
|
|
|
607
607
|
"default_http_sink_app": "http://nuclio-{project}-{application_name}.{namespace}.svc.cluster.local:8080",
|
|
608
608
|
"parquet_batching_max_events": 10_000,
|
|
609
609
|
"parquet_batching_timeout_secs": timedelta(minutes=1).total_seconds(),
|
|
610
|
-
# See mlrun.model_monitoring.db.stores.ObjectStoreFactory for available options
|
|
611
|
-
"endpoint_store_connection": "",
|
|
612
610
|
# See mlrun.model_monitoring.db.tsdb.ObjectTSDBFactory for available options
|
|
613
611
|
"tsdb_connection": "",
|
|
614
612
|
# See mlrun.common.schemas.model_monitoring.constants.StreamKind for available options
|
|
@@ -81,22 +81,62 @@ class DatastoreProfileBasic(DatastoreProfile):
|
|
|
81
81
|
private: typing.Optional[str] = None
|
|
82
82
|
|
|
83
83
|
|
|
84
|
-
class
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
84
|
+
class ConfigProfile(DatastoreProfile):
|
|
85
|
+
"""
|
|
86
|
+
A profile class for managing configuration data with nested public and private attributes.
|
|
87
|
+
This class extends DatastoreProfile to handle configuration settings, separating them into
|
|
88
|
+
public and private dictionaries. Both dictionaries support nested structures, and the class
|
|
89
|
+
provides functionality to merge these attributes when needed.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
public (Optional[dict]): Dictionary containing public configuration settings,
|
|
93
|
+
supporting nested structures
|
|
94
|
+
private (Optional[dict]): Dictionary containing private/sensitive configuration settings,
|
|
95
|
+
supporting nested structures
|
|
96
|
+
|
|
97
|
+
Example:
|
|
98
|
+
>>> public = {
|
|
99
|
+
"database": {
|
|
100
|
+
"host": "localhost",
|
|
101
|
+
"port": 5432
|
|
102
|
+
},
|
|
103
|
+
"api_version": "v1"
|
|
104
|
+
}
|
|
105
|
+
>>> private = {
|
|
106
|
+
"database": {
|
|
107
|
+
"password": "secret123",
|
|
108
|
+
"username": "admin"
|
|
109
|
+
},
|
|
110
|
+
"api_key": "xyz789"
|
|
111
|
+
}
|
|
112
|
+
>>> config = ConfigProfile("myconfig", public=public, private=private)
|
|
113
|
+
|
|
114
|
+
# When attributes() is called, it merges public and private:
|
|
115
|
+
# {
|
|
116
|
+
# "database": {
|
|
117
|
+
# "host": "localhost",
|
|
118
|
+
# "port": 5432,
|
|
119
|
+
# "password": "secret123",
|
|
120
|
+
# "username": "admin"
|
|
121
|
+
# },
|
|
122
|
+
# "api_version": "v1",
|
|
123
|
+
# "api_key": "xyz789"
|
|
124
|
+
# }
|
|
125
|
+
|
|
126
|
+
"""
|
|
127
|
+
|
|
128
|
+
type = "config"
|
|
129
|
+
_private_attributes = "private"
|
|
130
|
+
public: typing.Optional[dict] = None
|
|
131
|
+
private: typing.Optional[dict] = None
|
|
132
|
+
|
|
133
|
+
def attributes(self):
|
|
134
|
+
res = {}
|
|
135
|
+
if self.public:
|
|
136
|
+
res = merge(res, self.public)
|
|
137
|
+
if self.private:
|
|
138
|
+
res = merge(res, self.private)
|
|
139
|
+
return res
|
|
100
140
|
|
|
101
141
|
|
|
102
142
|
class DatastoreProfileKafkaTarget(DatastoreProfile):
|
|
@@ -494,7 +534,7 @@ class DatastoreProfile2Json(pydantic.v1.BaseModel):
|
|
|
494
534
|
"gcs": DatastoreProfileGCS,
|
|
495
535
|
"az": DatastoreProfileAzureBlob,
|
|
496
536
|
"hdfs": DatastoreProfileHdfs,
|
|
497
|
-
"
|
|
537
|
+
"config": ConfigProfile,
|
|
498
538
|
}
|
|
499
539
|
if datastore_type in ds_profile_factory:
|
|
500
540
|
return ds_profile_factory[datastore_type].parse_obj(decoded_dict)
|
mlrun/datastore/vectorstore.py
CHANGED
|
@@ -13,7 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
import inspect
|
|
16
|
-
from
|
|
16
|
+
from collections.abc import Iterable
|
|
17
17
|
from typing import Union
|
|
18
18
|
|
|
19
19
|
from mlrun.artifacts import DocumentArtifact
|
|
@@ -21,57 +21,27 @@ from mlrun.artifacts import DocumentArtifact
|
|
|
21
21
|
|
|
22
22
|
class VectorStoreCollection:
|
|
23
23
|
"""
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
add_documents(documents: list["Document"], **kwargs):
|
|
35
|
-
Adds a list of documents to the collection and updates the MLRun artifacts associated with the documents
|
|
36
|
-
if an MLRun context is present.
|
|
37
|
-
|
|
38
|
-
add_artifacts(artifacts: list[DocumentArtifact], splitter=None, **kwargs):
|
|
39
|
-
Adds a list of DocumentArtifact objects to the collection, optionally using a splitter to convert
|
|
40
|
-
artifacts to documents.
|
|
41
|
-
|
|
42
|
-
remove_itself_from_artifact(artifact: DocumentArtifact):
|
|
43
|
-
Removes the current object from the given artifact's collection and updates the artifact.
|
|
44
|
-
|
|
45
|
-
delete_artifacts(artifacts: list[DocumentArtifact]):
|
|
46
|
-
Deletes a list of DocumentArtifact objects from the collection and updates the MLRun context.
|
|
47
|
-
Raises NotImplementedError if the delete operation is not supported for the collection implementation.
|
|
24
|
+
A wrapper class for vector store collections with MLRun integration.
|
|
25
|
+
|
|
26
|
+
This class wraps a vector store implementation (like Milvus, Chroma) and provides
|
|
27
|
+
integration with MLRun context for document and artifact management. It delegates
|
|
28
|
+
most operations to the underlying vector store while handling MLRun-specific
|
|
29
|
+
functionality.
|
|
30
|
+
|
|
31
|
+
The class implements attribute delegation through __getattr__ and __setattr__,
|
|
32
|
+
allowing direct access to the underlying vector store's methods and attributes
|
|
33
|
+
while maintaining MLRun integration.
|
|
48
34
|
"""
|
|
49
35
|
|
|
50
36
|
def __init__(
|
|
51
37
|
self,
|
|
52
|
-
vector_store_class: str,
|
|
53
38
|
mlrun_context: Union["MlrunProject", "MLClientCtx"], # noqa: F821
|
|
54
|
-
datastore_profile: str,
|
|
55
39
|
collection_name: str,
|
|
56
|
-
|
|
40
|
+
vector_store: "VectorStore", # noqa: F821
|
|
57
41
|
):
|
|
58
|
-
# Import the vector store class dynamically
|
|
59
|
-
module_name, class_name = vector_store_class.rsplit(".", 1)
|
|
60
|
-
module = import_module(module_name)
|
|
61
|
-
vector_store_class = getattr(module, class_name)
|
|
62
|
-
|
|
63
|
-
signature = inspect.signature(vector_store_class)
|
|
64
|
-
|
|
65
|
-
# Create the vector store instance
|
|
66
|
-
if "collection_name" in signature.parameters.keys():
|
|
67
|
-
vector_store = vector_store_class(collection_name=collection_name, **kwargs)
|
|
68
|
-
else:
|
|
69
|
-
vector_store = vector_store_class(**kwargs)
|
|
70
|
-
|
|
71
42
|
self._collection_impl = vector_store
|
|
72
43
|
self._mlrun_context = mlrun_context
|
|
73
44
|
self.collection_name = collection_name
|
|
74
|
-
self.id = datastore_profile + "/" + collection_name
|
|
75
45
|
|
|
76
46
|
def __getattr__(self, name):
|
|
77
47
|
# This method is called when an attribute is not found in the usual places
|
|
@@ -112,40 +82,74 @@ class VectorStoreCollection:
|
|
|
112
82
|
)
|
|
113
83
|
if mlrun_uri:
|
|
114
84
|
artifact = self._mlrun_context.get_store_resource(mlrun_uri)
|
|
115
|
-
artifact.collection_add(self.
|
|
85
|
+
artifact.collection_add(self.collection_name)
|
|
116
86
|
self._mlrun_context.update_artifact(artifact)
|
|
87
|
+
|
|
117
88
|
return self._collection_impl.add_documents(documents, **kwargs)
|
|
118
89
|
|
|
119
90
|
def add_artifacts(self, artifacts: list[DocumentArtifact], splitter=None, **kwargs):
|
|
120
91
|
"""
|
|
121
|
-
Add a list of DocumentArtifact objects to the collection.
|
|
92
|
+
Add a list of DocumentArtifact objects to the vector store collection.
|
|
93
|
+
|
|
94
|
+
Converts artifacts to LangChain documents, adds them to the vector store, and
|
|
95
|
+
updates the MLRun context. If documents are split, the IDs are handled appropriately.
|
|
122
96
|
|
|
123
97
|
Args:
|
|
124
|
-
artifacts (list[DocumentArtifact]):
|
|
125
|
-
splitter (optional):
|
|
126
|
-
|
|
98
|
+
artifacts (list[DocumentArtifact]): List of DocumentArtifact objects to add
|
|
99
|
+
splitter (optional): Document splitter to break artifacts into smaller chunks.
|
|
100
|
+
If None, each artifact becomes a single document.
|
|
101
|
+
**kwargs: Additional arguments passed to the underlying add_documents method.
|
|
102
|
+
Special handling for 'ids' kwarg:
|
|
103
|
+
- If provided and document is split, IDs are generated as "{original_id}_{i}"
|
|
104
|
+
where i starts from 1 (e.g., "doc1_1", "doc1_2", etc.)
|
|
105
|
+
- If provided and document isn't split, original IDs are used as-is
|
|
127
106
|
|
|
128
107
|
Returns:
|
|
129
|
-
list:
|
|
108
|
+
list: List of IDs for all added documents. When no custom IDs are provided:
|
|
109
|
+
- Without splitting: Vector store generates IDs automatically
|
|
110
|
+
- With splitting: Vector store generates separate IDs for each chunk
|
|
111
|
+
When custom IDs are provided:
|
|
112
|
+
- Without splitting: Uses provided IDs directly
|
|
113
|
+
- With splitting: Generates sequential IDs as "{original_id}_{i}" for each chunk
|
|
130
114
|
"""
|
|
131
115
|
all_ids = []
|
|
132
|
-
|
|
116
|
+
user_ids = kwargs.pop("ids", None)
|
|
117
|
+
|
|
118
|
+
if user_ids:
|
|
119
|
+
if not isinstance(user_ids, Iterable):
|
|
120
|
+
raise ValueError("IDs must be an iterable collection")
|
|
121
|
+
if len(user_ids) != len(artifacts):
|
|
122
|
+
raise ValueError(
|
|
123
|
+
"The number of IDs should match the number of artifacts"
|
|
124
|
+
)
|
|
125
|
+
for index, artifact in enumerate(artifacts):
|
|
133
126
|
documents = artifact.to_langchain_documents(splitter)
|
|
134
|
-
artifact.collection_add(self.
|
|
135
|
-
self._mlrun_context
|
|
127
|
+
artifact.collection_add(self.collection_name)
|
|
128
|
+
if self._mlrun_context:
|
|
129
|
+
self._mlrun_context.update_artifact(artifact)
|
|
130
|
+
if user_ids:
|
|
131
|
+
num_of_documents = len(documents)
|
|
132
|
+
if num_of_documents > 1:
|
|
133
|
+
ids_to_pass = [
|
|
134
|
+
f"{user_ids[index]}_{i}" for i in range(1, num_of_documents + 1)
|
|
135
|
+
]
|
|
136
|
+
else:
|
|
137
|
+
ids_to_pass = [user_ids[index]]
|
|
138
|
+
kwargs["ids"] = ids_to_pass
|
|
136
139
|
ids = self._collection_impl.add_documents(documents, **kwargs)
|
|
137
140
|
all_ids.extend(ids)
|
|
138
141
|
return all_ids
|
|
139
142
|
|
|
140
|
-
def
|
|
143
|
+
def remove_from_artifact(self, artifact: DocumentArtifact):
|
|
141
144
|
"""
|
|
142
145
|
Remove the current object from the given artifact's collection and update the artifact.
|
|
143
146
|
|
|
144
147
|
Args:
|
|
145
148
|
artifact (DocumentArtifact): The artifact from which the current object should be removed.
|
|
146
149
|
"""
|
|
147
|
-
artifact.collection_remove(self.
|
|
148
|
-
self._mlrun_context
|
|
150
|
+
artifact.collection_remove(self.collection_name)
|
|
151
|
+
if self._mlrun_context:
|
|
152
|
+
self._mlrun_context.update_artifact(artifact)
|
|
149
153
|
|
|
150
154
|
def delete_artifacts(self, artifacts: list[DocumentArtifact]):
|
|
151
155
|
"""
|
|
@@ -162,13 +166,15 @@ class VectorStoreCollection:
|
|
|
162
166
|
"""
|
|
163
167
|
store_class = self._collection_impl.__class__.__name__.lower()
|
|
164
168
|
for artifact in artifacts:
|
|
165
|
-
artifact.collection_remove(self.
|
|
166
|
-
self._mlrun_context
|
|
169
|
+
artifact.collection_remove(self.collection_name)
|
|
170
|
+
if self._mlrun_context:
|
|
171
|
+
self._mlrun_context.update_artifact(artifact)
|
|
172
|
+
|
|
167
173
|
if store_class == "milvus":
|
|
168
|
-
expr = f"{DocumentArtifact.METADATA_SOURCE_KEY} == '{artifact.
|
|
174
|
+
expr = f"{DocumentArtifact.METADATA_SOURCE_KEY} == '{artifact.get_source()}'"
|
|
169
175
|
return self._collection_impl.delete(expr=expr)
|
|
170
176
|
elif store_class == "chroma":
|
|
171
|
-
where = {DocumentArtifact.METADATA_SOURCE_KEY: artifact.
|
|
177
|
+
where = {DocumentArtifact.METADATA_SOURCE_KEY: artifact.get_source()}
|
|
172
178
|
return self._collection_impl.delete(where=where)
|
|
173
179
|
|
|
174
180
|
elif (
|
|
@@ -177,7 +183,9 @@ class VectorStoreCollection:
|
|
|
177
183
|
in inspect.signature(self._collection_impl.delete).parameters
|
|
178
184
|
):
|
|
179
185
|
filter = {
|
|
180
|
-
"metadata": {
|
|
186
|
+
"metadata": {
|
|
187
|
+
DocumentArtifact.METADATA_SOURCE_KEY: artifact.get_source()
|
|
188
|
+
}
|
|
181
189
|
}
|
|
182
190
|
return self._collection_impl.delete(filter=filter)
|
|
183
191
|
else:
|