mlrun 1.8.0rc4__py3-none-any.whl → 1.8.0rc6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mlrun might be problematic. Click here for more details.

Files changed (69) hide show
  1. mlrun/__init__.py +4 -3
  2. mlrun/alerts/alert.py +129 -2
  3. mlrun/artifacts/__init__.py +1 -1
  4. mlrun/artifacts/base.py +12 -1
  5. mlrun/artifacts/document.py +59 -38
  6. mlrun/common/model_monitoring/__init__.py +0 -2
  7. mlrun/common/model_monitoring/helpers.py +0 -28
  8. mlrun/common/schemas/__init__.py +1 -4
  9. mlrun/common/schemas/alert.py +3 -0
  10. mlrun/common/schemas/artifact.py +4 -0
  11. mlrun/common/schemas/client_spec.py +0 -1
  12. mlrun/common/schemas/model_monitoring/__init__.py +0 -6
  13. mlrun/common/schemas/model_monitoring/constants.py +11 -9
  14. mlrun/common/schemas/model_monitoring/model_endpoints.py +77 -149
  15. mlrun/common/schemas/notification.py +6 -0
  16. mlrun/config.py +0 -2
  17. mlrun/datastore/datastore_profile.py +57 -17
  18. mlrun/datastore/vectorstore.py +67 -59
  19. mlrun/db/base.py +22 -18
  20. mlrun/db/factory.py +0 -3
  21. mlrun/db/httpdb.py +122 -150
  22. mlrun/db/nopdb.py +33 -17
  23. mlrun/execution.py +43 -29
  24. mlrun/model.py +7 -0
  25. mlrun/model_monitoring/__init__.py +3 -2
  26. mlrun/model_monitoring/api.py +40 -43
  27. mlrun/model_monitoring/applications/_application_steps.py +4 -2
  28. mlrun/model_monitoring/applications/base.py +65 -6
  29. mlrun/model_monitoring/applications/context.py +64 -33
  30. mlrun/model_monitoring/applications/evidently_base.py +0 -1
  31. mlrun/model_monitoring/applications/histogram_data_drift.py +2 -6
  32. mlrun/model_monitoring/controller.py +43 -37
  33. mlrun/model_monitoring/db/__init__.py +0 -2
  34. mlrun/model_monitoring/db/tsdb/base.py +2 -1
  35. mlrun/model_monitoring/db/tsdb/tdengine/tdengine_connector.py +2 -1
  36. mlrun/model_monitoring/db/tsdb/v3io/v3io_connector.py +43 -0
  37. mlrun/model_monitoring/helpers.py +12 -66
  38. mlrun/model_monitoring/stream_processing.py +83 -270
  39. mlrun/model_monitoring/writer.py +1 -10
  40. mlrun/projects/project.py +87 -74
  41. mlrun/runtimes/nuclio/function.py +7 -6
  42. mlrun/runtimes/nuclio/serving.py +7 -1
  43. mlrun/serving/routers.py +158 -145
  44. mlrun/serving/server.py +6 -0
  45. mlrun/serving/states.py +2 -0
  46. mlrun/serving/v2_serving.py +69 -60
  47. mlrun/utils/helpers.py +14 -30
  48. mlrun/utils/notifications/notification/mail.py +36 -9
  49. mlrun/utils/notifications/notification_pusher.py +34 -13
  50. mlrun/utils/version/version.json +2 -2
  51. {mlrun-1.8.0rc4.dist-info → mlrun-1.8.0rc6.dist-info}/METADATA +5 -4
  52. {mlrun-1.8.0rc4.dist-info → mlrun-1.8.0rc6.dist-info}/RECORD +56 -69
  53. mlrun/common/schemas/model_monitoring/model_endpoint_v2.py +0 -149
  54. mlrun/model_monitoring/db/stores/__init__.py +0 -136
  55. mlrun/model_monitoring/db/stores/base/__init__.py +0 -15
  56. mlrun/model_monitoring/db/stores/base/store.py +0 -154
  57. mlrun/model_monitoring/db/stores/sqldb/__init__.py +0 -13
  58. mlrun/model_monitoring/db/stores/sqldb/models/__init__.py +0 -46
  59. mlrun/model_monitoring/db/stores/sqldb/models/base.py +0 -93
  60. mlrun/model_monitoring/db/stores/sqldb/models/mysql.py +0 -47
  61. mlrun/model_monitoring/db/stores/sqldb/models/sqlite.py +0 -25
  62. mlrun/model_monitoring/db/stores/sqldb/sql_store.py +0 -408
  63. mlrun/model_monitoring/db/stores/v3io_kv/__init__.py +0 -13
  64. mlrun/model_monitoring/db/stores/v3io_kv/kv_store.py +0 -464
  65. mlrun/model_monitoring/model_endpoint.py +0 -120
  66. {mlrun-1.8.0rc4.dist-info → mlrun-1.8.0rc6.dist-info}/LICENSE +0 -0
  67. {mlrun-1.8.0rc4.dist-info → mlrun-1.8.0rc6.dist-info}/WHEEL +0 -0
  68. {mlrun-1.8.0rc4.dist-info → mlrun-1.8.0rc6.dist-info}/entry_points.txt +0 -0
  69. {mlrun-1.8.0rc4.dist-info → mlrun-1.8.0rc6.dist-info}/top_level.txt +0 -0
@@ -11,27 +11,22 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
-
15
- import enum
14
+ import abc
16
15
  import json
17
16
  from datetime import datetime
18
17
  from typing import Any, NamedTuple, Optional, TypeVar
19
18
 
20
- from pydantic.v1 import BaseModel, Extra, Field, constr, validator
19
+ from pydantic.v1 import BaseModel, Field, constr
21
20
 
22
21
  # TODO: remove the unused import below after `mlrun.datastore` and `mlrun.utils` usage is removed.
23
22
  # At the moment `make lint` fails if this is removed.
24
- import mlrun.common.model_monitoring
25
-
26
- from ..object import ObjectKind, ObjectSpec, ObjectStatus
23
+ from ..object import ObjectKind, ObjectMetadata, ObjectSpec, ObjectStatus
24
+ from . import ModelEndpointSchema
27
25
  from .constants import (
28
26
  FQN_REGEX,
29
27
  MODEL_ENDPOINT_ID_PATTERN,
30
28
  PROJECT_PATTERN,
31
29
  EndpointType,
32
- EventFieldType,
33
- EventKeyMetrics,
34
- EventLiveStats,
35
30
  ModelEndpointMonitoringMetricType,
36
31
  ModelMonitoringMode,
37
32
  ResultKindApp,
@@ -47,81 +42,6 @@ class ModelMonitoringStoreKinds:
47
42
  EVENTS = "events"
48
43
 
49
44
 
50
- class ModelEndpointMetadata(BaseModel):
51
- project: constr(regex=PROJECT_PATTERN)
52
- uid: constr(regex=MODEL_ENDPOINT_ID_PATTERN)
53
- labels: Optional[dict] = {}
54
-
55
- class Config:
56
- extra = Extra.allow
57
-
58
- @classmethod
59
- def from_flat_dict(
60
- cls, endpoint_dict: dict, json_parse_values: Optional[list] = None
61
- ):
62
- """Create a `ModelEndpointMetadata` object from an endpoint dictionary
63
-
64
- :param endpoint_dict: Model endpoint dictionary.
65
- :param json_parse_values: List of dictionary keys with a JSON string value that will be parsed into a
66
- dictionary using json.loads().
67
- """
68
- if json_parse_values is None:
69
- json_parse_values = [EventFieldType.LABELS]
70
-
71
- return _mapping_attributes(
72
- model_class=cls,
73
- flattened_dictionary=endpoint_dict,
74
- json_parse_values=json_parse_values,
75
- )
76
-
77
-
78
- class ModelEndpointSpec(ObjectSpec):
79
- function_uri: Optional[str] = "" # <project_name>/<function_name>:<tag>
80
- model: Optional[str] = "" # <model_name>:<version>
81
- model_class: Optional[str] = ""
82
- model_uri: Optional[str] = ""
83
- feature_names: Optional[list[str]] = []
84
- label_names: Optional[list[str]] = []
85
- stream_path: Optional[str] = ""
86
- algorithm: Optional[str] = ""
87
- monitor_configuration: Optional[dict] = {}
88
- active: Optional[bool] = True
89
- monitoring_mode: Optional[ModelMonitoringMode] = ModelMonitoringMode.disabled.value
90
-
91
- @classmethod
92
- def from_flat_dict(
93
- cls, endpoint_dict: dict, json_parse_values: Optional[list] = None
94
- ):
95
- """Create a `ModelEndpointSpec` object from an endpoint dictionary
96
-
97
- :param endpoint_dict: Model endpoint dictionary.
98
- :param json_parse_values: List of dictionary keys with a JSON string value that will be parsed into a
99
- dictionary using json.loads().
100
- """
101
- if json_parse_values is None:
102
- json_parse_values = [
103
- EventFieldType.FEATURE_NAMES,
104
- EventFieldType.LABEL_NAMES,
105
- EventFieldType.MONITOR_CONFIGURATION,
106
- ]
107
- return _mapping_attributes(
108
- model_class=cls,
109
- flattened_dictionary=endpoint_dict,
110
- json_parse_values=json_parse_values,
111
- )
112
-
113
- @validator("model_uri")
114
- @classmethod
115
- def validate_model_uri(cls, model_uri):
116
- """Validate that the model uri includes the required prefix"""
117
- prefix, uri = mlrun.datastore.parse_store_uri(model_uri)
118
- if prefix and prefix != mlrun.utils.helpers.StorePrefix.Model:
119
- return mlrun.datastore.get_store_uri(
120
- mlrun.utils.helpers.StorePrefix.Model, uri
121
- )
122
- return model_uri
123
-
124
-
125
45
  class Histogram(BaseModel):
126
46
  buckets: list[float]
127
47
  counts: list[int]
@@ -167,50 +87,24 @@ class Features(BaseModel):
167
87
  )
168
88
 
169
89
 
170
- class ModelEndpointStatus(ObjectStatus):
171
- feature_stats: Optional[dict] = {}
172
- current_stats: Optional[dict] = {}
173
- first_request: Optional[str] = ""
174
- last_request: Optional[str] = ""
175
- error_count: Optional[int] = 0
176
- drift_status: Optional[str] = ""
177
- drift_measures: Optional[dict] = {}
178
- metrics: Optional[dict[str, dict[str, Any]]] = {
179
- EventKeyMetrics.GENERIC: {
180
- EventLiveStats.LATENCY_AVG_1H: 0,
181
- EventLiveStats.PREDICTIONS_PER_SECOND: 0,
182
- }
183
- }
184
- features: Optional[list[Features]] = []
185
- children: Optional[list[str]] = []
186
- children_uids: Optional[list[str]] = []
187
- endpoint_type: Optional[EndpointType] = EndpointType.NODE_EP
188
- monitoring_feature_set_uri: Optional[str] = ""
189
- state: Optional[str] = ""
190
-
191
- class Config:
192
- extra = Extra.allow
90
+ class ModelEndpointParser(abc.ABC, BaseModel):
91
+ @classmethod
92
+ def json_parse_values(cls) -> list[str]:
93
+ return []
193
94
 
194
95
  @classmethod
195
96
  def from_flat_dict(
196
97
  cls, endpoint_dict: dict, json_parse_values: Optional[list] = None
197
- ):
198
- """Create a `ModelEndpointStatus` object from an endpoint dictionary
98
+ ) -> "ModelEndpointParser":
99
+ """Create a `ModelEndpointParser` object from an endpoint dictionary
199
100
 
200
101
  :param endpoint_dict: Model endpoint dictionary.
201
102
  :param json_parse_values: List of dictionary keys with a JSON string value that will be parsed into a
202
103
  dictionary using json.loads().
