snowflake-ml-python 1.12.0__py3-none-any.whl → 1.14.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. snowflake/ml/_internal/telemetry.py +3 -1
  2. snowflake/ml/experiment/experiment_tracking.py +24 -2
  3. snowflake/ml/jobs/_utils/constants.py +1 -1
  4. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +16 -4
  5. snowflake/ml/jobs/job.py +17 -6
  6. snowflake/ml/jobs/manager.py +60 -11
  7. snowflake/ml/lineage/lineage_node.py +0 -1
  8. snowflake/ml/model/_client/model/batch_inference_specs.py +3 -5
  9. snowflake/ml/model/_client/model/model_version_impl.py +6 -20
  10. snowflake/ml/model/_client/ops/model_ops.py +49 -9
  11. snowflake/ml/model/_client/ops/service_ops.py +66 -34
  12. snowflake/ml/model/_client/service/model_deployment_spec.py +3 -0
  13. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -0
  14. snowflake/ml/model/_client/sql/service.py +1 -0
  15. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +103 -21
  16. snowflake/ml/model/_packager/model_handlers/xgboost.py +5 -0
  17. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -0
  18. snowflake/ml/model/models/huggingface_pipeline.py +23 -0
  19. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +47 -3
  20. snowflake/ml/monitoring/_manager/model_monitor_manager.py +3 -0
  21. snowflake/ml/monitoring/entities/model_monitor_config.py +3 -0
  22. snowflake/ml/monitoring/model_monitor.py +30 -0
  23. snowflake/ml/version.py +1 -1
  24. {snowflake_ml_python-1.12.0.dist-info → snowflake_ml_python-1.14.0.dist-info}/METADATA +27 -1
  25. {snowflake_ml_python-1.12.0.dist-info → snowflake_ml_python-1.14.0.dist-info}/RECORD +28 -28
  26. {snowflake_ml_python-1.12.0.dist-info → snowflake_ml_python-1.14.0.dist-info}/WHEEL +0 -0
  27. {snowflake_ml_python-1.12.0.dist-info → snowflake_ml_python-1.14.0.dist-info}/licenses/LICENSE.txt +0 -0
  28. {snowflake_ml_python-1.12.0.dist-info → snowflake_ml_python-1.14.0.dist-info}/top_level.txt +0 -0
@@ -323,17 +323,20 @@ class ServiceOperator:
323
323
  statement_params=statement_params,
324
324
  )
325
325
 
326
- # stream service logs in a thread
327
- model_build_service_name = sql_identifier.SqlIdentifier(
328
- self._get_service_id_from_deployment_step(query_id, DeploymentStep.MODEL_BUILD)
329
- )
330
- model_build_service = ServiceLogInfo(
331
- database_name=service_database_name,
332
- schema_name=service_schema_name,
333
- service_name=model_build_service_name,
334
- deployment_step=DeploymentStep.MODEL_BUILD,
335
- log_color=service_logger.LogColor.GREEN,
336
- )
326
+ model_build_service: Optional[ServiceLogInfo] = None
327
+ if is_enable_image_build:
328
+ # stream service logs in a thread
329
+ model_build_service_name = sql_identifier.SqlIdentifier(
330
+ self._get_service_id_from_deployment_step(query_id, DeploymentStep.MODEL_BUILD)
331
+ )
332
+ model_build_service = ServiceLogInfo(
333
+ database_name=service_database_name,
334
+ schema_name=service_schema_name,
335
+ service_name=model_build_service_name,
336
+ deployment_step=DeploymentStep.MODEL_BUILD,
337
+ log_color=service_logger.LogColor.GREEN,
338
+ )
339
+
337
340
  model_inference_service = ServiceLogInfo(
338
341
  database_name=service_database_name,
339
342
  schema_name=service_schema_name,
@@ -375,7 +378,7 @@ class ServiceOperator:
375
378
  progress_status.increment()
376
379
 
377
380
  # Poll for model build to start if not using existing service
378
- if not model_inference_service_exists:
381
+ if not model_inference_service_exists and model_build_service:
379
382
  self._wait_for_service_status(
380
383
  model_build_service_name,
381
384
  service_sql.ServiceStatus.RUNNING,
@@ -390,7 +393,7 @@ class ServiceOperator:
390
393
  progress_status.increment()
391
394
 
392
395
  # Poll for model build completion
393
- if not model_inference_service_exists:
396
+ if not model_inference_service_exists and model_build_service:
394
397
  self._wait_for_service_status(
395
398
  model_build_service_name,
396
399
  service_sql.ServiceStatus.DONE,
@@ -454,7 +457,7 @@ class ServiceOperator:
454
457
  self,
455
458
  async_job: snowpark.AsyncJob,
456
459
  model_logger_service: Optional[ServiceLogInfo],
457
- model_build_service: ServiceLogInfo,
460
+ model_build_service: Optional[ServiceLogInfo],
458
461
  model_inference_service: ServiceLogInfo,
459
462
  model_inference_service_exists: bool,
460
463
  force_rebuild: bool,
@@ -483,7 +486,7 @@ class ServiceOperator:
483
486
  self,
484
487
  force_rebuild: bool,
485
488
  service_log_meta: ServiceLogMetadata,
486
- model_build_service: ServiceLogInfo,
489
+ model_build_service: Optional[ServiceLogInfo],
487
490
  model_inference_service: ServiceLogInfo,
488
491
  operation_id: str,
489
492
  statement_params: Optional[dict[str, Any]] = None,
@@ -599,13 +602,24 @@ class ServiceOperator:
599
602
  # check if model logger service is done
600
603
  # and transition the service log metadata to the model image build service
601
604
  if service.deployment_step == DeploymentStep.MODEL_LOGGING:
602
- service_log_meta.transition_service_log_metadata(
603
- model_build_service,
604
- f"Model Logger service {service.display_service_name} complete.",
605
- is_model_build_service_done=False,
606
- is_model_logger_service_done=service_log_meta.is_model_logger_service_done,
607
- operation_id=operation_id,
608
- )
605
+ if model_build_service:
606
+ # building the inference image, transition to the model build service
607
+ service_log_meta.transition_service_log_metadata(
608
+ model_build_service,
609
+ f"Model Logger service {service.display_service_name} complete.",
610
+ is_model_build_service_done=False,
611
+ is_model_logger_service_done=service_log_meta.is_model_logger_service_done,
612
+ operation_id=operation_id,
613
+ )
614
+ else:
615
+ # no model build service, transition to the model inference service
616
+ service_log_meta.transition_service_log_metadata(
617
+ model_inference_service,
618
+ f"Model Logger service {service.display_service_name} complete.",
619
+ is_model_build_service_done=True,
620
+ is_model_logger_service_done=service_log_meta.is_model_logger_service_done,
621
+ operation_id=operation_id,
622
+ )
609
623
  # check if model build service is done
610
624
  # and transition the service log metadata to the model inference service
611
625
  elif service.deployment_step == DeploymentStep.MODEL_BUILD:
@@ -616,6 +630,8 @@ class ServiceOperator:
616
630
  is_model_logger_service_done=service_log_meta.is_model_logger_service_done,
617
631
  operation_id=operation_id,
618
632
  )
633
+ elif service.deployment_step == DeploymentStep.MODEL_INFERENCE:
634
+ module_logger.info(f"Inference service {service.display_service_name} is deployed.")
619
635
  else:
620
636
  module_logger.warning(f"Service {service.display_service_name} is done, but not transitioning.")
621
637
 
@@ -623,7 +639,7 @@ class ServiceOperator:
623
639
  self,
624
640
  async_job: snowpark.AsyncJob,
625
641
  model_logger_service: Optional[ServiceLogInfo],
626
- model_build_service: ServiceLogInfo,
642
+ model_build_service: Optional[ServiceLogInfo],
627
643
  model_inference_service: ServiceLogInfo,
628
644
  model_inference_service_exists: bool,
629
645
  force_rebuild: bool,
@@ -632,14 +648,23 @@ class ServiceOperator:
632
648
  ) -> None:
633
649
  """Stream service logs while the async job is running."""
634
650
 
635
- model_build_service_logger = service_logger.get_logger( # BuildJobName
636
- model_build_service.display_service_name,
637
- model_build_service.log_color,
638
- operation_id=operation_id,
639
- )
640
- if model_logger_service:
641
- model_logger_service_logger = service_logger.get_logger( # ModelLoggerName
642
- model_logger_service.display_service_name,
651
+ if model_build_service:
652
+ model_build_service_logger = service_logger.get_logger(
653
+ model_build_service.display_service_name, # BuildJobName
654
+ model_build_service.log_color,
655
+ operation_id=operation_id,
656
+ )
657
+ service_log_meta = ServiceLogMetadata(
658
+ service_logger=model_build_service_logger,
659
+ service=model_build_service,
660
+ service_status=None,
661
+ is_model_build_service_done=False,
662
+ is_model_logger_service_done=True,
663
+ log_offset=0,
664
+ )
665
+ elif model_logger_service:
666
+ model_logger_service_logger = service_logger.get_logger(
667
+ model_logger_service.display_service_name, # ModelLoggerName
643
668
  model_logger_service.log_color,
644
669
  operation_id=operation_id,
645
670
  )
@@ -653,12 +678,17 @@ class ServiceOperator:
653
678
  log_offset=0,
654
679
  )
655
680
  else:
681
+ model_inference_service_logger = service_logger.get_logger(
682
+ model_inference_service.display_service_name, # ModelInferenceName
683
+ model_inference_service.log_color,
684
+ operation_id=operation_id,
685
+ )
656
686
  service_log_meta = ServiceLogMetadata(
657
- service_logger=model_build_service_logger,
658
- service=model_build_service,
687
+ service_logger=model_inference_service_logger,
688
+ service=model_inference_service,
659
689
  service_status=None,
660
690
  is_model_build_service_done=False,
661
- is_model_logger_service_done=True,
691
+ is_model_logger_service_done=False,
662
692
  log_offset=0,
663
693
  )
664
694
 
@@ -881,6 +911,7 @@ class ServiceOperator:
881
911
  max_batch_rows: Optional[int],
882
912
  cpu_requests: Optional[str],
883
913
  memory_requests: Optional[str],
914
+ replicas: Optional[int],
884
915
  statement_params: Optional[dict[str, Any]] = None,
885
916
  ) -> jobs.MLJob[Any]:
886
917
  database_name = self._database_name
@@ -914,6 +945,7 @@ class ServiceOperator:
914
945
  warehouse=warehouse,
915
946
  cpu=cpu_requests,
916
947
  memory=memory_requests,
948
+ replicas=replicas,
917
949
  )
918
950
 
919
951
  self._model_deployment_spec.add_image_build_spec(
@@ -207,6 +207,7 @@ class ModelDeploymentSpec:
207
207
  gpu: Optional[Union[str, int]] = None,
208
208
  num_workers: Optional[int] = None,
209
209
  max_batch_rows: Optional[int] = None,
210
+ replicas: Optional[int] = None,
210
211
  ) -> "ModelDeploymentSpec":
211
212
  """Add job specification to the deployment spec.
212
213
 
@@ -226,6 +227,7 @@ class ModelDeploymentSpec:
226
227
  gpu: GPU requirement.
227
228
  num_workers: Number of workers.
228
229
  max_batch_rows: Maximum batch rows for inference.
230
+ replicas: Number of replicas.
229
231
 
230
232
  Raises:
231
233
  ValueError: If a service spec already exists.
@@ -260,6 +262,7 @@ class ModelDeploymentSpec:
260
262
  output_stage_location=output_stage_location,
261
263
  completion_filename=completion_filename,
262
264
  ),
265
+ replicas=replicas,
263
266
  **self._inference_spec,
264
267
  )
265
268
  return self
@@ -57,6 +57,7 @@ class Job(BaseModel):
57
57
  function_name: str
58
58
  input: Input
59
59
  output: Output
60
+ replicas: Optional[int] = None
60
61
 
61
62
 
62
63
  class LogModelArgs(BaseModel):
@@ -63,6 +63,7 @@ class ServiceStatusInfo:
63
63
  class ServiceSQLClient(_base._BaseSQLClient):
64
64
  MODEL_INFERENCE_SERVICE_ENDPOINT_NAME_COL_NAME = "name"
65
65
  MODEL_INFERENCE_SERVICE_ENDPOINT_INGRESS_URL_COL_NAME = "ingress_url"
66
+ MODEL_INFERENCE_SERVICE_ENDPOINT_PRIVATELINK_INGRESS_URL_COL_NAME = "privatelink_ingress_url"
66
67
  SERVICE_STATUS = "service_status"
67
68
  INSTANCE_ID = "instance_id"
68
69
  INSTANCE_STATUS = "instance_status"
@@ -1,6 +1,7 @@
1
1
  import json
2
2
  import logging
3
3
  import os
4
+ import shutil
4
5
  import time
5
6
  import uuid
6
7
  import warnings
@@ -88,6 +89,7 @@ class HuggingFacePipelineHandler(
88
89
  _HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {}
89
90
 
90
91
  MODEL_BLOB_FILE_OR_DIR = "model"
92
+ MODEL_PICKLE_FILE = "snowml_huggingface_pipeline.pkl"
91
93
  ADDITIONAL_CONFIG_FILE = "pipeline_config.pt"
92
94
  DEFAULT_TARGET_METHODS = ["__call__"]
93
95
  IS_AUTO_SIGNATURE = True
@@ -199,6 +201,7 @@ class HuggingFacePipelineHandler(
199
201
  model_blob_path = os.path.join(model_blobs_dir_path, name)
200
202
  os.makedirs(model_blob_path, exist_ok=True)
201
203
 
204
+ is_repo_downloaded = False
202
205
  if type_utils.LazyType("transformers.Pipeline").isinstance(model):
203
206
  save_path = os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR)
204
207
  model.save_pretrained( # type:ignore[attr-defined]
@@ -224,11 +227,22 @@ class HuggingFacePipelineHandler(
224
227
  ) as f:
225
228
  cloudpickle.dump(pipeline_params, f)
226
229
  else:
230
+ model_blob_file_or_dir = os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR)
231
+ model_blob_pickle_file = os.path.join(model_blob_file_or_dir, cls.MODEL_PICKLE_FILE)
232
+ os.makedirs(model_blob_file_or_dir, exist_ok=True)
227
233
  with open(
228
- os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR),
234
+ model_blob_pickle_file,
229
235
  "wb",
230
236
  ) as f:
231
237
  cloudpickle.dump(model, f)
238
+ if model.repo_snapshot_dir:
239
+ logger.info("model's repo_snapshot_dir is available, copying snapshot")
240
+ shutil.copytree(
241
+ model.repo_snapshot_dir,
242
+ model_blob_file_or_dir,
243
+ dirs_exist_ok=True,
244
+ )
245
+ is_repo_downloaded = True
232
246
 
233
247
  base_meta = model_blob_meta.ModelBlobMeta(
234
248
  name=name,
@@ -236,13 +250,12 @@ class HuggingFacePipelineHandler(
236
250
  handler_version=cls.HANDLER_VERSION,
237
251
  path=cls.MODEL_BLOB_FILE_OR_DIR,
238
252
  options=model_meta_schema.HuggingFacePipelineModelBlobOptions(
239
- {
240
- "task": task,
241
- "batch_size": batch_size if batch_size is not None else 1,
242
- "has_tokenizer": has_tokenizer,
243
- "has_feature_extractor": has_feature_extractor,
244
- "has_image_preprocessor": has_image_preprocessor,
245
- }
253
+ task=task,
254
+ batch_size=batch_size if batch_size is not None else 1,
255
+ has_tokenizer=has_tokenizer,
256
+ has_feature_extractor=has_feature_extractor,
257
+ has_image_preprocessor=has_image_preprocessor,
258
+ is_repo_downloaded=is_repo_downloaded,
246
259
  ),
247
260
  )
248
261
  model_meta.models[name] = base_meta
@@ -286,6 +299,27 @@ class HuggingFacePipelineHandler(
286
299
 
287
300
  return device_config
288
301
 
302
+ @staticmethod
303
+ def _load_pickle_model(
304
+ pickle_file: str,
305
+ **kwargs: Unpack[model_types.HuggingFaceLoadOptions],
306
+ ) -> huggingface_pipeline.HuggingFacePipelineModel:
307
+ with open(pickle_file, "rb") as f:
308
+ m = cloudpickle.load(f)
309
+ assert isinstance(m, huggingface_pipeline.HuggingFacePipelineModel)
310
+ torch_dtype: Optional[str] = None
311
+ device_config = None
312
+ if getattr(m, "device", None) is None and getattr(m, "device_map", None) is None:
313
+ device_config = HuggingFacePipelineHandler._get_device_config(**kwargs)
314
+ m.__dict__.update(device_config)
315
+
316
+ if getattr(m, "torch_dtype", None) is None and kwargs.get("use_gpu", False):
317
+ torch_dtype = "auto"
318
+ m.__dict__.update(torch_dtype=torch_dtype)
319
+ else:
320
+ m.__dict__.update(torch_dtype=None)
321
+ return m
322
+
289
323
  @classmethod
290
324
  def load_model(
291
325
  cls,
@@ -310,7 +344,13 @@ class HuggingFacePipelineHandler(
310
344
  raise ValueError("Missing field `batch_size` in model blob metadata for type `huggingface_pipeline`")
311
345
 
312
346
  model_blob_file_or_dir_path = os.path.join(model_blob_path, model_blob_filename)
313
- if os.path.isdir(model_blob_file_or_dir_path):
347
+ is_repo_downloaded = model_blob_options.get("is_repo_downloaded", False)
348
+
349
+ def _create_pipeline_from_dir(
350
+ model_blob_file_or_dir_path: str,
351
+ model_blob_options: model_meta_schema.HuggingFacePipelineModelBlobOptions,
352
+ **kwargs: Unpack[model_types.HuggingFaceLoadOptions],
353
+ ) -> "transformers.Pipeline":
314
354
  import transformers
315
355
 
316
356
  additional_pipeline_params = {}
@@ -330,7 +370,7 @@ class HuggingFacePipelineHandler(
330
370
  ) as f:
331
371
  pipeline_params = cloudpickle.load(f)
332
372
 
333
- device_config = cls._get_device_config(**kwargs)
373
+ device_config = HuggingFacePipelineHandler._get_device_config(**kwargs)
334
374
 
335
375
  m = transformers.pipeline(
336
376
  model_blob_options["task"],
@@ -359,18 +399,59 @@ class HuggingFacePipelineHandler(
359
399
  m.tokenizer.chat_template = DEFAULT_CHAT_TEMPLATE
360
400
 
361
401
  m.__dict__.update(pipeline_params)
402
+ return m
362
403
 
404
+ def _create_pipeline_from_model(
405
+ model_blob_file_or_dir_path: str,
406
+ m: huggingface_pipeline.HuggingFacePipelineModel,
407
+ **kwargs: Unpack[model_types.HuggingFaceLoadOptions],
408
+ ) -> "transformers.Pipeline":
409
+ import transformers
410
+
411
+ return transformers.pipeline(
412
+ m.task,
413
+ model=model_blob_file_or_dir_path,
414
+ trust_remote_code=m.trust_remote_code,
415
+ torch_dtype=getattr(m, "torch_dtype", None),
416
+ revision=m.revision,
417
+ # pass device or device_map when creating the pipeline
418
+ **HuggingFacePipelineHandler._get_device_config(**kwargs),
419
+ # pass other model_kwargs to transformers.pipeline.from_pretrained method
420
+ **m.model_kwargs,
421
+ )
422
+
423
+ if os.path.isdir(model_blob_file_or_dir_path) and not is_repo_downloaded:
424
+ # the logged model is a transformers.Pipeline object
425
+ # weights of the model are saved in the directory
426
+ return _create_pipeline_from_dir(model_blob_file_or_dir_path, model_blob_options, **kwargs)
363
427
  else:
364
- assert os.path.isfile(model_blob_file_or_dir_path)
365
- with open(model_blob_file_or_dir_path, "rb") as f:
366
- m = cloudpickle.load(f)
367
- assert isinstance(m, huggingface_pipeline.HuggingFacePipelineModel)
368
- if getattr(m, "device", None) is None and getattr(m, "device_map", None) is None:
369
- m.__dict__.update(cls._get_device_config(**kwargs))
370
-
371
- if getattr(m, "torch_dtype", None) is None and kwargs.get("use_gpu", False):
372
- m.__dict__.update(torch_dtype="auto")
373
- return m
428
+ # case 1: LEGACY logging, repo snapshot is not logged
429
+ if os.path.isfile(model_blob_file_or_dir_path):
430
+ # LEGACY logging that had model as a pickle file in the model blob directory
431
+ # the logged model is a huggingface_pipeline.HuggingFacePipelineModel object
432
+ # the model_blob_file_or_dir_path is the pickle file that holds
433
+ # the huggingface_pipeline.HuggingFacePipelineModel object
434
+ # the snapshot of the repo is not logged
435
+ return cls._load_pickle_model(model_blob_file_or_dir_path)
436
+ else:
437
+ assert os.path.isdir(model_blob_file_or_dir_path)
438
+ # the logged model is a huggingface_pipeline.HuggingFacePipelineModel object
439
+ # the pickle_file holds the huggingface_pipeline.HuggingFacePipelineModel object
440
+ pickle_file = os.path.join(model_blob_file_or_dir_path, cls.MODEL_PICKLE_FILE)
441
+ m = cls._load_pickle_model(pickle_file)
442
+
443
+ # case 2: logging without the snapshot of the repo
444
+ if not is_repo_downloaded:
445
+ # we return the huggingface_pipeline.HuggingFacePipelineModel object
446
+ return m
447
+ # case 3: logging with the snapshot of the repo
448
+ else:
449
+ # the model_blob_file_or_dir_path is the directory that holds
450
+ # weights of the model from `huggingface_hub.snapshot_download`
451
+ # the huggingface_pipeline.HuggingFacePipelineModel object is logged
452
+ # with a snapshot of the repo, we create a transformers.Pipeline object
453
+ # by reading the snapshot directory
454
+ return _create_pipeline_from_model(model_blob_file_or_dir_path, m, **kwargs)
374
455
 
375
456
  @classmethod
376
457
  def convert_as_custom_model(
@@ -665,7 +746,7 @@ class HuggingFaceOpenAICompatibleModel:
665
746
  prompt_text,
666
747
  return_tensors="pt",
667
748
  padding=True,
668
- )
749
+ ).to(self.model.device)
669
750
  prompt_tokens = inputs.input_ids.shape[1]
670
751
 
671
752
  from transformers import GenerationConfig
@@ -683,6 +764,7 @@ class HuggingFaceOpenAICompatibleModel:
683
764
  num_return_sequences=n,
684
765
  num_beams=max(2, n), # must be >1
685
766
  num_beam_groups=max(2, n) if presence_penalty else 1,
767
+ do_sample=False,
686
768
  )
687
769
 
688
770
  # Generate text
@@ -229,6 +229,11 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
229
229
  enable_categorical = False
230
230
  for col, d_type in X.dtypes.items():
231
231
  if pd.api.extensions.ExtensionDtype.is_dtype(d_type):
232
+ if pd.CategoricalDtype.is_dtype(d_type):
233
+ enable_categorical = True
234
+ elif isinstance(d_type, pd.StringDtype):
235
+ X[col] = X[col].astype("category")
236
+ enable_categorical = True
232
237
  continue
233
238
  if not np.issubdtype(d_type, np.number):
234
239
  # categorical columns are converted to numpy's str dtype
@@ -51,6 +51,7 @@ class HuggingFacePipelineModelBlobOptions(BaseModelBlobOptions):
51
51
  has_tokenizer: NotRequired[bool]
52
52
  has_feature_extractor: NotRequired[bool]
53
53
  has_image_preprocessor: NotRequired[bool]
54
+ is_repo_downloaded: NotRequired[Optional[bool]]
54
55
 
55
56
 
56
57
  class LightGBMModelBlobOptions(BaseModelBlobOptions):
@@ -28,6 +28,10 @@ class HuggingFacePipelineModel:
28
28
  token: Optional[str] = None,
29
29
  trust_remote_code: Optional[bool] = None,
30
30
  model_kwargs: Optional[dict[str, Any]] = None,
31
+ download_snapshot: bool = True,
32
+ # repo snapshot download args
33
+ allow_patterns: Optional[Union[list[str], str]] = None,
34
+ ignore_patterns: Optional[Union[list[str], str]] = None,
31
35
  **kwargs: Any,
32
36
  ) -> None:
33
37
  """
@@ -52,6 +56,9 @@ class HuggingFacePipelineModel:
52
56
  Defaults to None.
53
57
  model_kwargs: Additional dictionary of keyword arguments passed along to the model's `from_pretrained(...,`.
54
58
  Defaults to None.
59
+ download_snapshot: Whether to download the HuggingFace repository. Defaults to True.
60
+ allow_patterns: If provided, only files matching at least one pattern are downloaded.
61
+ ignore_patterns: If provided, files matching any of the patterns are not downloaded.
55
62
  kwargs: Additional keyword arguments passed along to the specific pipeline init (see the documentation for
56
63
  the corresponding pipeline class for possible values).
57
64
 
@@ -220,6 +227,21 @@ class HuggingFacePipelineModel:
220
227
  stacklevel=2,
221
228
  )
222
229
 
230
+ repo_snapshot_dir: Optional[str] = None
231
+ if download_snapshot:
232
+ try:
233
+ from huggingface_hub import snapshot_download
234
+
235
+ repo_snapshot_dir = snapshot_download(
236
+ repo_id=model,
237
+ revision=revision,
238
+ token=token,
239
+ allow_patterns=allow_patterns,
240
+ ignore_patterns=ignore_patterns,
241
+ )
242
+ except ImportError:
243
+ logger.info("huggingface_hub package is not installed, skipping snapshot download")
244
+
223
245
  # ==== End pipeline logic from transformers ====
224
246
 
225
247
  self.task = normalized_task
@@ -229,6 +251,7 @@ class HuggingFacePipelineModel:
229
251
  self.trust_remote_code = trust_remote_code
230
252
  self.model_kwargs = model_kwargs
231
253
  self.tokenizer = tokenizer
254
+ self.repo_snapshot_dir = repo_snapshot_dir
232
255
  self.__dict__.update(kwargs)
233
256
 
234
257
  @telemetry.send_api_usage_telemetry(
@@ -30,8 +30,8 @@ class MonitorOperation(Enum):
30
30
  _OPERATION_SUPPORTED_PROPS: dict[MonitorOperation, frozenset[str]] = {
31
31
  MonitorOperation.SUSPEND: frozenset(),
32
32
  MonitorOperation.RESUME: frozenset(),
33
- MonitorOperation.ADD: frozenset({"SEGMENT_COLUMN"}),
34
- MonitorOperation.DROP: frozenset({"SEGMENT_COLUMN"}),
33
+ MonitorOperation.ADD: frozenset({"SEGMENT_COLUMN", "CUSTOM_METRIC_COLUMN"}),
34
+ MonitorOperation.DROP: frozenset({"SEGMENT_COLUMN", "CUSTOM_METRIC_COLUMN"}),
35
35
  }
36
36
 
37
37
 
@@ -91,6 +91,7 @@ class ModelMonitorSQLClient:
91
91
  baseline_schema: Optional[sql_identifier.SqlIdentifier] = None,
92
92
  baseline: Optional[sql_identifier.SqlIdentifier] = None,
93
93
  segment_columns: Optional[list[sql_identifier.SqlIdentifier]] = None,
94
+ custom_metric_columns: Optional[list[sql_identifier.SqlIdentifier]] = None,
94
95
  statement_params: Optional[dict[str, Any]] = None,
95
96
  ) -> None:
96
97
  baseline_sql = ""
@@ -101,6 +102,10 @@ class ModelMonitorSQLClient:
101
102
  if segment_columns:
102
103
  segment_columns_sql = f"SEGMENT_COLUMNS={_build_sql_list_from_columns(segment_columns)}"
103
104
 
105
+ custom_metric_columns_sql = ""
106
+ if custom_metric_columns:
107
+ custom_metric_columns_sql = f"CUSTOM_METRIC_COLUMNS={_build_sql_list_from_columns(custom_metric_columns)}"
108
+
104
109
  query_result_checker.SqlResultValidator(
105
110
  self._sql_client._session,
106
111
  f"""
@@ -120,6 +125,7 @@ class ModelMonitorSQLClient:
120
125
  REFRESH_INTERVAL='{refresh_interval}'
121
126
  AGGREGATION_WINDOW='{aggregation_window}'
122
127
  {segment_columns_sql}
128
+ {custom_metric_columns_sql}
123
129
  {baseline_sql}""",
124
130
  statement_params=statement_params,
125
131
  ).has_column("status").has_dimensions(1, 1).validate()
@@ -210,6 +216,7 @@ class ModelMonitorSQLClient:
210
216
  actual_class_columns: list[sql_identifier.SqlIdentifier],
211
217
  id_columns: list[sql_identifier.SqlIdentifier],
212
218
  segment_columns: Optional[list[sql_identifier.SqlIdentifier]] = None,
219
+ custom_metric_columns: Optional[list[sql_identifier.SqlIdentifier]] = None,
213
220
  ) -> None:
214
221
  """Ensures all columns exist in the source table.
215
222
 
@@ -222,12 +229,14 @@ class ModelMonitorSQLClient:
222
229
  actual_class_columns: List of actual class column names.
223
230
  id_columns: List of id column names.
224
231
  segment_columns: List of segment column names.
232
+ custom_metric_columns: List of custom metric column names.
225
233
 
226
234
  Raises:
227
235
  ValueError: If any of the columns do not exist in the source.
228
236
  """
229
237
 
230
238
  segment_columns = [] if segment_columns is None else segment_columns
239
+ custom_metric_columns = [] if custom_metric_columns is None else custom_metric_columns
231
240
 
232
241
  if timestamp_column not in source_column_schema:
233
242
  raise ValueError(f"Timestamp column {timestamp_column} does not exist in source.")
@@ -248,6 +257,9 @@ class ModelMonitorSQLClient:
248
257
  if not all([column_name in source_column_schema for column_name in segment_columns]):
249
258
  raise ValueError(f"Segment column(s): {segment_columns} do not exist in source.")
250
259
 
260
+ if not all([column_name in source_column_schema for column_name in custom_metric_columns]):
261
+ raise ValueError(f"Custom Metric column(s): {custom_metric_columns} do not exist in source.")
262
+
251
263
  def validate_source(
252
264
  self,
253
265
  *,
@@ -261,6 +273,7 @@ class ModelMonitorSQLClient:
261
273
  actual_class_columns: list[sql_identifier.SqlIdentifier],
262
274
  id_columns: list[sql_identifier.SqlIdentifier],
263
275
  segment_columns: Optional[list[sql_identifier.SqlIdentifier]] = None,
276
+ custom_metric_columns: Optional[list[sql_identifier.SqlIdentifier]] = None,
264
277
  ) -> None:
265
278
 
266
279
  source_database = source_database or self._database_name
@@ -281,6 +294,7 @@ class ModelMonitorSQLClient:
281
294
  actual_class_columns=actual_class_columns,
282
295
  id_columns=id_columns,
283
296
  segment_columns=segment_columns,
297
+ custom_metric_columns=custom_metric_columns,
284
298
  )
285
299
 
286
300
  def _alter_monitor(
@@ -299,7 +313,7 @@ class ModelMonitorSQLClient:
299
313
 
300
314
  if target_property not in supported_target_properties:
301
315
  raise ValueError(
302
- f"Only {', '.join(supported_target_properties)} supported as target property "
316
+ f"Only {', '.join(sorted(supported_target_properties))} supported as target property "
303
317
  f"for {operation.name} operation"
304
318
  )
305
319
 
@@ -366,3 +380,33 @@ class ModelMonitorSQLClient:
366
380
  target_value=segment_column,
367
381
  statement_params=statement_params,
368
382
  )
383
+
384
+ def add_custom_metric_column(
385
+ self,
386
+ monitor_name: sql_identifier.SqlIdentifier,
387
+ custom_metric_column: sql_identifier.SqlIdentifier,
388
+ statement_params: Optional[dict[str, Any]] = None,
389
+ ) -> None:
390
+ """Add a custom metric column to the Model Monitor"""
391
+ self._alter_monitor(
392
+ operation=MonitorOperation.ADD,
393
+ monitor_name=monitor_name,
394
+ target_property="CUSTOM_METRIC_COLUMN",
395
+ target_value=custom_metric_column,
396
+ statement_params=statement_params,
397
+ )
398
+
399
+ def drop_custom_metric_column(
400
+ self,
401
+ monitor_name: sql_identifier.SqlIdentifier,
402
+ custom_metric_column: sql_identifier.SqlIdentifier,
403
+ statement_params: Optional[dict[str, Any]] = None,
404
+ ) -> None:
405
+ """Drop a custom metric column from the Model Monitor"""
406
+ self._alter_monitor(
407
+ operation=MonitorOperation.DROP,
408
+ monitor_name=monitor_name,
409
+ target_property="CUSTOM_METRIC_COLUMN",
410
+ target_value=custom_metric_column,
411
+ statement_params=statement_params,
412
+ )
@@ -109,6 +109,7 @@ class ModelMonitorManager:
109
109
  actual_score_columns = self._build_column_list_from_input(source_config.actual_score_columns)
110
110
  actual_class_columns = self._build_column_list_from_input(source_config.actual_class_columns)
111
111
  segment_columns = self._build_column_list_from_input(source_config.segment_columns)
112
+ custom_metric_columns = self._build_column_list_from_input(source_config.custom_metric_columns)
112
113
 
113
114
  id_columns = [sql_identifier.SqlIdentifier(column_name) for column_name in source_config.id_columns]
114
115
  ts_column = sql_identifier.SqlIdentifier(source_config.timestamp_column)
@@ -125,6 +126,7 @@ class ModelMonitorManager:
125
126
  actual_class_columns=actual_class_columns,
126
127
  id_columns=id_columns,
127
128
  segment_columns=segment_columns,
129
+ custom_metric_columns=custom_metric_columns,
128
130
  )
129
131
 
130
132
  self._model_monitor_client.create_model_monitor(
@@ -147,6 +149,7 @@ class ModelMonitorManager:
147
149
  actual_score_columns=actual_score_columns,
148
150
  actual_class_columns=actual_class_columns,
149
151
  segment_columns=segment_columns,
152
+ custom_metric_columns=custom_metric_columns,
150
153
  refresh_interval=model_monitor_config.refresh_interval,
151
154
  aggregation_window=model_monitor_config.aggregation_window,
152
155
  baseline_database=baseline_database_name_id,
@@ -36,6 +36,9 @@ class ModelMonitorSourceConfig:
36
36
  segment_columns: Optional[list[str]] = None
37
37
  """List of columns in the source containing segment information for grouped monitoring."""
38
38
 
39
+ custom_metric_columns: Optional[list[str]] = None
40
+ """List of columns in the source containing custom metrics."""
41
+
39
42
 
40
43
  @dataclass
41
44
  class ModelMonitorConfig: