snowflake-ml-python 1.19.0__py3-none-any.whl → 1.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. snowflake/ml/_internal/env_utils.py +16 -0
  2. snowflake/ml/_internal/platform_capabilities.py +36 -0
  3. snowflake/ml/_internal/telemetry.py +56 -7
  4. snowflake/ml/data/_internal/arrow_ingestor.py +67 -2
  5. snowflake/ml/data/data_connector.py +103 -1
  6. snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +8 -2
  7. snowflake/ml/experiment/_entities/run.py +15 -0
  8. snowflake/ml/experiment/callback/keras.py +25 -2
  9. snowflake/ml/experiment/callback/lightgbm.py +27 -2
  10. snowflake/ml/experiment/callback/xgboost.py +25 -2
  11. snowflake/ml/experiment/experiment_tracking.py +123 -13
  12. snowflake/ml/experiment/utils.py +6 -0
  13. snowflake/ml/feature_store/access_manager.py +1 -0
  14. snowflake/ml/feature_store/feature_store.py +1 -1
  15. snowflake/ml/feature_store/feature_view.py +34 -24
  16. snowflake/ml/jobs/_interop/protocols.py +3 -0
  17. snowflake/ml/jobs/_utils/feature_flags.py +1 -0
  18. snowflake/ml/jobs/_utils/payload_utils.py +360 -357
  19. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +95 -8
  20. snowflake/ml/jobs/_utils/scripts/start_mlruntime.sh +92 -0
  21. snowflake/ml/jobs/_utils/scripts/startup.sh +112 -0
  22. snowflake/ml/jobs/_utils/spec_utils.py +2 -406
  23. snowflake/ml/jobs/_utils/stage_utils.py +22 -1
  24. snowflake/ml/jobs/_utils/types.py +14 -7
  25. snowflake/ml/jobs/job.py +8 -9
  26. snowflake/ml/jobs/manager.py +64 -129
  27. snowflake/ml/model/_client/model/inference_engine_utils.py +8 -4
  28. snowflake/ml/model/_client/model/model_version_impl.py +109 -28
  29. snowflake/ml/model/_client/ops/model_ops.py +32 -6
  30. snowflake/ml/model/_client/ops/service_ops.py +9 -4
  31. snowflake/ml/model/_client/sql/service.py +69 -2
  32. snowflake/ml/model/_packager/model_handler.py +8 -2
  33. snowflake/ml/model/_packager/model_handlers/{huggingface_pipeline.py → huggingface.py} +203 -76
  34. snowflake/ml/model/_packager/model_handlers/mlflow.py +6 -1
  35. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -1
  36. snowflake/ml/model/_signatures/core.py +305 -8
  37. snowflake/ml/model/_signatures/utils.py +13 -4
  38. snowflake/ml/model/compute_pool.py +2 -0
  39. snowflake/ml/model/models/huggingface.py +285 -0
  40. snowflake/ml/model/models/huggingface_pipeline.py +25 -215
  41. snowflake/ml/model/type_hints.py +5 -1
  42. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
  43. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +12 -0
  44. snowflake/ml/monitoring/_manager/model_monitor_manager.py +12 -0
  45. snowflake/ml/monitoring/entities/model_monitor_config.py +5 -0
  46. snowflake/ml/utils/html_utils.py +67 -1
  47. snowflake/ml/version.py +1 -1
  48. {snowflake_ml_python-1.19.0.dist-info → snowflake_ml_python-1.21.0.dist-info}/METADATA +94 -7
  49. {snowflake_ml_python-1.19.0.dist-info → snowflake_ml_python-1.21.0.dist-info}/RECORD +52 -48
  50. {snowflake_ml_python-1.19.0.dist-info → snowflake_ml_python-1.21.0.dist-info}/WHEEL +0 -0
  51. {snowflake_ml_python-1.19.0.dist-info → snowflake_ml_python-1.21.0.dist-info}/licenses/LICENSE.txt +0 -0
  52. {snowflake_ml_python-1.19.0.dist-info → snowflake_ml_python-1.21.0.dist-info}/top_level.txt +0 -0
@@ -1,27 +1,25 @@
1
1
  import json
2
2
  import logging
3
+ import os
3
4
  import pathlib
4
5
  import sys
5
- import textwrap
6
6
  from pathlib import PurePath
7
7
  from typing import Any, Callable, Optional, TypeVar, Union, cast, overload
8
8
  from uuid import uuid4
9
9
 
10
10
  import pandas as pd
11
- import yaml
12
11
 
13
12
  from snowflake import snowpark
14
13
  from snowflake.ml._internal import telemetry
15
14
  from snowflake.ml._internal.utils import identifier
16
15
  from snowflake.ml.jobs import job as jb
17
16
  from snowflake.ml.jobs._utils import (
17
+ constants,
18
18
  feature_flags,
19
19
  payload_utils,
20
20
  query_helper,
21
- spec_utils,
22
21
  types,
23
22
  )
24
- from snowflake.snowpark._internal import utils as sp_utils
25
23
  from snowflake.snowpark.context import get_active_session
26
24
  from snowflake.snowpark.exceptions import SnowparkSQLException
27
25
  from snowflake.snowpark.functions import coalesce, col, lit, when
@@ -259,7 +257,7 @@ def submit_directory(
259
257
  dir_path: str,
260
258
  compute_pool: str,
261
259
  *,
262
- entrypoint: str,
260
+ entrypoint: Union[str, list[str]],
263
261
  stage_name: str,
264
262
  args: Optional[list[str]] = None,
265
263
  target_instances: int = 1,
@@ -274,7 +272,11 @@ def submit_directory(
274
272
  Args:
275
273
  dir_path: The path to the directory containing the job payload.
276
274
  compute_pool: The compute pool to use for the job.
277
- entrypoint: The relative path to the entry point script inside the source directory.
275
+ entrypoint: The entry point for job execution. Can be:
276
+ - A string path to the entry point script inside the source directory.
277
+ - A list of strings representing a custom command (e.g., ["arctic_training"])
278
+ which is passed through as-is without local resolution or validation.
279
+ This is useful for entrypoints that are installed via pip_requirements.
278
280
  stage_name: The name of the stage where the job payload will be uploaded.
279
281
  args: A list of arguments to pass to the job.
280
282
  target_instances: The number of nodes in the job. If none specified, create a single node job.
@@ -315,7 +317,7 @@ def submit_from_stage(
315
317
  source: str,
316
318
  compute_pool: str,
317
319
  *,
318
- entrypoint: str,
320
+ entrypoint: Union[str, list[str]],
319
321
  stage_name: str,
320
322
  args: Optional[list[str]] = None,
321
323
  target_instances: int = 1,
@@ -330,7 +332,11 @@ def submit_from_stage(
330
332
  Args:
331
333
  source: a stage path or a stage containing the job payload.
332
334
  compute_pool: The compute pool to use for the job.
333
- entrypoint: a stage path containing the entry point script inside the source directory.
335
+ entrypoint: The entry point for job execution. Can be:
336
+ - A string path to the entry point script inside the source directory.
337
+ - A list of strings representing a custom command (e.g., ["arctic_training"])
338
+ which is passed through as-is without local resolution or validation.
339
+ This is useful for entrypoints that are installed via pip_requirements.
334
340
  stage_name: The name of the stage where the job payload will be uploaded.
335
341
  args: A list of arguments to pass to the job.
336
342
  target_instances: The number of nodes in the job. If none specified, create a single node job.
@@ -375,7 +381,7 @@ def _submit_job(
375
381
  compute_pool: str,
376
382
  *,
377
383
  stage_name: str,
378
- entrypoint: Optional[str] = None,
384
+ entrypoint: Optional[Union[str, list[str]]] = None,
379
385
  args: Optional[list[str]] = None,
380
386
  target_instances: int = 1,
381
387
  pip_requirements: Optional[list[str]] = None,
@@ -392,7 +398,7 @@ def _submit_job(
392
398
  compute_pool: str,
393
399
  *,
394
400
  stage_name: str,
395
- entrypoint: Optional[str] = None,
401
+ entrypoint: Optional[Union[str, list[str]]] = None,
396
402
  args: Optional[list[str]] = None,
397
403
  target_instances: int = 1,
398
404
  pip_requirements: Optional[list[str]] = None,
@@ -424,7 +430,7 @@ def _submit_job(
424
430
  compute_pool: str,
425
431
  *,
426
432
  stage_name: str,
427
- entrypoint: Optional[str] = None,
433
+ entrypoint: Optional[Union[str, list[str]]] = None,
428
434
  args: Optional[list[str]] = None,
429
435
  target_instances: int = 1,
430
436
  session: Optional[snowpark.Session] = None,
@@ -437,7 +443,11 @@ def _submit_job(
437
443
  source: The file/directory path containing payload source code or a serializable Python callable.
438
444
  compute_pool: The compute pool to use for the job.
439
445
  stage_name: The name of the stage where the job payload will be uploaded.
440
- entrypoint: The entry point for the job execution. Required if source is a directory.
446
+ entrypoint: The entry point for the job execution. Can be:
447
+ - A string path to a Python script (required if source is a directory).
448
+ - A list of strings representing a custom command (e.g., ["arctic_training"])
449
+ which is passed through as-is without local resolution or validation.
450
+ This is useful for entrypoints that are installed via pip_requirements.
441
451
  args: A list of arguments to pass to the job.
442
452
  target_instances: The number of instances to use for the job. If none specified, single node job is created.
443
453
  session: The Snowpark session to use. If none specified, uses active session.
@@ -449,7 +459,6 @@ def _submit_job(
449
459
  Raises:
450
460
  ValueError: If database or schema value(s) are invalid
451
461
  RuntimeError: If schema is not specified in session context or job submission
452
- SnowparkSQLException: if failed to upload payload
453
462
  """
454
463
  session = _ensure_session(session)
455
464
 
@@ -481,7 +490,8 @@ def _submit_job(
481
490
  enable_metrics = kwargs.pop("enable_metrics", True)
482
491
  query_warehouse = kwargs.pop("query_warehouse", session.get_current_warehouse())
483
492
  imports = kwargs.pop("imports", None) or imports
484
- runtime_environment = kwargs.pop("runtime_environment", None)
493
+ # if the mljob is submitted from a notebook, we use the same image tag as the notebook
494
+ runtime_environment = kwargs.pop("runtime_environment", os.environ.get(constants.RUNTIME_IMAGE_TAG_ENV_VAR, None))
485
495
 
486
496
  # Warn if there are unknown kwargs
487
497
  if kwargs:
@@ -513,128 +523,51 @@ def _submit_job(
513
523
  try:
514
524
  # Upload payload
515
525
  uploaded_payload = payload_utils.JobPayload(
516
- source, entrypoint=entrypoint, pip_requirements=pip_requirements, additional_payloads=imports
526
+ source, entrypoint=entrypoint, pip_requirements=pip_requirements, imports=imports
517
527
  ).upload(session, stage_path)
518
528
  except SnowparkSQLException as e:
519
529
  if e.sql_error_code == 90106:
520
530
  raise RuntimeError(
521
531
  "Please specify a schema, either in the session context or as a parameter in the job submission"
522
532
  )
533
+ elif e.sql_error_code == 3001 and "schema" in str(e).lower():
534
+ raise RuntimeError(
535
+ "please grant privileges on schema before submitting a job, see",
536
+ "https://docs.snowflake.com/en/developer-guide/snowflake-ml/ml-jobs/access-control-requirements",
537
+ " for more details",
538
+ ) from e
523
539
  raise
524
540
 
525
- if feature_flags.FeatureFlags.USE_SUBMIT_JOB_V2.is_enabled(default=True):
526
- # Add default env vars (extracted from spec_utils.generate_service_spec)
527
- combined_env_vars = {**uploaded_payload.env_vars, **(env_vars or {})}
528
-
529
- try:
530
- return _do_submit_job_v2(
531
- session=session,
532
- payload=uploaded_payload,
533
- args=args,
534
- env_vars=combined_env_vars,
535
- spec_overrides=spec_overrides,
536
- compute_pool=compute_pool,
537
- job_id=job_id,
538
- external_access_integrations=external_access_integrations,
539
- query_warehouse=query_warehouse,
540
- target_instances=target_instances,
541
- min_instances=min_instances,
542
- enable_metrics=enable_metrics,
543
- use_async=True,
544
- runtime_environment=runtime_environment,
545
- )
546
- except SnowparkSQLException as e:
547
- if not (e.sql_error_code == 90237 and sp_utils.is_in_stored_procedure()): # type: ignore[no-untyped-call]
548
- raise
549
- # SNOW-2390287: SYSTEM$EXECUTE_ML_JOB() is erroneously blocked in owner's rights
550
- # stored procedures. This will be fixed in an upcoming release.
551
- logger.warning(
552
- "Job submission using V2 failed with error {}. Falling back to V1.".format(
553
- str(e).split("\n", 1)[0],
554
- )
555
- )
556
-
557
- # Fall back to v1
558
- # Generate service spec
559
- spec = spec_utils.generate_service_spec(
560
- session,
561
- compute_pool=compute_pool,
562
- payload=uploaded_payload,
563
- args=args,
564
- target_instances=target_instances,
565
- min_instances=min_instances,
566
- enable_metrics=enable_metrics,
567
- runtime_environment=runtime_environment,
568
- )
569
-
570
- # Generate spec overrides
571
- spec_overrides = spec_utils.generate_spec_overrides(
572
- environment_vars=env_vars,
573
- custom_overrides=spec_overrides,
574
- )
575
- if spec_overrides:
576
- spec = spec_utils.merge_patch(spec, spec_overrides, display_name="spec_overrides")
577
-
578
- return _do_submit_job_v1(
579
- session, spec, external_access_integrations, query_warehouse, target_instances, compute_pool, job_id
580
- )
581
-
582
-
583
- def _do_submit_job_v1(
584
- session: snowpark.Session,
585
- spec: dict[str, Any],
586
- external_access_integrations: list[str],
587
- query_warehouse: Optional[str],
588
- target_instances: int,
589
- compute_pool: str,
590
- job_id: str,
591
- ) -> jb.MLJob[Any]:
592
- """
593
- Generate the SQL query for job submission.
594
-
595
- Args:
596
- session: The Snowpark session to use.
597
- spec: The service spec for the job.
598
- external_access_integrations: The external access integrations for the job.
599
- query_warehouse: The query warehouse for the job.
600
- target_instances: The number of instances for the job.
601
- session: The Snowpark session to use.
602
- compute_pool: The compute pool to use for the job.
603
- job_id: The ID of the job.
604
-
605
- Returns:
606
- The job object.
607
- """
608
- query_template = textwrap.dedent(
609
- """\
610
- EXECUTE JOB SERVICE
611
- IN COMPUTE POOL IDENTIFIER(?)
612
- FROM SPECIFICATION $$
613
- {}
614
- $$
615
- NAME = IDENTIFIER(?)
616
- ASYNC = TRUE
617
- """
618
- )
619
- params: list[Any] = [compute_pool, job_id]
620
- query = query_template.format(yaml.dump(spec)).splitlines()
621
- if external_access_integrations:
622
- external_access_integration_list = ",".join(f"{e}" for e in external_access_integrations)
623
- query.append(f"EXTERNAL_ACCESS_INTEGRATIONS = ({external_access_integration_list})")
624
- if query_warehouse:
625
- query.append("QUERY_WAREHOUSE = IDENTIFIER(?)")
626
- params.append(query_warehouse)
627
- if target_instances > 1:
628
- query.append("REPLICAS = ?")
629
- params.append(target_instances)
630
-
631
- query_text = "\n".join(line for line in query if line)
632
- _ = query_helper.run_query(session, query_text, params=params)
541
+ combined_env_vars = {**uploaded_payload.env_vars, **(env_vars or {})}
633
542
 
634
- return get_job(job_id, session=session)
543
+ try:
544
+ return _do_submit_job(
545
+ session=session,
546
+ payload=uploaded_payload,
547
+ args=args,
548
+ env_vars=combined_env_vars,
549
+ spec_overrides=spec_overrides,
550
+ compute_pool=compute_pool,
551
+ job_id=job_id,
552
+ external_access_integrations=external_access_integrations,
553
+ query_warehouse=query_warehouse,
554
+ target_instances=target_instances,
555
+ min_instances=min_instances,
556
+ enable_metrics=enable_metrics,
557
+ use_async=True,
558
+ runtime_environment=runtime_environment,
559
+ )
560
+ except SnowparkSQLException as e:
561
+ if e.sql_error_code == 3001 and "schema" in str(e).lower():
562
+ raise RuntimeError(
563
+ "please grant privileges on schema before submitting a job, see",
564
+ "https://docs.snowflake.com/en/developer-guide/snowflake-ml/ml-jobs/access-control-requirements"
565
+ " for more details",
566
+ ) from e
567
+ raise
635
568
 
636
569
 
637
- def _do_submit_job_v2(
570
+ def _do_submit_job(
638
571
  session: snowpark.Session,
639
572
  payload: types.UploadedPayload,
640
573
  args: Optional[list[str]],
@@ -672,9 +605,7 @@ def _do_submit_job_v2(
672
605
  Returns:
673
606
  The job object.
674
607
  """
675
- args = [
676
- (payload.stage_path.joinpath(v).as_posix() if isinstance(v, PurePath) else v) for v in payload.entrypoint
677
- ] + (args or [])
608
+ args = [(v.as_posix() if isinstance(v, PurePath) else v) for v in payload.entrypoint] + (args or [])
678
609
  spec_options = {
679
610
  "STAGE_PATH": payload.stage_path.as_posix(),
680
611
  "ENTRYPOINT": ["/usr/local/bin/_entrypoint.sh"],
@@ -683,13 +614,14 @@ def _do_submit_job_v2(
683
614
  "ENABLE_METRICS": enable_metrics,
684
615
  "SPEC_OVERRIDES": spec_overrides,
685
616
  }
686
- # for the image tag or full image URL, we use that directly
687
617
  if runtime_environment:
618
+ # for the image tag or full image URL, we use that directly
688
619
  spec_options["RUNTIME"] = runtime_environment
689
620
  elif feature_flags.FeatureFlags.ENABLE_RUNTIME_VERSIONS.is_enabled():
690
621
  # when feature flag is enabled, we get the local python version and wrap it in a dict
691
622
  # in system function, we can know whether it is python version or image tag or full image URL through the format
692
623
  spec_options["RUNTIME"] = json.dumps({"pythonVersion": f"{sys.version_info.major}.{sys.version_info.minor}"})
624
+
693
625
  job_options = {
694
626
  "EXTERNAL_ACCESS_INTEGRATIONS": external_access_integrations,
695
627
  "QUERY_WAREHOUSE": query_warehouse,
@@ -697,6 +629,9 @@ def _do_submit_job_v2(
697
629
  "MIN_INSTANCES": min_instances,
698
630
  "ASYNC": use_async,
699
631
  }
632
+
633
+ if feature_flags.FeatureFlags.ENABLE_STAGE_MOUNT_V2.is_enabled(default=True):
634
+ spec_options["ENABLE_STAGE_MOUNT_V2"] = True
700
635
  if payload.payload_name:
701
636
  job_options["GENERATE_SUFFIX"] = True
702
637
  job_options = {k: v for k, v in job_options.items() if v is not None}
@@ -4,14 +4,18 @@ from snowflake.ml.model._client.ops import service_ops
4
4
 
5
5
 
6
6
  def _get_inference_engine_args(
7
- experimental_options: Optional[dict[str, Any]],
7
+ inference_engine_options: Optional[dict[str, Any]],
8
8
  ) -> Optional[service_ops.InferenceEngineArgs]:
9
- if not experimental_options or "inference_engine" not in experimental_options:
9
+
10
+ if not inference_engine_options:
10
11
  return None
11
12
 
13
+ if "engine" not in inference_engine_options:
14
+ raise ValueError("'engine' field is required in inference_engine_options")
15
+
12
16
  return service_ops.InferenceEngineArgs(
13
- inference_engine=experimental_options["inference_engine"],
14
- inference_engine_args_override=experimental_options.get("inference_engine_args_override"),
17
+ inference_engine=inference_engine_options["engine"],
18
+ inference_engine_args_override=inference_engine_options.get("engine_args_override"),
15
19
  )
16
20
 
17
21
 
@@ -8,11 +8,11 @@ from typing import Any, Callable, Optional, Union, overload
8
8
  import pandas as pd
9
9
 
10
10
  from snowflake import snowpark
11
- from snowflake.ml import jobs
12
11
  from snowflake.ml._internal import telemetry
13
12
  from snowflake.ml._internal.utils import sql_identifier
13
+ from snowflake.ml.jobs import job
14
14
  from snowflake.ml.lineage import lineage_node
15
- from snowflake.ml.model import task, type_hints
15
+ from snowflake.ml.model import openai_signatures, task, type_hints
16
16
  from snowflake.ml.model._client.model import (
17
17
  batch_inference_specs,
18
18
  inference_engine_utils,
@@ -23,6 +23,7 @@ from snowflake.ml.model._model_composer.model_manifest import model_manifest_sch
23
23
  from snowflake.ml.model._model_composer.model_method import utils as model_method_utils
24
24
  from snowflake.ml.model._packager.model_handlers import snowmlmodel
25
25
  from snowflake.ml.model._packager.model_meta import model_meta_schema
26
+ from snowflake.ml.model._signatures import core
26
27
  from snowflake.snowpark import Session, async_job, dataframe
27
28
 
28
29
  _TELEMETRY_PROJECT = "MLOps"
@@ -45,6 +46,7 @@ class ModelVersion(lineage_node.LineageNode):
45
46
  _version_name: sql_identifier.SqlIdentifier
46
47
  _functions: list[model_manifest_schema.ModelFunctionInfo]
47
48
  _model_spec: Optional[model_meta_schema.ModelMetadataDict]
49
+ _model_manifest: Optional[model_manifest_schema.ModelManifestDict]
48
50
 
49
51
  def __init__(self) -> None:
50
52
  raise RuntimeError("ModelVersion's initializer is not meant to be used. Use `version` from model instead.")
@@ -155,6 +157,7 @@ class ModelVersion(lineage_node.LineageNode):
155
157
  self._version_name = version_name
156
158
  self._functions = self._get_functions()
157
159
  self._model_spec = None
160
+ self._model_manifest = None
158
161
  super(cls, cls).__init__(
159
162
  self,
160
163
  session=model_ops._session,
@@ -462,6 +465,28 @@ class ModelVersion(lineage_node.LineageNode):
462
465
  )
463
466
  return self._model_spec
464
467
 
468
+ def _get_model_manifest(
469
+ self, statement_params: Optional[dict[str, Any]] = None
470
+ ) -> model_manifest_schema.ModelManifestDict:
471
+ """Fetch and cache the model manifest for this model version.
472
+
473
+ Args:
474
+ statement_params: Optional dictionary of statement parameters to include
475
+ in the SQL command to fetch the model manifest.
476
+
477
+ Returns:
478
+ The model manifest as a dictionary for this model version.
479
+ """
480
+ if self._model_manifest is None:
481
+ self._model_manifest = self._model_ops.get_model_version_manifest(
482
+ database_name=None,
483
+ schema_name=None,
484
+ model_name=self._model_name,
485
+ version_name=self._version_name,
486
+ statement_params=statement_params,
487
+ )
488
+ return self._model_manifest
489
+
465
490
  @overload
466
491
  def run(
467
492
  self,
@@ -530,6 +555,9 @@ class ModelVersion(lineage_node.LineageNode):
530
555
 
531
556
  Returns:
532
557
  The prediction data. It would be the same type dataframe as your input.
558
+
559
+ Raises:
560
+ ValueError: When the model does not support running on warehouse and no service name is provided.
533
561
  """
534
562
  statement_params = telemetry.get_statement_params(
535
563
  project=_TELEMETRY_PROJECT,
@@ -556,6 +584,27 @@ class ModelVersion(lineage_node.LineageNode):
556
584
  statement_params=statement_params,
557
585
  )
558
586
  else:
587
+ manifest = self._get_model_manifest(statement_params=statement_params)
588
+ target_platforms = manifest.get("target_platforms", None)
589
+ if (
590
+ target_platforms is not None
591
+ and len(target_platforms) > 0
592
+ and type_hints.TargetPlatform.WAREHOUSE.value not in target_platforms
593
+ ):
594
+ raise ValueError(
595
+ f"The model {self.fully_qualified_model_name} version {self.version_name} "
596
+ "is not logged for inference in Warehouse. "
597
+ "To run the model in Warehouse, please log the model again using `log_model` API with "
598
+ '`target_platforms=["WAREHOUSE"]` or '
599
+ '`target_platforms=["WAREHOUSE", "SNOWPARK_CONTAINER_SERVICES"]` and rerun the command. '
600
+ "To run the model in Snowpark Container Services, the `service_name` argument must be provided. "
601
+ "You can create a service using the `create_service` API. "
602
+ "For inference in Warehouse, see https://docs.snowflake.com/en/developer-guide/"
603
+ "snowflake-ml/model-registry/warehouse#inference-from-python. "
604
+ "For inference in Snowpark Container Services, see https://docs.snowflake.com/en/developer-guide/"
605
+ "snowflake-ml/model-registry/container#python."
606
+ )
607
+
559
608
  explain_case_sensitive = self._determine_explain_case_sensitivity(target_function_info, statement_params)
560
609
 
561
610
  return self._model_ops.invoke_method(
@@ -602,7 +651,7 @@ class ModelVersion(lineage_node.LineageNode):
602
651
  input_spec: dataframe.DataFrame,
603
652
  output_spec: batch_inference_specs.OutputSpec,
604
653
  job_spec: Optional[batch_inference_specs.JobSpec] = None,
605
- ) -> jobs.MLJob[Any]:
654
+ ) -> job.MLJob[Any]:
606
655
  """Execute batch inference on datasets as an SPCS job.
607
656
 
608
657
  Args:
@@ -617,7 +666,7 @@ class ModelVersion(lineage_node.LineageNode):
617
666
  If None, default values will be used.
618
667
 
619
668
  Returns:
620
- jobs.MLJob[Any]: A batch inference job object that can be used to monitor progress and manage the job
669
+ job.MLJob[Any]: A batch inference job object that can be used to monitor progress and manage the job
621
670
  lifecycle.
622
671
 
623
672
  Raises:
@@ -940,14 +989,16 @@ class ModelVersion(lineage_node.LineageNode):
940
989
  self,
941
990
  statement_params: Optional[dict[str, Any]] = None,
942
991
  ) -> None:
943
- """Check if the model is a HuggingFace pipeline with text-generation task.
992
+ """Check if the model is a HuggingFace pipeline with text-generation task
993
+ and is logged with OPENAI_CHAT_SIGNATURE.
944
994
 
945
995
  Args:
946
996
  statement_params: Optional dictionary of statement parameters to include
947
997
  in the SQL command to fetch model spec.
948
998
 
949
999
  Raises:
950
- ValueError: If the model is not a HuggingFace text-generation model.
1000
+ ValueError: If the model is not a HuggingFace text-generation model or
1001
+ if the model is not logged with OPENAI_CHAT_SIGNATURE.
951
1002
  """
952
1003
  # Fetch model spec
953
1004
  model_spec = self._get_model_spec(statement_params)
@@ -983,6 +1034,21 @@ class ModelVersion(lineage_node.LineageNode):
983
1034
  )
984
1035
  raise ValueError(f"Inference engine is only supported for task 'text-generation'. {found_tasks_str}")
985
1036
 
1037
+ # Check if the model is logged with OPENAI_CHAT_SIGNATURE
1038
+ signatures_dict = model_spec.get("signatures", {})
1039
+
1040
+ # Deserialize signatures from model spec to ModelSignature objects for proper semantic comparison.
1041
+ deserialized_signatures = {
1042
+ func_name: core.ModelSignature.from_dict(sig_dict) for func_name, sig_dict in signatures_dict.items()
1043
+ }
1044
+
1045
+ if deserialized_signatures != openai_signatures.OPENAI_CHAT_SIGNATURE:
1046
+ raise ValueError(
1047
+ "Inference engine requires the model to be logged with OPENAI_CHAT_SIGNATURE. "
1048
+ f"Found signatures: {signatures_dict}. "
1049
+ "Please log the model with: signatures=openai_signatures.OPENAI_CHAT_SIGNATURE"
1050
+ )
1051
+
986
1052
  @overload
987
1053
  def create_service(
988
1054
  self,
@@ -1001,6 +1067,8 @@ class ModelVersion(lineage_node.LineageNode):
1001
1067
  force_rebuild: bool = False,
1002
1068
  build_external_access_integration: Optional[str] = None,
1003
1069
  block: bool = True,
1070
+ autocapture: bool = False,
1071
+ inference_engine_options: Optional[dict[str, Any]] = None,
1004
1072
  experimental_options: Optional[dict[str, Any]] = None,
1005
1073
  ) -> Union[str, async_job.AsyncJob]:
1006
1074
  """Create an inference service with the given spec.
@@ -1034,11 +1102,13 @@ class ModelVersion(lineage_node.LineageNode):
1034
1102
  block: A bool value indicating whether this function will wait until the service is available.
1035
1103
  When it is ``False``, this function executes the underlying service creation asynchronously
1036
1104
  and returns an :class:`AsyncJob`.
1037
- experimental_options: Experimental options for the service creation with custom inference engine.
1038
- Currently, `inference_engine`, `inference_engine_args_override`, and `autocapture` are supported.
1039
- `inference_engine` is the name of the inference engine to use.
1040
- `inference_engine_args_override` is a list of string arguments to pass to the inference engine.
1041
- `autocapture` is a boolean to enable/disable inference table.
1105
+ autocapture: Whether inference autocapture is enabled on the service. If true, inference data will be
1106
+ captured in the model inference table.
1107
+ inference_engine_options: Options for the service creation with custom inference engine.
1108
+ Supports `engine` and `engine_args_override`.
1109
+ `engine` is the type of the inference engine to use.
1110
+ `engine_args_override` is a list of string arguments to pass to the inference engine.
1111
+ experimental_options: Experimental options for the service creation.
1042
1112
  """
1043
1113
  ...
1044
1114
 
@@ -1060,6 +1130,8 @@ class ModelVersion(lineage_node.LineageNode):
1060
1130
  force_rebuild: bool = False,
1061
1131
  build_external_access_integrations: Optional[list[str]] = None,
1062
1132
  block: bool = True,
1133
+ autocapture: bool = False,
1134
+ inference_engine_options: Optional[dict[str, Any]] = None,
1063
1135
  experimental_options: Optional[dict[str, Any]] = None,
1064
1136
  ) -> Union[str, async_job.AsyncJob]:
1065
1137
  """Create an inference service with the given spec.
@@ -1093,11 +1165,13 @@ class ModelVersion(lineage_node.LineageNode):
1093
1165
  block: A bool value indicating whether this function will wait until the service is available.
1094
1166
  When it is ``False``, this function executes the underlying service creation asynchronously
1095
1167
  and returns an :class:`AsyncJob`.
1096
- experimental_options: Experimental options for the service creation with custom inference engine.
1097
- Currently, `inference_engine`, `inference_engine_args_override`, and `autocapture` are supported.
1098
- `inference_engine` is the name of the inference engine to use.
1099
- `inference_engine_args_override` is a list of string arguments to pass to the inference engine.
1100
- `autocapture` is a boolean to enable/disable inference table.
1168
+ autocapture: Whether inference autocapture is enabled on the service. If true, inference data will be
1169
+ captured in the model inference table.
1170
+ inference_engine_options: Options for the service creation with custom inference engine.
1171
+ Supports `engine` and `engine_args_override`.
1172
+ `engine` is the type of the inference engine to use.
1173
+ `engine_args_override` is a list of string arguments to pass to the inference engine.
1174
+ experimental_options: Experimental options for the service creation.
1101
1175
  """
1102
1176
  ...
1103
1177
 
@@ -1134,6 +1208,8 @@ class ModelVersion(lineage_node.LineageNode):
1134
1208
  build_external_access_integration: Optional[str] = None,
1135
1209
  build_external_access_integrations: Optional[list[str]] = None,
1136
1210
  block: bool = True,
1211
+ autocapture: bool = False,
1212
+ inference_engine_options: Optional[dict[str, Any]] = None,
1137
1213
  experimental_options: Optional[dict[str, Any]] = None,
1138
1214
  ) -> Union[str, async_job.AsyncJob]:
1139
1215
  """Create an inference service with the given spec.
@@ -1169,11 +1245,13 @@ class ModelVersion(lineage_node.LineageNode):
1169
1245
  block: A bool value indicating whether this function will wait until the service is available.
1170
1246
  When it is False, this function executes the underlying service creation asynchronously
1171
1247
  and returns an AsyncJob.
1172
- experimental_options: Experimental options for the service creation with custom inference engine.
1173
- Currently, `inference_engine`, `inference_engine_args_override`, and `autocapture` are supported.
1174
- `inference_engine` is the name of the inference engine to use.
1175
- `inference_engine_args_override` is a list of string arguments to pass to the inference engine.
1176
- `autocapture` is a boolean to enable/disable inference table.
1248
+ autocapture: Whether inference autocapture is enabled on the service. If true, inference data will be
1249
+ captured in the model inference table.
1250
+ inference_engine_options: Options for the service creation with custom inference engine.
1251
+ Supports `engine` and `engine_args_override`.
1252
+ `engine` is the type of the inference engine to use.
1253
+ `engine_args_override` is a list of string arguments to pass to the inference engine.
1254
+ experimental_options: Experimental options for the service creation.
1177
1255
 
1178
1256
 
1179
1257
  Raises:
@@ -1209,9 +1287,10 @@ class ModelVersion(lineage_node.LineageNode):
1209
1287
  # Validate GPU support if GPU resources are requested
1210
1288
  self._throw_error_if_gpu_is_not_supported(gpu_requests, statement_params)
1211
1289
 
1212
- inference_engine_args = inference_engine_utils._get_inference_engine_args(experimental_options)
1290
+ inference_engine_args = inference_engine_utils._get_inference_engine_args(inference_engine_options)
1213
1291
 
1214
- # Check if model is HuggingFace text-generation before doing inference engine checks
1292
+ # Check if model is HuggingFace text-generation and is logged with
1293
+ # OPENAI_CHAT_SIGNATURE before doing inference engine checks
1215
1294
  # Only validate if inference engine is actually specified
1216
1295
  if inference_engine_args is not None:
1217
1296
  self._check_huggingface_text_generation_model(statement_params)
@@ -1223,9 +1302,6 @@ class ModelVersion(lineage_node.LineageNode):
1223
1302
  gpu_requests,
1224
1303
  )
1225
1304
 
1226
- # Extract autocapture from experimental_options
1227
- autocapture = experimental_options.get("autocapture") if experimental_options else None
1228
-
1229
1305
  from snowflake.ml.model import event_handler
1230
1306
  from snowflake.snowpark import exceptions
1231
1307
 
@@ -1292,8 +1368,13 @@ class ModelVersion(lineage_node.LineageNode):
1292
1368
  """List all the service names using this model version.
1293
1369
 
1294
1370
  Returns:
1295
- List of service_names: The name of the service, can be fully qualified. If not fully qualified, the database
1296
- or schema of the model will be used.
1371
+ List of details about all the services associated with this model version. The details include:
1372
+ name: The name of the service.
1373
+ status: The status of the service.
1374
+ inference_endpoint: The public endpoint of the service, if enabled and services is not in PENDING state.
1375
+ This will give privatelink endpoint if the session is created with privatelink connection
1376
+ internal_endpoint: The internal endpoint of the service, if services is not in PENDING state.
1377
+ autocapture_enabled: Whether service has autocapture enabled, if it is set in service proxy spec.
1297
1378
  """
1298
1379
  statement_params = telemetry.get_statement_params(
1299
1380
  project=_TELEMETRY_PROJECT,