203
104
  """
204
105
  if json_parse_values is None:
205
- json_parse_values = [
206
- EventFieldType.FEATURE_STATS,
207
- EventFieldType.CURRENT_STATS,
208
- EventFieldType.DRIFT_MEASURES,
209
- EventFieldType.METRICS,
210
- EventFieldType.CHILDREN,
211
- EventFieldType.CHILDREN_UIDS,
212
- EventFieldType.ENDPOINT_TYPE,
213
- ]
106
+ json_parse_values = cls.json_parse_values()
107
+
214
108
  return _mapping_attributes(
215
109
  model_class=cls,
216
110
  flattened_dictionary=endpoint_dict,
@@ -218,16 +112,53 @@ class ModelEndpointStatus(ObjectStatus):
218
112
  )
219
113
 
220
114
 
115
+ class ModelEndpointMetadata(ObjectMetadata, ModelEndpointParser):
116
+ project: constr(regex=PROJECT_PATTERN)
117
+ endpoint_type: EndpointType = EndpointType.NODE_EP
118
+ uid: Optional[constr(regex=MODEL_ENDPOINT_ID_PATTERN)]
119
+
120
+
121
+ class ModelEndpointSpec(ObjectSpec, ModelEndpointParser):
122
+ model_uid: Optional[str] = ""
123
+ model_name: Optional[str] = ""
124
+ model_tag: Optional[str] = ""
125
+ model_class: Optional[str] = ""
126
+ function_name: Optional[str] = ""
127
+ function_tag: Optional[str] = ""
128
+ function_uid: Optional[str] = ""
129
+ feature_names: Optional[list[str]] = []
130
+ label_names: Optional[list[str]] = []
131
+ feature_stats: Optional[dict] = {}
132
+ function_uri: Optional[str] = "" # <project_name>/<function_hash>
133
+ model_uri: Optional[str] = ""
134
+ children: Optional[list[str]] = []
135
+ children_uids: Optional[list[str]] = []
136
+ monitoring_feature_set_uri: Optional[str] = ""
137
+
138
+
139
+ class ModelEndpointStatus(ObjectStatus, ModelEndpointParser):
140
+ state: Optional[str] = "unknown" # will be updated according to the function state
141
+ first_request: Optional[datetime] = None
142
+ monitoring_mode: Optional[ModelMonitoringMode] = ModelMonitoringMode.disabled
143
+
144
+ # operative
145
+ last_request: Optional[datetime] = None
146
+ result_status: Optional[int] = -1
147
+ avg_latency: Optional[float] = None
148
+ error_count: Optional[int] = 0
149
+ current_stats: Optional[dict] = {}
150
+ current_stats_timestamp: Optional[datetime] = None
151
+ drift_measures: Optional[dict] = {}
152
+ drift_measures_timestamp: Optional[datetime] = None
153
+
154
+
221
155
  class ModelEndpoint(BaseModel):
222
156
  kind: ObjectKind = Field(ObjectKind.model_endpoint, const=True)
223
157
  metadata: ModelEndpointMetadata
224
- spec: ModelEndpointSpec = ModelEndpointSpec()
225
- status: ModelEndpointStatus = ModelEndpointStatus()
158
+ spec: ModelEndpointSpec
159
+ status: ModelEndpointStatus
226
160
 
227
- class Config:
228
- extra = Extra.allow
229
-
230
- def flat_dict(self):
161
+ def flat_dict(self) -> dict[str, Any]:
231
162
  """Generate a flattened `ModelEndpoint` dictionary. The flattened dictionary result is important for storing
232
163
  the model endpoint object in the database.
233
164
 
@@ -235,35 +166,24 @@ class ModelEndpoint(BaseModel):
235
166
  """
236
167
  # Convert the ModelEndpoint object into a dictionary using BaseModel dict() function
237
168
  # In addition, remove the BaseModel kind as it is not required by the DB schema
238
- model_endpoint_dictionary = self.dict(exclude={"kind"})
239
169
 
170
+ model_endpoint_dictionary = self.dict(exclude={"kind"})
171
+ exclude = {
172
+ "tag",
173
+ ModelEndpointSchema.FEATURE_STATS,
174
+ ModelEndpointSchema.CURRENT_STATS,
175
+ ModelEndpointSchema.DRIFT_MEASURES,
176
+ ModelEndpointSchema.FUNCTION_URI,
177
+ ModelEndpointSchema.MODEL_URI,
178
+ }
240
179
  # Initialize a flattened dictionary that will be filled with the model endpoint dictionary attributes
241
180
  flatten_dict = {}
242
181
  for k_object in model_endpoint_dictionary:
243
182
  for key in model_endpoint_dictionary[k_object]:
244
- # Extract the value of the current field
245
- current_value = model_endpoint_dictionary[k_object][key]
246
-
247
- # If the value is not from type str or bool (e.g. dict), convert it into a JSON string
248
- # for matching the database required format
249
- if not isinstance(current_value, (str, bool, int)) or isinstance(
250
- current_value, enum.IntEnum
251
- ):
252
- flatten_dict[key] = json.dumps(current_value)
253
- else:
254
- flatten_dict[key] = current_value
255
-
256
- if EventFieldType.METRICS not in flatten_dict:
257
- # Initialize metrics dictionary
258
- flatten_dict[EventFieldType.METRICS] = {
259
- EventKeyMetrics.GENERIC: {
260
- EventLiveStats.LATENCY_AVG_1H: 0,
261
- EventLiveStats.PREDICTIONS_PER_SECOND: 0,
262
- }
263
- }
264
-
265
- # Remove the features from the dictionary as this field will be filled only within the feature analysis process
266
- flatten_dict.pop(EventFieldType.FEATURES, None)
183
+ if key not in exclude:
184
+ # Extract the value of the current field
185
+ flatten_dict[key] = model_endpoint_dictionary[k_object][key]
186
+
267
187
  return flatten_dict
268
188
 
269
189
  @classmethod
@@ -280,9 +200,17 @@ class ModelEndpoint(BaseModel):
280
200
  status=ModelEndpointStatus.from_flat_dict(endpoint_dict=endpoint_dict),
281
201
  )
282
202
 
203
+ def get(self, field, default=None):
204
+ return (
205
+ getattr(self.metadata, field, None)
206
+ or getattr(self.spec, field, None)
207
+ or getattr(self.status, field, None)
208
+ or default
209
+ )
210
+
283
211
 
284
212
  class ModelEndpointList(BaseModel):
285
- endpoints: list[ModelEndpoint] = []
213
+ endpoints: list[ModelEndpoint]
286
214
 
287
215
 
288
216
  class ModelEndpointMonitoringMetric(BaseModel):
@@ -132,8 +132,14 @@ class SetNotificationRequest(pydantic.v1.BaseModel):
132
132
  notifications: list[Notification] = None
133
133
 
134
134
 
135
+ class NotificationSummary(pydantic.v1.BaseModel):
136
+ failed: int = 0
137
+ succeeded: int = 0
138
+
139
+
135
140
  class NotificationState(pydantic.v1.BaseModel):
136
141
  kind: str
137
142
  err: Optional[
138
143
  str
139
144
  ] # empty error means that the notifications were sent successfully
145
+ summary: NotificationSummary
mlrun/config.py CHANGED
@@ -607,8 +607,6 @@ default_config = {
607
607
  "default_http_sink_app": "http://nuclio-{project}-{application_name}.{namespace}.svc.cluster.local:8080",
608
608
  "parquet_batching_max_events": 10_000,
609
609
  "parquet_batching_timeout_secs": timedelta(minutes=1).total_seconds(),
610
- # See mlrun.model_monitoring.db.stores.ObjectStoreFactory for available options
611
- "endpoint_store_connection": "",
612
610
  # See mlrun.model_monitoring.db.tsdb.ObjectTSDBFactory for available options
613
611
  "tsdb_connection": "",
614
612
  # See mlrun.common.schemas.model_monitoring.constants.StreamKind for available options
@@ -81,22 +81,62 @@ class DatastoreProfileBasic(DatastoreProfile):
81
81
  private: typing.Optional[str] = None
82
82
 
83
83
 
84
- class VectorStoreProfile(DatastoreProfile):
85
- type: str = pydantic.Field("vector")
86
- _private_attributes = ("kwargs_private",)
87
- vector_store_class: str
88
- kwargs_public: typing.Optional[dict] = None
89
- kwargs_private: typing.Optional[dict] = None
90
-
91
- def attributes(self, kwargs=None):
92
- attributes = {}
93
- if self.kwargs_public:
94
- attributes = merge(attributes, self.kwargs_public)
95
- if self.kwargs_private:
96
- attributes = merge(attributes, self.kwargs_private)
97
- if kwargs:
98
- attributes = merge(attributes, kwargs)
99
- return attributes
84
+ class ConfigProfile(DatastoreProfile):
85
+ """
86
+ A profile class for managing configuration data with nested public and private attributes.
87
+ This class extends DatastoreProfile to handle configuration settings, separating them into
88
+ public and private dictionaries. Both dictionaries support nested structures, and the class
89
+ provides functionality to merge these attributes when needed.
90
+
91
+ Args:
92
+ public (Optional[dict]): Dictionary containing public configuration settings,
93
+ supporting nested structures
94
+ private (Optional[dict]): Dictionary containing private/sensitive configuration settings,
95
+ supporting nested structures
96
+
97
+ Example:
98
+ >>> public = {
99
+ "database": {
100
+ "host": "localhost",
101
+ "port": 5432
102
+ },
103
+ "api_version": "v1"
104
+ }
105
+ >>> private = {
106
+ "database": {
107
+ "password": "secret123",
108
+ "username": "admin"
109
+ },
110
+ "api_key": "xyz789"
111
+ }
112
+ >>> config = ConfigProfile("myconfig", public=public, private=private)
113
+
114
+ # When attributes() is called, it merges public and private:
115
+ # {
116
+ # "database": {
117
+ # "host": "localhost",
118
+ # "port": 5432,
119
+ # "password": "secret123",
120
+ # "username": "admin"
121
+ # },
122
+ # "api_version": "v1",
123
+ # "api_key": "xyz789"
124
+ # }
125
+
126
+ """
127
+
128
+ type = "config"
129
+ _private_attributes = "private"
130
+ public: typing.Optional[dict] = None
131
+ private: typing.Optional[dict] = None
132
+
133
+ def attributes(self):
134
+ res = {}
135
+ if self.public:
136
+ res = merge(res, self.public)
137
+ if self.private:
138
+ res = merge(res, self.private)
139
+ return res
100
140
 
101
141
 
102
142
  class DatastoreProfileKafkaTarget(DatastoreProfile):
@@ -494,7 +534,7 @@ class DatastoreProfile2Json(pydantic.v1.BaseModel):
494
534
  "gcs": DatastoreProfileGCS,
495
535
  "az": DatastoreProfileAzureBlob,
496
536
  "hdfs": DatastoreProfileHdfs,
497
- "vector": VectorStoreProfile,
537
+ "config": ConfigProfile,
498
538
  }
499
539
  if datastore_type in ds_profile_factory:
500
540
  return ds_profile_factory[datastore_type].parse_obj(decoded_dict)
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import inspect
16
- from importlib import import_module
16
+ from collections.abc import Iterable
17
17
  from typing import Union
18
18
 
19
19
  from mlrun.artifacts import DocumentArtifact
@@ -21,57 +21,27 @@ from mlrun.artifacts import DocumentArtifact
21
21
 
22
22
  class VectorStoreCollection:
23
23
  """
24
- VectorStoreCollection is a class that manages a collection of vector stores, providing methods to add and delete
25
- documents and artifacts, and to interact with an MLRun context.
26
-
27
- Attributes:
28
- _collection_impl (object): The underlying collection implementation.
29
- _mlrun_context (Union[MlrunProject, MLClientCtx]): The MLRun context associated with the collection.
30
- collection_name (str): The name of the collection.
31
- id (str): The unique identifier of the collection, composed of the datastore profile and collection name.
32
-
33
- Methods:
34
- add_documents(documents: list["Document"], **kwargs):
35
- Adds a list of documents to the collection and updates the MLRun artifacts associated with the documents
36
- if an MLRun context is present.
37
-
38
- add_artifacts(artifacts: list[DocumentArtifact], splitter=None, **kwargs):
39
- Adds a list of DocumentArtifact objects to the collection, optionally using a splitter to convert
40
- artifacts to documents.
41
-
42
- remove_itself_from_artifact(artifact: DocumentArtifact):
43
- Removes the current object from the given artifact's collection and updates the artifact.
44
-
45
- delete_artifacts(artifacts: list[DocumentArtifact]):
46
- Deletes a list of DocumentArtifact objects from the collection and updates the MLRun context.
47
- Raises NotImplementedError if the delete operation is not supported for the collection implementation.
24
+ A wrapper class for vector store collections with MLRun integration.
25
+
26
+ This class wraps a vector store implementation (like Milvus, Chroma) and provides
27
+ integration with MLRun context for document and artifact management. It delegates
28
+ most operations to the underlying vector store while handling MLRun-specific
29
+ functionality.
30
+
31
+ The class implements attribute delegation through __getattr__ and __setattr__,
32
+ allowing direct access to the underlying vector store's methods and attributes
33
+ while maintaining MLRun integration.
48
34
  """
49
35
 
50
36
  def __init__(
51
37
  self,
52
- vector_store_class: str,
53
38
  mlrun_context: Union["MlrunProject", "MLClientCtx"], # noqa: F821
54
- datastore_profile: str,
55
39
  collection_name: str,
56
- **kwargs,
40
+ vector_store: "VectorStore", # noqa: F821
57
41
  ):
58
- # Import the vector store class dynamically
59
- module_name, class_name = vector_store_class.rsplit(".", 1)
60
- module = import_module(module_name)
61
- vector_store_class = getattr(module, class_name)
62
-
63
- signature = inspect.signature(vector_store_class)
64
-
65
- # Create the vector store instance
66
- if "collection_name" in signature.parameters.keys():
67
- vector_store = vector_store_class(collection_name=collection_name, **kwargs)
68
- else:
69
- vector_store = vector_store_class(**kwargs)
70
-
71
42
  self._collection_impl = vector_store
72
43
  self._mlrun_context = mlrun_context
73
44
  self.collection_name = collection_name
74
- self.id = datastore_profile + "/" + collection_name
75
45
 
76
46
  def __getattr__(self, name):
77
47
  # This method is called when an attribute is not found in the usual places
@@ -112,40 +82,74 @@ class VectorStoreCollection:
112
82
  )
113
83
  if mlrun_uri:
114
84
  artifact = self._mlrun_context.get_store_resource(mlrun_uri)
115
- artifact.collection_add(self.id)
85
+ artifact.collection_add(self.collection_name)
116
86
  self._mlrun_context.update_artifact(artifact)
87
+
117
88
  return self._collection_impl.add_documents(documents, **kwargs)
118
89
 
119
90
  def add_artifacts(self, artifacts: list[DocumentArtifact], splitter=None, **kwargs):
120
91
  """
121
- Add a list of DocumentArtifact objects to the collection.
92
+ Add a list of DocumentArtifact objects to the vector store collection.
93
+
94
+ Converts artifacts to LangChain documents, adds them to the vector store, and
95
+ updates the MLRun context. If documents are split, the IDs are handled appropriately.
122
96
 
123
97
  Args:
124
- artifacts (list[DocumentArtifact]): A list of DocumentArtifact objects to be added.
125
- splitter (optional): An optional splitter to be used when converting artifacts to documents.
126
- **kwargs: Additional keyword arguments to be passed to the collection's add_documents method.
98
+ artifacts (list[DocumentArtifact]): List of DocumentArtifact objects to add
99
+ splitter (optional): Document splitter to break artifacts into smaller chunks.
100
+ If None, each artifact becomes a single document.
101
+ **kwargs: Additional arguments passed to the underlying add_documents method.
102
+ Special handling for 'ids' kwarg:
103
+ - If provided and document is split, IDs are generated as "{original_id}_{i}"
104
+ where i starts from 1 (e.g., "doc1_1", "doc1_2", etc.)
105
+ - If provided and document isn't split, original IDs are used as-is
127
106
 
128
107
  Returns:
129
- list: A list of IDs of the added documents.
108
+ list: List of IDs for all added documents. When no custom IDs are provided:
109
+ - Without splitting: Vector store generates IDs automatically
110
+ - With splitting: Vector store generates separate IDs for each chunk
111
+ When custom IDs are provided:
112
+ - Without splitting: Uses provided IDs directly
113
+ - With splitting: Generates sequential IDs as "{original_id}_{i}" for each chunk
130
114
  """
131
115
  all_ids = []
132
- for artifact in artifacts:
116
+ user_ids = kwargs.pop("ids", None)
117
+
118
+ if user_ids:
119
+ if not isinstance(user_ids, Iterable):
120
+ raise ValueError("IDs must be an iterable collection")
121
+ if len(user_ids) != len(artifacts):
122
+ raise ValueError(
123
+ "The number of IDs should match the number of artifacts"
124
+ )
125
+ for index, artifact in enumerate(artifacts):
133
126
  documents = artifact.to_langchain_documents(splitter)
134
- artifact.collection_add(self.id)
135
- self._mlrun_context.update_artifact(artifact)
127
+ artifact.collection_add(self.collection_name)
128
+ if self._mlrun_context:
129
+ self._mlrun_context.update_artifact(artifact)
130
+ if user_ids:
131
+ num_of_documents = len(documents)
132
+ if num_of_documents > 1:
133
+ ids_to_pass = [
134
+ f"{user_ids[index]}_{i}" for i in range(1, num_of_documents + 1)
135
+ ]
136
+ else:
137
+ ids_to_pass = [user_ids[index]]
138
+ kwargs["ids"] = ids_to_pass
136
139
  ids = self._collection_impl.add_documents(documents, **kwargs)
137
140
  all_ids.extend(ids)
138
141
  return all_ids
139
142
 
140
- def remove_itself_from_artifact(self, artifact: DocumentArtifact):
143
+ def remove_from_artifact(self, artifact: DocumentArtifact):
141
144
  """
142
145
  Remove the current object from the given artifact's collection and update the artifact.
143
146
 
144
147
  Args:
145
148
  artifact (DocumentArtifact): The artifact from which the current object should be removed.
146
149
  """
147
- artifact.collection_remove(self.id)
148
- self._mlrun_context.update_artifact(artifact)
150
+ artifact.collection_remove(self.collection_name)
151
+ if self._mlrun_context:
152
+ self._mlrun_context.update_artifact(artifact)
149
153
 
150
154
  def delete_artifacts(self, artifacts: list[DocumentArtifact]):
151
155
  """
@@ -162,13 +166,15 @@ class VectorStoreCollection:
162
166
  """
163
167
  store_class = self._collection_impl.__class__.__name__.lower()
164
168
  for artifact in artifacts:
165
- artifact.collection_remove(self.id)
166
- self._mlrun_context.update_artifact(artifact)
169
+ artifact.collection_remove(self.collection_name)
170
+ if self._mlrun_context:
171
+ self._mlrun_context.update_artifact(artifact)
172
+
167
173
  if store_class == "milvus":
168
- expr = f"{DocumentArtifact.METADATA_SOURCE_KEY} == '{artifact.source}'"
174
+ expr = f"{DocumentArtifact.METADATA_SOURCE_KEY} == '{artifact.get_source()}'"
169
175
  return self._collection_impl.delete(expr=expr)
170
176
  elif store_class == "chroma":
171
- where = {DocumentArtifact.METADATA_SOURCE_KEY: artifact.source}
177
+ where = {DocumentArtifact.METADATA_SOURCE_KEY: artifact.get_source()}
172
178
  return self._collection_impl.delete(where=where)
173
179
 
174
180
  elif (
@@ -177,7 +183,9 @@ class VectorStoreCollection:
177
183
  in inspect.signature(self._collection_impl.delete).parameters
178
184
  ):
179
185
  filter = {
180
- "metadata": {DocumentArtifact.METADATA_SOURCE_KEY: artifact.source}
186
+ "metadata": {
187
+ DocumentArtifact.METADATA_SOURCE_KEY: artifact.get_source()
188
+ }
181
189
  }
182
190
  return self._collection_impl.delete(filter=filter)
183
191
  else: