mlrun 1.8.0rc21__py3-none-any.whl → 1.8.0rc22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mlrun might be problematic. Click here for more details.

mlrun/__init__.py CHANGED
@@ -213,7 +213,41 @@ def set_env_from_file(env_file: str, return_dict: bool = False) -> Optional[dict
213
213
  env_vars = dotenv.dotenv_values(env_file)
214
214
  if None in env_vars.values():
215
215
  raise MLRunInvalidArgumentError("env file lines must be in the form key=value")
216
- for key, value in env_vars.items():
217
- environ[key] = value # Load to local environ
216
+
217
+ ordered_env_vars = order_env_vars(env_vars)
218
+ for key, value in ordered_env_vars.items():
219
+ environ[key] = value
220
+
218
221
  mlconf.reload() # reload mlrun configuration
219
- return env_vars if return_dict else None
222
+ return ordered_env_vars if return_dict else None
223
+
224
+
225
+ def order_env_vars(env_vars: dict[str, str]) -> dict[str, str]:
226
+ """
227
+ Order and process environment variables by first handling specific ordered keys,
228
+ then processing the remaining keys in the given dictionary.
229
+
230
+ The function ensures that environment variables defined in the `ordered_keys` list
231
+ are added to the result dictionary first. Any other environment variables from
232
+ `env_vars` are then added in the order they appear in the input dictionary.
233
+
234
+ :param env_vars: A dictionary where each key is the name of an environment variable (str),
235
+ and each value is the corresponding environment variable value (str).
236
+ :return: A dictionary with the processed environment variables, ordered with the specific
237
+ keys first, followed by the rest in their original order.
238
+ """
239
+ ordered_keys = mlconf.get_ordered_keys()
240
+
241
+ ordered_env_vars: dict[str, str] = {}
242
+
243
+ # First, add the ordered keys to the dictionary
244
+ for key in ordered_keys:
245
+ if key in env_vars:
246
+ ordered_env_vars[key] = env_vars[key]
247
+
248
+ # Then, add the remaining keys (those not in ordered_keys)
249
+ for key, value in env_vars.items():
250
+ if key not in ordered_keys:
251
+ ordered_env_vars[key] = value
252
+
253
+ return ordered_env_vars
@@ -20,10 +20,12 @@ from importlib import import_module
20
20
  from typing import Optional, Union
21
21
 
22
22
  import mlrun
23
+ import mlrun.artifacts
23
24
  from mlrun.artifacts import Artifact, ArtifactSpec
24
25
  from mlrun.model import ModelObj
25
26
 
26
27
  from ..utils import generate_artifact_uri
28
+ from .base import ArtifactStatus
27
29
 
28
30
 
29
31
  class DocumentLoaderSpec(ModelObj):
@@ -191,6 +193,14 @@ class MLRunLoader:
191
193
  self.producer = mlrun.get_or_create_project(self.producer)
192
194
 
193
195
  def lazy_load(self) -> Iterator["Document"]: # noqa: F821
196
+ collections = None
197
+ try:
198
+ artifact = self.producer.get_artifact(self.artifact_key, self.tag)
199
+ collections = (
200
+ artifact.status.collections if artifact else collections
201
+ )
202
+ except mlrun.MLRunNotFoundError:
203
+ pass
194
204
  artifact = self.producer.log_document(
195
205
  key=self.artifact_key,
196
206
  document_loader_spec=self.loader_spec,
@@ -198,6 +208,7 @@ class MLRunLoader:
198
208
  upload=self.upload,
199
209
  labels=self.labels,
200
210
  tag=self.tag,
211
+ collections=collections,
201
212
  )
202
213
  res = artifact.to_langchain_documents()
203
214
  return res
@@ -252,26 +263,32 @@ class DocumentArtifact(Artifact):
252
263
  class DocumentArtifactSpec(ArtifactSpec):
253
264
  _dict_fields = ArtifactSpec._dict_fields + [
254
265
  "document_loader",
255
- "collections",
256
266
  "original_source",
257
267
  ]
258
- _exclude_fields_from_uid_hash = ArtifactSpec._exclude_fields_from_uid_hash + [
259
- "collections",
260
- ]
261
268
 
262
269
  def __init__(
263
270
  self,
264
271
  *args,
265
272
  document_loader: Optional[DocumentLoaderSpec] = None,
266
- collections: Optional[dict] = None,
267
273
  original_source: Optional[str] = None,
268
274
  **kwargs,
269
275
  ):
270
276
  super().__init__(*args, **kwargs)
271
277
  self.document_loader = document_loader
272
- self.collections = collections if collections is not None else {}
273
278
  self.original_source = original_source
274
279
 
280
+ class DocumentArtifactStatus(ArtifactStatus):
281
+ _dict_fields = ArtifactStatus._dict_fields + ["collections"]
282
+
283
+ def __init__(
284
+ self,
285
+ *args,
286
+ collections: Optional[dict] = None,
287
+ **kwargs,
288
+ ):
289
+ super().__init__(*args, **kwargs)
290
+ self.collections = collections if collections is not None else {}
291
+
275
292
  kind = "document"
276
293
 
277
294
  METADATA_SOURCE_KEY = "source"
@@ -286,6 +303,7 @@ class DocumentArtifact(Artifact):
286
303
  self,
287
304
  original_source: Optional[str] = None,
288
305
  document_loader_spec: Optional[DocumentLoaderSpec] = None,
306
+ collections: Optional[dict] = None,
289
307
  **kwargs,
290
308
  ):
291
309
  super().__init__(**kwargs)
@@ -295,6 +313,17 @@ class DocumentArtifact(Artifact):
295
313
  else self.spec.document_loader
296
314
  )
297
315
  self.spec.original_source = original_source or self.spec.original_source
316
+ self.status = DocumentArtifact.DocumentArtifactStatus(collections=collections)
317
+
318
+ @property
319
+ def status(self) -> DocumentArtifactStatus:
320
+ return self._status
321
+
322
+ @status.setter
323
+ def status(self, status):
324
+ self._status = self._verify_dict(
325
+ status, "status", DocumentArtifact.DocumentArtifactStatus
326
+ )
298
327
 
299
328
  @property
300
329
  def spec(self) -> DocumentArtifactSpec:
@@ -386,8 +415,8 @@ class DocumentArtifact(Artifact):
386
415
  Args:
387
416
  collection_id (str): The ID of the collection to add
388
417
  """
389
- if collection_id not in self.spec.collections:
390
- self.spec.collections[collection_id] = "1"
418
+ if collection_id not in self.status.collections:
419
+ self.status.collections[collection_id] = "1"
391
420
  return True
392
421
  return False
393
422
 
@@ -403,7 +432,7 @@ class DocumentArtifact(Artifact):
403
432
  Args:
404
433
  collection_id (str): The ID of the collection to remove
405
434
  """
406
- if collection_id in self.spec.collections:
407
- self.spec.collections.pop(collection_id)
435
+ if collection_id in self.status.collections:
436
+ self.status.collections.pop(collection_id)
408
437
  return True
409
438
  return False
@@ -57,7 +57,6 @@ class ClientSpec(pydantic.v1.BaseModel):
57
57
  redis_url: typing.Optional[str]
58
58
  redis_type: typing.Optional[str]
59
59
  sql_url: typing.Optional[str]
60
- model_monitoring_tsdb_connection: typing.Optional[str]
61
60
  ce: typing.Optional[dict]
62
61
  # not passing them as one object as it possible client user would like to override only one of the params
63
62
  calculate_artifact_hash: typing.Optional[str]
@@ -183,6 +183,25 @@ class WriterEventKind(MonitoringStrEnum):
183
183
  STATS = "stats"
184
184
 
185
185
 
186
+ class ControllerEvent(MonitoringStrEnum):
187
+ KIND = "kind"
188
+ ENDPOINT_ID = "endpoint_id"
189
+ ENDPOINT_NAME = "endpoint_name"
190
+ PROJECT = "project"
191
+ TIMESTAMP = "timestamp"
192
+ FIRST_REQUEST = "first_request"
193
+ FEATURE_SET_URI = "feature_set_uri"
194
+ ENDPOINT_TYPE = "endpoint_type"
195
+ ENDPOINT_POLICY = "endpoint_policy"
196
+ # Note: currently under endpoint policy we will have a dictionary including the keys: "application_names"
197
+ # and "base_period"
198
+
199
+
200
+ class ControllerEventKind(MonitoringStrEnum):
201
+ NOP_EVENT = "nop_event"
202
+ REGULAR_EVENT = "regular_event"
203
+
204
+
186
205
  class MetricData(MonitoringStrEnum):
187
206
  METRIC_NAME = "metric_name"
188
207
  METRIC_VALUE = "metric_value"
@@ -228,28 +247,26 @@ class ModelEndpointTarget(MonitoringStrEnum):
228
247
  SQL = "sql"
229
248
 
230
249
 
231
- class StreamKind(MonitoringStrEnum):
232
- V3IO_STREAM = "v3io_stream"
233
- KAFKA = "kafka"
234
-
235
-
236
250
  class TSDBTarget(MonitoringStrEnum):
237
251
  V3IO_TSDB = "v3io-tsdb"
238
252
  TDEngine = "tdengine"
239
253
 
240
254
 
255
+ class DefaultProfileName(StrEnum):
256
+ STREAM = "mm-infra-stream"
257
+ TSDB = "mm-infra-tsdb"
258
+
259
+
241
260
  class ProjectSecretKeys:
242
261
  ACCESS_KEY = "MODEL_MONITORING_ACCESS_KEY"
243
- STREAM_PATH = "STREAM_PATH"
244
- TSDB_CONNECTION = "TSDB_CONNECTION"
245
262
  TSDB_PROFILE_NAME = "TSDB_PROFILE_NAME"
246
263
  STREAM_PROFILE_NAME = "STREAM_PROFILE_NAME"
247
264
 
248
265
  @classmethod
249
266
  def mandatory_secrets(cls):
250
267
  return [
251
- cls.STREAM_PATH,
252
- cls.TSDB_CONNECTION,
268
+ cls.STREAM_PROFILE_NAME,
269
+ cls.TSDB_PROFILE_NAME,
253
270
  ]
254
271
 
255
272
 
mlrun/config.py CHANGED
@@ -537,6 +537,8 @@ default_config = {
537
537
  },
538
538
  "pagination": {
539
539
  "default_page_size": 200,
540
+ "page_limit": 1000000,
541
+ "page_size_limit": 1000000,
540
542
  "pagination_cache": {
541
543
  "interval": 60,
542
544
  "ttl": 3600,
@@ -594,6 +596,22 @@ default_config = {
594
596
  "max_replicas": 1,
595
597
  },
596
598
  },
599
+ "controller_stream_args": {
600
+ "v3io": {
601
+ "shard_count": 10,
602
+ "retention_period_hours": 24,
603
+ "num_workers": 10,
604
+ "min_replicas": 1,
605
+ "max_replicas": 1,
606
+ },
607
+ "kafka": {
608
+ "partition_count": 10,
609
+ "replication_factor": 1,
610
+ "num_workers": 10,
611
+ "min_replicas": 1,
612
+ "max_replicas": 1,
613
+ },
614
+ },
597
615
  # Store prefixes are used to handle model monitoring storing policies based on project and kind, such as events,
598
616
  # stream, and endpoints.
599
617
  "store_prefixes": {
@@ -606,10 +624,6 @@ default_config = {
606
624
  "offline_storage_path": "model-endpoints/{kind}",
607
625
  "parquet_batching_max_events": 10_000,
608
626
  "parquet_batching_timeout_secs": timedelta(minutes=1).total_seconds(),
609
- # See mlrun.model_monitoring.db.tsdb.ObjectTSDBFactory for available options
610
- "tsdb_connection": "",
611
- # See mlrun.common.schemas.model_monitoring.constants.StreamKind for available options
612
- "stream_connection": "",
613
627
  "tdengine": {
614
628
  "timeout": 10,
615
629
  "retries": 1,
@@ -727,6 +741,7 @@ default_config = {
727
741
  },
728
742
  "workflows": {
729
743
  "default_workflow_runner_name": "workflow-runner-{}",
744
+ "concurrent_delete_worker_count": 20,
730
745
  # Default timeout seconds for retrieving workflow id after execution
731
746
  # Remote workflow timeout is the maximum between remote and the inner engine timeout
732
747
  "timeouts": {"local": 120, "kfp": 60, "remote": 60 * 5},
@@ -799,7 +814,7 @@ default_config = {
799
814
  # maximum allowed value for count in criteria field inside AlertConfig
800
815
  "max_criteria_count": 100,
801
816
  # interval for periodic events generation job
802
- "events_generation_interval": "30",
817
+ "events_generation_interval": 30, # seconds
803
818
  },
804
819
  "auth_with_client_id": {
805
820
  "enabled": False,
@@ -1282,6 +1297,8 @@ class Config:
1282
1297
  function_name
1283
1298
  and function_name
1284
1299
  != mlrun.common.schemas.model_monitoring.constants.MonitoringFunctionNames.STREAM
1300
+ and function_name
1301
+ != mlrun.common.schemas.model_monitoring.constants.MonitoringFunctionNames.APPLICATION_CONTROLLER
1285
1302
  ):
1286
1303
  return mlrun.mlconf.model_endpoint_monitoring.store_prefixes.user_space.format(
1287
1304
  project=project,
@@ -1289,12 +1306,21 @@ class Config:
1289
1306
  if function_name is None
1290
1307
  else f"{kind}-{function_name.lower()}",
1291
1308
  )
1292
- elif kind == "stream":
1309
+ elif (
1310
+ kind == "stream"
1311
+ and function_name
1312
+ != mlrun.common.schemas.model_monitoring.constants.MonitoringFunctionNames.APPLICATION_CONTROLLER
1313
+ ):
1293
1314
  return mlrun.mlconf.model_endpoint_monitoring.store_prefixes.user_space.format(
1294
1315
  project=project,
1295
1316
  kind=kind,
1296
1317
  )
1297
1318
  else:
1319
+ if (
1320
+ function_name
1321
+ == mlrun.common.schemas.model_monitoring.constants.MonitoringFunctionNames.APPLICATION_CONTROLLER
1322
+ ):
1323
+ kind = function_name
1298
1324
  return mlrun.mlconf.model_endpoint_monitoring.store_prefixes.default.format(
1299
1325
  project=project,
1300
1326
  kind=kind,
@@ -1363,6 +1389,13 @@ class Config:
1363
1389
  >= semver.VersionInfo.parse("1.12.10")
1364
1390
  )
1365
1391
 
1392
+ @staticmethod
1393
+ def get_ordered_keys():
1394
+ # Define the keys to process first
1395
+ return [
1396
+ "MLRUN_HTTPDB__HTTP__VERIFY" # Ensure this key is processed first for proper connection setup
1397
+ ]
1398
+
1366
1399
 
1367
1400
  # Global configuration
1368
1401
  config = Config.from_dict(default_config)
@@ -17,7 +17,7 @@ import base64
17
17
  import json
18
18
  import typing
19
19
  import warnings
20
- from urllib.parse import ParseResult, urlparse, urlunparse
20
+ from urllib.parse import ParseResult, urlparse
21
21
 
22
22
  import pydantic.v1
23
23
  from mergedeep import merge
@@ -211,9 +211,10 @@ class DatastoreProfileKafkaSource(DatastoreProfile):
211
211
  attributes["partitions"] = self.partitions
212
212
  sasl = attributes.pop("sasl", {})
213
213
  if self.sasl_user and self.sasl_pass:
214
- sasl["enabled"] = True
214
+ sasl["enable"] = True
215
215
  sasl["user"] = self.sasl_user
216
216
  sasl["password"] = self.sasl_pass
217
+ sasl["mechanism"] = "PLAIN"
217
218
  if sasl:
218
219
  attributes["sasl"] = sasl
219
220
  return attributes
@@ -312,7 +313,7 @@ class DatastoreProfileRedis(DatastoreProfile):
312
313
  query=parsed_url.query,
313
314
  fragment=parsed_url.fragment,
314
315
  )
315
- return urlunparse(new_parsed_url)
316
+ return new_parsed_url.geturl()
316
317
 
317
318
  def secrets(self) -> dict:
318
319
  res = {}
@@ -473,6 +474,59 @@ class DatastoreProfileHdfs(DatastoreProfile):
473
474
  return f"webhdfs://{self.host}:{self.http_port}{subpath}"
474
475
 
475
476
 
477
+ class TDEngineDatastoreProfile(DatastoreProfile):
478
+ """
479
+ A profile that holds the required parameters for a TDEngine database, with the websocket scheme.
480
+ https://docs.tdengine.com/developer-guide/connecting-to-tdengine/#websocket-connection
481
+ """
482
+
483
+ type: str = pydantic.v1.Field("taosws")
484
+ _private_attributes = ["password"]
485
+ user: str
486
+ # The password cannot be empty in real world scenarios. It's here just because of the profiles completion design.
487
+ password: typing.Optional[str]
488
+ host: str
489
+ port: int
490
+
491
+ def dsn(self) -> str:
492
+ """Get the Data Source Name of the configured TDEngine profile."""
493
+ return f"{self.type}://{self.user}:{self.password}@{self.host}:{self.port}"
494
+
495
+ @classmethod
496
+ def from_dsn(cls, dsn: str, profile_name: str) -> "TDEngineDatastoreProfile":
497
+ """
498
+ Construct a TDEngine profile from DSN (connection string) and a name for the profile.
499
+
500
+ :param dsn: The DSN (Data Source Name) of the TDEngine database, e.g.: ``"taosws://root:taosdata@localhost:6041"``.
501
+ :param profile_name: The new profile's name.
502
+ :return: The TDEngine profile.
503
+ """
504
+ parsed_url = urlparse(dsn)
505
+ return cls(
506
+ name=profile_name,
507
+ user=parsed_url.username,
508
+ password=parsed_url.password,
509
+ host=parsed_url.hostname,
510
+ port=parsed_url.port,
511
+ )
512
+
513
+
514
+ _DATASTORE_TYPE_TO_PROFILE_CLASS: dict[str, type[DatastoreProfile]] = {
515
+ "v3io": DatastoreProfileV3io,
516
+ "s3": DatastoreProfileS3,
517
+ "redis": DatastoreProfileRedis,
518
+ "basic": DatastoreProfileBasic,
519
+ "kafka_target": DatastoreProfileKafkaTarget,
520
+ "kafka_source": DatastoreProfileKafkaSource,
521
+ "dbfs": DatastoreProfileDBFS,
522
+ "gcs": DatastoreProfileGCS,
523
+ "az": DatastoreProfileAzureBlob,
524
+ "hdfs": DatastoreProfileHdfs,
525
+ "taosws": TDEngineDatastoreProfile,
526
+ "config": ConfigProfile,
527
+ }
528
+
529
+
476
530
  class DatastoreProfile2Json(pydantic.v1.BaseModel):
477
531
  @staticmethod
478
532
  def _to_json(attributes):
@@ -523,19 +577,7 @@ class DatastoreProfile2Json(pydantic.v1.BaseModel):
523
577
 
524
578
  decoded_dict = {k: safe_literal_eval(v) for k, v in decoded_dict.items()}
525
579
  datastore_type = decoded_dict.get("type")
526
- ds_profile_factory = {
527
- "v3io": DatastoreProfileV3io,
528
- "s3": DatastoreProfileS3,
529
- "redis": DatastoreProfileRedis,
530
- "basic": DatastoreProfileBasic,
531
- "kafka_target": DatastoreProfileKafkaTarget,
532
- "kafka_source": DatastoreProfileKafkaSource,
533
- "dbfs": DatastoreProfileDBFS,
534
- "gcs": DatastoreProfileGCS,
535
- "az": DatastoreProfileAzureBlob,
536
- "hdfs": DatastoreProfileHdfs,
537
- "config": ConfigProfile,
538
- }
580
+ ds_profile_factory = _DATASTORE_TYPE_TO_PROFILE_CLASS
539
581
  if datastore_type in ds_profile_factory:
540
582
  return ds_profile_factory[datastore_type].parse_obj(decoded_dict)
541
583
  else:
@@ -1089,9 +1089,10 @@ class KafkaSource(OnlineSource):
1089
1089
  attributes["partitions"] = partitions
1090
1090
  sasl = attributes.pop("sasl", {})
1091
1091
  if sasl_user and sasl_pass:
1092
- sasl["enabled"] = True
1092
+ sasl["enable"] = True
1093
1093
  sasl["user"] = sasl_user
1094
1094
  sasl["password"] = sasl_pass
1095
+ sasl["mechanism"] = "PLAIN"
1095
1096
  if sasl:
1096
1097
  attributes["sasl"] = sasl
1097
1098
  super().__init__(attributes=attributes, **kwargs)
@@ -1127,8 +1128,13 @@ class KafkaSource(OnlineSource):
1127
1128
  extra_attributes["workerAllocationMode"] = extra_attributes.get(
1128
1129
  "worker_allocation_mode", "static"
1129
1130
  )
1131
+ else:
1132
+ extra_attributes["workerAllocationMode"] = extra_attributes.get(
1133
+ "worker_allocation_mode", "pool"
1134
+ )
1130
1135
 
1131
1136
  trigger_kwargs = {}
1137
+
1132
1138
  if "max_workers" in extra_attributes:
1133
1139
  trigger_kwargs = {"max_workers": extra_attributes.pop("max_workers")}
1134
1140
 
@@ -48,6 +48,14 @@ def _extract_collection_name(vectorstore: "VectorStore") -> str: # noqa: F821
48
48
  else:
49
49
  return getattr(obj, pattern, None)
50
50
 
51
+ if type(vectorstore).__name__ == "PineconeVectorStore":
52
+ try:
53
+ url = vectorstore._index._config.host
54
+ index_name = url.split("//")[1].split("-")[0]
55
+ return index_name
56
+ except Exception:
57
+ pass
58
+
51
59
  for pattern in patterns:
52
60
  try:
53
61
  value = resolve_attribute(vectorstore, pattern)
@@ -254,7 +262,11 @@ class VectorStoreCollection:
254
262
  elif store_class == "chroma":
255
263
  where = {DocumentArtifact.METADATA_SOURCE_KEY: artifact.get_source()}
256
264
  self._collection_impl.delete(where=where)
257
-
265
+ elif store_class == "pineconevectorstore":
266
+ filter = {
267
+ DocumentArtifact.METADATA_SOURCE_KEY: {"$eq": artifact.get_source()}
268
+ }
269
+ self._collection_impl.delete(filter=filter)
258
270
  elif (
259
271
  hasattr(self._collection_impl, "delete")
260
272
  and "filter"
mlrun/db/base.py CHANGED
@@ -68,6 +68,9 @@ class RunDBInterface(ABC):
68
68
  ):
69
69
  pass
70
70
 
71
+ def refresh_smtp_configuration(self):
72
+ pass
73
+
71
74
  def push_pipeline_notifications(
72
75
  self,
73
76
  pipeline_id,
mlrun/db/httpdb.py CHANGED
@@ -559,14 +559,6 @@ class HTTPRunDB(RunDBInterface):
559
559
  server_cfg.get("external_platform_tracking")
560
560
  or config.external_platform_tracking
561
561
  )
562
- config.model_endpoint_monitoring.tsdb_connection = (
563
- server_cfg.get("model_monitoring_tsdb_connection")
564
- or config.model_endpoint_monitoring.tsdb_connection
565
- )
566
- config.model_endpoint_monitoring.stream_connection = (
567
- server_cfg.get("stream_connection")
568
- or config.model_endpoint_monitoring.stream_connection
569
- )
570
562
  config.packagers = server_cfg.get("packagers") or config.packagers
571
563
  server_data_prefixes = server_cfg.get("feature_store_data_prefixes") or {}
572
564
  for prefix in ["default", "nosql", "redisnosql"]:
mlrun/db/nopdb.py CHANGED
@@ -84,6 +84,9 @@ class NopDB(RunDBInterface):
84
84
  ):
85
85
  pass
86
86
 
87
+ def refresh_smtp_configuration(self):
88
+ pass
89
+
87
90
  def push_pipeline_notifications(
88
91
  self,
89
92
  pipeline_id,
mlrun/errors.py CHANGED
@@ -174,6 +174,10 @@ class MLRunInvalidArgumentError(MLRunHTTPStatusError, ValueError):
174
174
  error_status_code = HTTPStatus.BAD_REQUEST.value
175
175
 
176
176
 
177
+ class MLRunModelLimitExceededError(MLRunHTTPStatusError, ValueError):
178
+ error_status_code = HTTPStatus.BAD_REQUEST.value
179
+
180
+
177
181
  class MLRunInvalidArgumentTypeError(MLRunHTTPStatusError, TypeError):
178
182
  error_status_code = HTTPStatus.BAD_REQUEST.value
179
183
 
mlrun/execution.py CHANGED
@@ -936,6 +936,7 @@ class MLClientCtx:
936
936
  key=key,
937
937
  original_source=local_path or target_path,
938
938
  document_loader_spec=document_loader_spec,
939
+ collections=kwargs.pop("collections", None),
939
940
  **kwargs,
940
941
  )
941
942