snowflake-ml-python 1.20.0__py3-none-any.whl → 1.22.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. snowflake/ml/_internal/platform_capabilities.py +36 -0
  2. snowflake/ml/_internal/utils/url.py +42 -0
  3. snowflake/ml/data/_internal/arrow_ingestor.py +67 -2
  4. snowflake/ml/data/data_connector.py +103 -1
  5. snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +8 -2
  6. snowflake/ml/experiment/callback/__init__.py +0 -0
  7. snowflake/ml/experiment/callback/keras.py +25 -2
  8. snowflake/ml/experiment/callback/lightgbm.py +27 -2
  9. snowflake/ml/experiment/callback/xgboost.py +25 -2
  10. snowflake/ml/experiment/experiment_tracking.py +93 -3
  11. snowflake/ml/experiment/utils.py +6 -0
  12. snowflake/ml/feature_store/feature_view.py +34 -24
  13. snowflake/ml/jobs/_interop/protocols.py +3 -0
  14. snowflake/ml/jobs/_utils/constants.py +1 -0
  15. snowflake/ml/jobs/_utils/payload_utils.py +354 -356
  16. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +95 -8
  17. snowflake/ml/jobs/_utils/scripts/start_mlruntime.sh +92 -0
  18. snowflake/ml/jobs/_utils/scripts/startup.sh +112 -0
  19. snowflake/ml/jobs/_utils/spec_utils.py +1 -445
  20. snowflake/ml/jobs/_utils/stage_utils.py +22 -1
  21. snowflake/ml/jobs/_utils/types.py +14 -7
  22. snowflake/ml/jobs/job.py +2 -8
  23. snowflake/ml/jobs/manager.py +57 -135
  24. snowflake/ml/lineage/lineage_node.py +1 -1
  25. snowflake/ml/model/__init__.py +6 -0
  26. snowflake/ml/model/_client/model/batch_inference_specs.py +16 -1
  27. snowflake/ml/model/_client/model/model_version_impl.py +130 -14
  28. snowflake/ml/model/_client/ops/deployment_step.py +36 -0
  29. snowflake/ml/model/_client/ops/model_ops.py +93 -8
  30. snowflake/ml/model/_client/ops/service_ops.py +32 -52
  31. snowflake/ml/model/_client/service/import_model_spec_schema.py +23 -0
  32. snowflake/ml/model/_client/service/model_deployment_spec.py +12 -4
  33. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +3 -0
  34. snowflake/ml/model/_client/sql/model_version.py +30 -6
  35. snowflake/ml/model/_client/sql/service.py +94 -5
  36. snowflake/ml/model/_model_composer/model_composer.py +1 -1
  37. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +5 -0
  38. snowflake/ml/model/_model_composer/model_method/model_method.py +61 -2
  39. snowflake/ml/model/_packager/model_handler.py +8 -2
  40. snowflake/ml/model/_packager/model_handlers/custom.py +52 -0
  41. snowflake/ml/model/_packager/model_handlers/{huggingface_pipeline.py → huggingface.py} +203 -76
  42. snowflake/ml/model/_packager/model_handlers/mlflow.py +6 -1
  43. snowflake/ml/model/_packager/model_handlers/xgboost.py +26 -1
  44. snowflake/ml/model/_packager/model_meta/model_meta.py +40 -7
  45. snowflake/ml/model/_packager/model_packager.py +1 -1
  46. snowflake/ml/model/_signatures/core.py +390 -8
  47. snowflake/ml/model/_signatures/utils.py +13 -4
  48. snowflake/ml/model/code_path.py +104 -0
  49. snowflake/ml/model/compute_pool.py +2 -0
  50. snowflake/ml/model/custom_model.py +55 -13
  51. snowflake/ml/model/model_signature.py +13 -1
  52. snowflake/ml/model/models/huggingface.py +285 -0
  53. snowflake/ml/model/models/huggingface_pipeline.py +19 -208
  54. snowflake/ml/model/type_hints.py +7 -1
  55. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
  56. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +12 -0
  57. snowflake/ml/monitoring/_manager/model_monitor_manager.py +12 -0
  58. snowflake/ml/monitoring/entities/model_monitor_config.py +5 -0
  59. snowflake/ml/registry/_manager/model_manager.py +230 -15
  60. snowflake/ml/registry/registry.py +4 -4
  61. snowflake/ml/utils/html_utils.py +67 -1
  62. snowflake/ml/version.py +1 -1
  63. {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/METADATA +81 -7
  64. {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/RECORD +67 -59
  65. snowflake/ml/jobs/_utils/runtime_env_utils.py +0 -63
  66. {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/WHEEL +0 -0
  67. {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/licenses/LICENSE.txt +0 -0
  68. {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/top_level.txt +0 -0
@@ -1,27 +1,25 @@
1
1
  import json
2
2
  import logging
3
+ import os
3
4
  import pathlib
4
5
  import sys
5
- import textwrap
6
6
  from pathlib import PurePath
7
7
  from typing import Any, Callable, Optional, TypeVar, Union, cast, overload
8
8
  from uuid import uuid4
9
9
 
10
10
  import pandas as pd
11
- import yaml
12
11
 
13
12
  from snowflake import snowpark
14
13
  from snowflake.ml._internal import telemetry
15
14
  from snowflake.ml._internal.utils import identifier
16
15
  from snowflake.ml.jobs import job as jb
17
16
  from snowflake.ml.jobs._utils import (
17
+ constants,
18
18
  feature_flags,
19
19
  payload_utils,
20
20
  query_helper,
21
- spec_utils,
22
21
  types,
23
22
  )
24
- from snowflake.snowpark._internal import utils as sp_utils
25
23
  from snowflake.snowpark.context import get_active_session
26
24
  from snowflake.snowpark.exceptions import SnowparkSQLException
27
25
  from snowflake.snowpark.functions import coalesce, col, lit, when
@@ -259,7 +257,7 @@ def submit_directory(
259
257
  dir_path: str,
260
258
  compute_pool: str,
261
259
  *,
262
- entrypoint: str,
260
+ entrypoint: Union[str, list[str]],
263
261
  stage_name: str,
264
262
  args: Optional[list[str]] = None,
265
263
  target_instances: int = 1,
@@ -274,7 +272,11 @@ def submit_directory(
274
272
  Args:
275
273
  dir_path: The path to the directory containing the job payload.
276
274
  compute_pool: The compute pool to use for the job.
277
- entrypoint: The relative path to the entry point script inside the source directory.
275
+ entrypoint: The entry point for job execution. Can be:
276
+ - A string path to the entry point script inside the source directory.
277
+ - A list of strings representing a custom command (e.g., ["arctic_training"])
278
+ which is passed through as-is without local resolution or validation.
279
+ This is useful for entrypoints that are installed via pip_requirements.
278
280
  stage_name: The name of the stage where the job payload will be uploaded.
279
281
  args: A list of arguments to pass to the job.
280
282
  target_instances: The number of nodes in the job. If none specified, create a single node job.
@@ -315,7 +317,7 @@ def submit_from_stage(
315
317
  source: str,
316
318
  compute_pool: str,
317
319
  *,
318
- entrypoint: str,
320
+ entrypoint: Union[str, list[str]],
319
321
  stage_name: str,
320
322
  args: Optional[list[str]] = None,
321
323
  target_instances: int = 1,
@@ -330,7 +332,11 @@ def submit_from_stage(
330
332
  Args:
331
333
  source: a stage path or a stage containing the job payload.
332
334
  compute_pool: The compute pool to use for the job.
333
- entrypoint: a stage path containing the entry point script inside the source directory.
335
+ entrypoint: The entry point for job execution. Can be:
336
+ - A string path to the entry point script inside the source directory.
337
+ - A list of strings representing a custom command (e.g., ["arctic_training"])
338
+ which is passed through as-is without local resolution or validation.
339
+ This is useful for entrypoints that are installed via pip_requirements.
334
340
  stage_name: The name of the stage where the job payload will be uploaded.
335
341
  args: A list of arguments to pass to the job.
336
342
  target_instances: The number of nodes in the job. If none specified, create a single node job.
@@ -375,7 +381,7 @@ def _submit_job(
375
381
  compute_pool: str,
376
382
  *,
377
383
  stage_name: str,
378
- entrypoint: Optional[str] = None,
384
+ entrypoint: Optional[Union[str, list[str]]] = None,
379
385
  args: Optional[list[str]] = None,
380
386
  target_instances: int = 1,
381
387
  pip_requirements: Optional[list[str]] = None,
@@ -392,7 +398,7 @@ def _submit_job(
392
398
  compute_pool: str,
393
399
  *,
394
400
  stage_name: str,
395
- entrypoint: Optional[str] = None,
401
+ entrypoint: Optional[Union[str, list[str]]] = None,
396
402
  args: Optional[list[str]] = None,
397
403
  target_instances: int = 1,
398
404
  pip_requirements: Optional[list[str]] = None,
@@ -424,7 +430,7 @@ def _submit_job(
424
430
  compute_pool: str,
425
431
  *,
426
432
  stage_name: str,
427
- entrypoint: Optional[str] = None,
433
+ entrypoint: Optional[Union[str, list[str]]] = None,
428
434
  args: Optional[list[str]] = None,
429
435
  target_instances: int = 1,
430
436
  session: Optional[snowpark.Session] = None,
@@ -437,7 +443,11 @@ def _submit_job(
437
443
  source: The file/directory path containing payload source code or a serializable Python callable.
438
444
  compute_pool: The compute pool to use for the job.
439
445
  stage_name: The name of the stage where the job payload will be uploaded.
440
- entrypoint: The entry point for the job execution. Required if source is a directory.
446
+ entrypoint: The entry point for the job execution. Can be:
447
+ - A string path to a Python script (required if source is a directory).
448
+ - A list of strings representing a custom command (e.g., ["arctic_training"])
449
+ which is passed through as-is without local resolution or validation.
450
+ This is useful for entrypoints that are installed via pip_requirements.
441
451
  args: A list of arguments to pass to the job.
442
452
  target_instances: The number of instances to use for the job. If none specified, single node job is created.
443
453
  session: The Snowpark session to use. If none specified, uses active session.
@@ -449,7 +459,6 @@ def _submit_job(
449
459
  Raises:
450
460
  ValueError: If database or schema value(s) are invalid
451
461
  RuntimeError: If schema is not specified in session context or job submission
452
- SnowparkSQLException: if failed to upload payload
453
462
  """
454
463
  session = _ensure_session(session)
455
464
 
@@ -481,7 +490,8 @@ def _submit_job(
481
490
  enable_metrics = kwargs.pop("enable_metrics", True)
482
491
  query_warehouse = kwargs.pop("query_warehouse", session.get_current_warehouse())
483
492
  imports = kwargs.pop("imports", None) or imports
484
- runtime_environment = kwargs.pop("runtime_environment", None)
493
+ # if the mljob is submitted from a notebook, we use the same image tag as the notebook
494
+ runtime_environment = kwargs.pop("runtime_environment", os.environ.get(constants.RUNTIME_IMAGE_TAG_ENV_VAR, None))
485
495
 
486
496
  # Warn if there are unknown kwargs
487
497
  if kwargs:
@@ -513,7 +523,7 @@ def _submit_job(
513
523
  try:
514
524
  # Upload payload
515
525
  uploaded_payload = payload_utils.JobPayload(
516
- source, entrypoint=entrypoint, pip_requirements=pip_requirements, additional_payloads=imports
526
+ source, entrypoint=entrypoint, pip_requirements=pip_requirements, imports=imports
517
527
  ).upload(session, stage_path)
518
528
  except SnowparkSQLException as e:
519
529
  if e.sql_error_code == 90106:
@@ -528,125 +538,36 @@ def _submit_job(
528
538
  ) from e
529
539
  raise
530
540
 
531
- if feature_flags.FeatureFlags.USE_SUBMIT_JOB_V2.is_enabled(default=True):
532
- # Add default env vars (extracted from spec_utils.generate_service_spec)
533
- combined_env_vars = {**uploaded_payload.env_vars, **(env_vars or {})}
541
+ combined_env_vars = {**uploaded_payload.env_vars, **(env_vars or {})}
534
542
 
535
- try:
536
- return _do_submit_job_v2(
537
- session=session,
538
- payload=uploaded_payload,
539
- args=args,
540
- env_vars=combined_env_vars,
541
- spec_overrides=spec_overrides,
542
- compute_pool=compute_pool,
543
- job_id=job_id,
544
- external_access_integrations=external_access_integrations,
545
- query_warehouse=query_warehouse,
546
- target_instances=target_instances,
547
- min_instances=min_instances,
548
- enable_metrics=enable_metrics,
549
- use_async=True,
550
- runtime_environment=runtime_environment,
551
- )
552
- except SnowparkSQLException as e:
553
- if not (e.sql_error_code == 90237 and sp_utils.is_in_stored_procedure()): # type: ignore[no-untyped-call]
554
- raise
555
- elif e.sql_error_code == 3001 and "schema" in str(e).lower():
556
- raise RuntimeError(
557
- "please grant privileges on schema before submitting a job, see",
558
- "https://docs.snowflake.com/en/developer-guide/snowflake-ml/ml-jobs/access-control-requirements"
559
- " for more details",
560
- ) from e
561
- # SNOW-2390287: SYSTEM$EXECUTE_ML_JOB() is erroneously blocked in owner's rights
562
- # stored procedures. This will be fixed in an upcoming release.
563
- logger.warning(
564
- "Job submission using V2 failed with error {}. Falling back to V1.".format(
565
- str(e).split("\n", 1)[0],
566
- )
567
- )
568
-
569
- # Fall back to v1
570
- # Generate service spec
571
- spec = spec_utils.generate_service_spec(
572
- session,
573
- compute_pool=compute_pool,
574
- payload=uploaded_payload,
575
- args=args,
576
- target_instances=target_instances,
577
- min_instances=min_instances,
578
- enable_metrics=enable_metrics,
579
- runtime_environment=runtime_environment,
580
- )
581
-
582
- # Generate spec overrides
583
- spec_overrides = spec_utils.generate_spec_overrides(
584
- environment_vars=env_vars,
585
- custom_overrides=spec_overrides,
586
- )
587
- if spec_overrides:
588
- spec = spec_utils.merge_patch(spec, spec_overrides, display_name="spec_overrides")
589
-
590
- return _do_submit_job_v1(
591
- session, spec, external_access_integrations, query_warehouse, target_instances, compute_pool, job_id
592
- )
593
-
594
-
595
- def _do_submit_job_v1(
596
- session: snowpark.Session,
597
- spec: dict[str, Any],
598
- external_access_integrations: list[str],
599
- query_warehouse: Optional[str],
600
- target_instances: int,
601
- compute_pool: str,
602
- job_id: str,
603
- ) -> jb.MLJob[Any]:
604
- """
605
- Generate the SQL query for job submission.
606
-
607
- Args:
608
- session: The Snowpark session to use.
609
- spec: The service spec for the job.
610
- external_access_integrations: The external access integrations for the job.
611
- query_warehouse: The query warehouse for the job.
612
- target_instances: The number of instances for the job.
613
- session: The Snowpark session to use.
614
- compute_pool: The compute pool to use for the job.
615
- job_id: The ID of the job.
616
-
617
- Returns:
618
- The job object.
619
- """
620
- query_template = textwrap.dedent(
621
- """\
622
- EXECUTE JOB SERVICE
623
- IN COMPUTE POOL IDENTIFIER(?)
624
- FROM SPECIFICATION $$
625
- {}
626
- $$
627
- NAME = IDENTIFIER(?)
628
- ASYNC = TRUE
629
- """
630
- )
631
- params: list[Any] = [compute_pool, job_id]
632
- query = query_template.format(yaml.dump(spec)).splitlines()
633
- if external_access_integrations:
634
- external_access_integration_list = ",".join(f"{e}" for e in external_access_integrations)
635
- query.append(f"EXTERNAL_ACCESS_INTEGRATIONS = ({external_access_integration_list})")
636
- if query_warehouse:
637
- query.append("QUERY_WAREHOUSE = IDENTIFIER(?)")
638
- params.append(query_warehouse)
639
- if target_instances > 1:
640
- query.append("REPLICAS = ?")
641
- params.append(target_instances)
642
-
643
- query_text = "\n".join(line for line in query if line)
644
- _ = query_helper.run_query(session, query_text, params=params)
645
-
646
- return get_job(job_id, session=session)
543
+ try:
544
+ return _do_submit_job(
545
+ session=session,
546
+ payload=uploaded_payload,
547
+ args=args,
548
+ env_vars=combined_env_vars,
549
+ spec_overrides=spec_overrides,
550
+ compute_pool=compute_pool,
551
+ job_id=job_id,
552
+ external_access_integrations=external_access_integrations,
553
+ query_warehouse=query_warehouse,
554
+ target_instances=target_instances,
555
+ min_instances=min_instances,
556
+ enable_metrics=enable_metrics,
557
+ use_async=True,
558
+ runtime_environment=runtime_environment,
559
+ )
560
+ except SnowparkSQLException as e:
561
+ if e.sql_error_code == 3001 and "schema" in str(e).lower():
562
+ raise RuntimeError(
563
+ "please grant privileges on schema before submitting a job, see",
564
+ "https://docs.snowflake.com/en/developer-guide/snowflake-ml/ml-jobs/access-control-requirements"
565
+ " for more details",
566
+ ) from e
567
+ raise
647
568
 
648
569
 
649
- def _do_submit_job_v2(
570
+ def _do_submit_job(
650
571
  session: snowpark.Session,
651
572
  payload: types.UploadedPayload,
652
573
  args: Optional[list[str]],
@@ -684,9 +605,7 @@ def _do_submit_job_v2(
684
605
  Returns:
685
606
  The job object.
686
607
  """
687
- args = [
688
- (payload.stage_path.joinpath(v).as_posix() if isinstance(v, PurePath) else v) for v in payload.entrypoint
689
- ] + (args or [])
608
+ args = [(v.as_posix() if isinstance(v, PurePath) else v) for v in payload.entrypoint] + (args or [])
690
609
  spec_options = {
691
610
  "STAGE_PATH": payload.stage_path.as_posix(),
692
611
  "ENTRYPOINT": ["/usr/local/bin/_entrypoint.sh"],
@@ -695,8 +614,8 @@ def _do_submit_job_v2(
695
614
  "ENABLE_METRICS": enable_metrics,
696
615
  "SPEC_OVERRIDES": spec_overrides,
697
616
  }
698
- # for the image tag or full image URL, we use that directly
699
617
  if runtime_environment:
618
+ # for the image tag or full image URL, we use that directly
700
619
  spec_options["RUNTIME"] = runtime_environment
701
620
  elif feature_flags.FeatureFlags.ENABLE_RUNTIME_VERSIONS.is_enabled():
702
621
  # when feature flag is enabled, we get the local python version and wrap it in a dict
@@ -710,6 +629,9 @@ def _do_submit_job_v2(
710
629
  "MIN_INSTANCES": min_instances,
711
630
  "ASYNC": use_async,
712
631
  }
632
+
633
+ if feature_flags.FeatureFlags.ENABLE_STAGE_MOUNT_V2.is_enabled(default=True):
634
+ spec_options["ENABLE_STAGE_MOUNT_V2"] = True
713
635
  if payload.payload_name:
714
636
  job_options["GENERATE_SUFFIX"] = True
715
637
  job_options = {k: v for k, v in job_options.items() if v is not None}
@@ -7,7 +7,7 @@ from snowflake.ml._internal import telemetry
7
7
  from snowflake.ml._internal.utils import identifier, mixins
8
8
 
9
9
  if TYPE_CHECKING:
10
- from snowflake.ml import dataset
10
+ from snowflake.ml.dataset import dataset
11
11
  from snowflake.ml.feature_store import feature_view
12
12
  from snowflake.ml.model._client.model import model_version_impl
13
13
 
@@ -2,16 +2,20 @@ import sys
2
2
  import warnings
3
3
 
4
4
  from snowflake.ml.model._client.model.batch_inference_specs import (
5
+ ColumnHandlingOptions,
6
+ FileEncoding,
5
7
  JobSpec,
6
8
  OutputSpec,
7
9
  SaveMode,
8
10
  )
9
11
  from snowflake.ml.model._client.model.model_impl import Model
10
12
  from snowflake.ml.model._client.model.model_version_impl import ExportMode, ModelVersion
13
+ from snowflake.ml.model.code_path import CodePath
11
14
  from snowflake.ml.model.models.huggingface_pipeline import HuggingFacePipelineModel
12
15
  from snowflake.ml.model.volatility import Volatility
13
16
 
14
17
  __all__ = [
18
+ "CodePath",
15
19
  "Model",
16
20
  "ModelVersion",
17
21
  "ExportMode",
@@ -20,6 +24,8 @@ __all__ = [
20
24
  "OutputSpec",
21
25
  "SaveMode",
22
26
  "Volatility",
27
+ "FileEncoding",
28
+ "ColumnHandlingOptions",
23
29
  ]
24
30
 
25
31
  _deprecation_warning_msg_for_3_9 = (
@@ -2,6 +2,7 @@ from enum import Enum
2
2
  from typing import Optional
3
3
 
4
4
  from pydantic import BaseModel
5
+ from typing_extensions import TypedDict
5
6
 
6
7
 
7
8
  class SaveMode(str, Enum):
@@ -18,6 +19,20 @@ class SaveMode(str, Enum):
18
19
  ERROR = "error"
19
20
 
20
21
 
22
+ class FileEncoding(str, Enum):
23
+ """The encoding of the file content that will be passed to the custom model."""
24
+
25
+ RAW_BYTES = "raw_bytes"
26
+ BASE64 = "base64"
27
+ BASE64_DATA_URL = "base64_data_url"
28
+
29
+
30
+ class ColumnHandlingOptions(TypedDict):
31
+ """Options for handling specific columns during run_batch for file I/O."""
32
+
33
+ encoding: FileEncoding
34
+
35
+
21
36
  class OutputSpec(BaseModel):
22
37
  """Specification for batch inference output.
23
38
 
@@ -74,7 +89,7 @@ class JobSpec(BaseModel):
74
89
  the memory of the node.
75
90
  gpu_requests (Optional[str]): The gpu limit for GPU based inference. Can be integer or
76
91
  string values. Use CPU if None.
77
- replicas (Optional[int]): Number of job replicas to run for high availability.
92
+ replicas (Optional[int]): Number of SPCS job nodes used for distributed inference.
78
93
  If not specified, defaults to 1 replica.
79
94
 
80
95
  Example: