snowflake-ml-python 1.22.0__py3-none-any.whl → 1.24.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. snowflake/ml/_internal/platform_capabilities.py +0 -4
  2. snowflake/ml/feature_store/__init__.py +2 -0
  3. snowflake/ml/feature_store/aggregation.py +367 -0
  4. snowflake/ml/feature_store/feature.py +366 -0
  5. snowflake/ml/feature_store/feature_store.py +234 -20
  6. snowflake/ml/feature_store/feature_view.py +189 -4
  7. snowflake/ml/feature_store/metadata_manager.py +425 -0
  8. snowflake/ml/feature_store/tile_sql_generator.py +1079 -0
  9. snowflake/ml/jobs/__init__.py +2 -0
  10. snowflake/ml/jobs/_utils/constants.py +1 -0
  11. snowflake/ml/jobs/_utils/payload_utils.py +38 -18
  12. snowflake/ml/jobs/_utils/query_helper.py +8 -1
  13. snowflake/ml/jobs/_utils/runtime_env_utils.py +117 -0
  14. snowflake/ml/jobs/_utils/stage_utils.py +2 -2
  15. snowflake/ml/jobs/_utils/types.py +22 -2
  16. snowflake/ml/jobs/job_definition.py +232 -0
  17. snowflake/ml/jobs/manager.py +16 -177
  18. snowflake/ml/model/__init__.py +4 -0
  19. snowflake/ml/model/_client/model/batch_inference_specs.py +38 -2
  20. snowflake/ml/model/_client/model/model_version_impl.py +120 -89
  21. snowflake/ml/model/_client/ops/model_ops.py +4 -26
  22. snowflake/ml/model/_client/ops/param_utils.py +124 -0
  23. snowflake/ml/model/_client/ops/service_ops.py +63 -23
  24. snowflake/ml/model/_client/service/model_deployment_spec.py +12 -5
  25. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -0
  26. snowflake/ml/model/_client/sql/service.py +25 -54
  27. snowflake/ml/model/_model_composer/model_method/infer_function.py_template +21 -3
  28. snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +21 -3
  29. snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +21 -3
  30. snowflake/ml/model/_model_composer/model_method/model_method.py +3 -1
  31. snowflake/ml/model/_packager/model_handlers/huggingface.py +74 -10
  32. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +121 -29
  33. snowflake/ml/model/_signatures/utils.py +130 -0
  34. snowflake/ml/model/openai_signatures.py +97 -0
  35. snowflake/ml/registry/_manager/model_parameter_reconciler.py +1 -1
  36. snowflake/ml/version.py +1 -1
  37. {snowflake_ml_python-1.22.0.dist-info → snowflake_ml_python-1.24.0.dist-info}/METADATA +105 -1
  38. {snowflake_ml_python-1.22.0.dist-info → snowflake_ml_python-1.24.0.dist-info}/RECORD +41 -35
  39. {snowflake_ml_python-1.22.0.dist-info → snowflake_ml_python-1.24.0.dist-info}/WHEEL +1 -1
  40. snowflake/ml/experiment/callback/__init__.py +0 -0
  41. {snowflake_ml_python-1.22.0.dist-info → snowflake_ml_python-1.24.0.dist-info}/licenses/LICENSE.txt +0 -0
  42. {snowflake_ml_python-1.22.0.dist-info → snowflake_ml_python-1.24.0.dist-info}/top_level.txt +0 -0
@@ -1,11 +1,5 @@
1
- import json
2
1
  import logging
3
- import os
4
- import pathlib
5
- import sys
6
- from pathlib import PurePath
7
2
  from typing import Any, Callable, Optional, TypeVar, Union, cast, overload
8
- from uuid import uuid4
9
3
 
10
4
  import pandas as pd
11
5
 
@@ -13,13 +7,8 @@ from snowflake import snowpark
13
7
  from snowflake.ml._internal import telemetry
14
8
  from snowflake.ml._internal.utils import identifier
15
9
  from snowflake.ml.jobs import job as jb
16
- from snowflake.ml.jobs._utils import (
17
- constants,
18
- feature_flags,
19
- payload_utils,
20
- query_helper,
21
- types,
22
- )
10
+ from snowflake.ml.jobs._utils import query_helper
11
+ from snowflake.ml.jobs.job_definition import MLJobDefinition
23
12
  from snowflake.snowpark.context import get_active_session
24
13
  from snowflake.snowpark.exceptions import SnowparkSQLException
25
14
  from snowflake.snowpark.functions import coalesce, col, lit, when
@@ -457,7 +446,6 @@ def _submit_job(
457
446
  An object representing the submitted job.
458
447
 
459
448
  Raises:
460
- ValueError: If database or schema value(s) are invalid
461
449
  RuntimeError: If schema is not specified in session context or job submission
462
450
  """
463
451
  session = _ensure_session(session)
@@ -469,94 +457,30 @@ def _submit_job(
469
457
  )
470
458
  target_instances = max(target_instances, kwargs.pop("num_instances"))
471
459
 
472
- imports = None
473
460
  if "additional_payloads" in kwargs:
474
461
  logger.warning(
475
462
  "'additional_payloads' is deprecated and will be removed in a future release. Use 'imports' instead."
476
463
  )
477
- imports = kwargs.pop("additional_payloads")
464
+ if "imports" not in kwargs:
465
+ imports = kwargs.pop("additional_payloads", None)
466
+ kwargs.update({"imports": imports})
478
467
 
479
468
  if "runtime_environment" in kwargs:
480
469
  logger.warning("'runtime_environment' is in private preview since 1.15.0, do not use it in production.")
481
470
 
482
- # Use kwargs for less common optional parameters
483
- database = kwargs.pop("database", None)
484
- schema = kwargs.pop("schema", None)
485
- min_instances = kwargs.pop("min_instances", target_instances)
486
- pip_requirements = kwargs.pop("pip_requirements", None)
487
- external_access_integrations = kwargs.pop("external_access_integrations", None)
488
- env_vars = kwargs.pop("env_vars", None)
489
- spec_overrides = kwargs.pop("spec_overrides", None)
490
- enable_metrics = kwargs.pop("enable_metrics", True)
491
- query_warehouse = kwargs.pop("query_warehouse", session.get_current_warehouse())
492
- imports = kwargs.pop("imports", None) or imports
493
- # if the mljob is submitted from a notebook, we use the same image tag as the notebook
494
- runtime_environment = kwargs.pop("runtime_environment", os.environ.get(constants.RUNTIME_IMAGE_TAG_ENV_VAR, None))
495
-
496
- # Warn if there are unknown kwargs
497
- if kwargs:
498
- logger.warning(f"Ignoring unknown kwargs: {kwargs.keys()}")
499
-
500
- # Validate parameters
501
- if database and not schema:
502
- raise ValueError("Schema must be specified if database is specified.")
503
- if target_instances < 1:
504
- raise ValueError("target_instances must be greater than 0.")
505
- if not (0 < min_instances <= target_instances):
506
- raise ValueError("min_instances must be greater than 0 and less than or equal to target_instances.")
507
- if min_instances > 1:
508
- # Validate min_instances against compute pool max_nodes
509
- pool_info = jb._get_compute_pool_info(session, compute_pool)
510
- max_nodes = int(pool_info["max_nodes"])
511
- if min_instances > max_nodes:
512
- raise ValueError(
513
- f"The requested min_instances ({min_instances}) exceeds the max_nodes ({max_nodes}) "
514
- f"of compute pool '{compute_pool}'. Reduce min_instances or increase max_nodes."
515
- )
516
-
517
- job_name = f"{JOB_ID_PREFIX}{str(uuid4()).replace('-', '_').upper()}"
518
- job_id = identifier.get_schema_level_object_identifier(database, schema, job_name)
519
- stage_path_parts = identifier.parse_snowflake_stage_path(stage_name.lstrip("@"))
520
- stage_name = f"@{'.'.join(filter(None, stage_path_parts[:3]))}"
521
- stage_path = pathlib.PurePosixPath(f"{stage_name}{stage_path_parts[-1].rstrip('/')}/{job_name}")
522
-
523
- try:
524
- # Upload payload
525
- uploaded_payload = payload_utils.JobPayload(
526
- source, entrypoint=entrypoint, pip_requirements=pip_requirements, imports=imports
527
- ).upload(session, stage_path)
528
- except SnowparkSQLException as e:
529
- if e.sql_error_code == 90106:
530
- raise RuntimeError(
531
- "Please specify a schema, either in the session context or as a parameter in the job submission"
532
- )
533
- elif e.sql_error_code == 3001 and "schema" in str(e).lower():
534
- raise RuntimeError(
535
- "please grant privileges on schema before submitting a job, see",
536
- "https://docs.snowflake.com/en/developer-guide/snowflake-ml/ml-jobs/access-control-requirements",
537
- " for more details",
538
- ) from e
539
- raise
540
-
541
- combined_env_vars = {**uploaded_payload.env_vars, **(env_vars or {})}
471
+ job_definition = MLJobDefinition.register(
472
+ source,
473
+ compute_pool,
474
+ stage_name,
475
+ session or get_active_session(),
476
+ entrypoint,
477
+ target_instances,
478
+ generate_suffix=True,
479
+ **kwargs,
480
+ )
542
481
 
543
482
  try:
544
- return _do_submit_job(
545
- session=session,
546
- payload=uploaded_payload,
547
- args=args,
548
- env_vars=combined_env_vars,
549
- spec_overrides=spec_overrides,
550
- compute_pool=compute_pool,
551
- job_id=job_id,
552
- external_access_integrations=external_access_integrations,
553
- query_warehouse=query_warehouse,
554
- target_instances=target_instances,
555
- min_instances=min_instances,
556
- enable_metrics=enable_metrics,
557
- use_async=True,
558
- runtime_environment=runtime_environment,
559
- )
483
+ return job_definition(*(args or []))
560
484
  except SnowparkSQLException as e:
561
485
  if e.sql_error_code == 3001 and "schema" in str(e).lower():
562
486
  raise RuntimeError(
@@ -567,91 +491,6 @@ def _submit_job(
567
491
  raise
568
492
 
569
493
 
570
- def _do_submit_job(
571
- session: snowpark.Session,
572
- payload: types.UploadedPayload,
573
- args: Optional[list[str]],
574
- env_vars: dict[str, str],
575
- spec_overrides: dict[str, Any],
576
- compute_pool: str,
577
- job_id: Optional[str] = None,
578
- external_access_integrations: Optional[list[str]] = None,
579
- query_warehouse: Optional[str] = None,
580
- target_instances: int = 1,
581
- min_instances: int = 1,
582
- enable_metrics: bool = True,
583
- use_async: bool = True,
584
- runtime_environment: Optional[str] = None,
585
- ) -> jb.MLJob[Any]:
586
- """
587
- Generate the SQL query for job submission.
588
-
589
- Args:
590
- session: The Snowpark session to use.
591
- payload: The uploaded job payload.
592
- args: Arguments to pass to the entrypoint script.
593
- env_vars: Environment variables to set in the job container.
594
- spec_overrides: Custom service specification overrides.
595
- compute_pool: The compute pool to use for job execution.
596
- job_id: The ID of the job.
597
- external_access_integrations: Optional list of external access integrations.
598
- query_warehouse: Optional query warehouse to use.
599
- target_instances: Number of instances for multi-node job.
600
- min_instances: Minimum number of instances required to start the job.
601
- enable_metrics: Whether to enable platform metrics for the job.
602
- use_async: Whether to run the job asynchronously.
603
- runtime_environment: image tag or full image URL to use for the job.
604
-
605
- Returns:
606
- The job object.
607
- """
608
- args = [(v.as_posix() if isinstance(v, PurePath) else v) for v in payload.entrypoint] + (args or [])
609
- spec_options = {
610
- "STAGE_PATH": payload.stage_path.as_posix(),
611
- "ENTRYPOINT": ["/usr/local/bin/_entrypoint.sh"],
612
- "ARGS": args,
613
- "ENV_VARS": env_vars,
614
- "ENABLE_METRICS": enable_metrics,
615
- "SPEC_OVERRIDES": spec_overrides,
616
- }
617
- if runtime_environment:
618
- # for the image tag or full image URL, we use that directly
619
- spec_options["RUNTIME"] = runtime_environment
620
- elif feature_flags.FeatureFlags.ENABLE_RUNTIME_VERSIONS.is_enabled():
621
- # when feature flag is enabled, we get the local python version and wrap it in a dict
622
- # in system function, we can know whether it is python version or image tag or full image URL through the format
623
- spec_options["RUNTIME"] = json.dumps({"pythonVersion": f"{sys.version_info.major}.{sys.version_info.minor}"})
624
-
625
- job_options = {
626
- "EXTERNAL_ACCESS_INTEGRATIONS": external_access_integrations,
627
- "QUERY_WAREHOUSE": query_warehouse,
628
- "TARGET_INSTANCES": target_instances,
629
- "MIN_INSTANCES": min_instances,
630
- "ASYNC": use_async,
631
- }
632
-
633
- if feature_flags.FeatureFlags.ENABLE_STAGE_MOUNT_V2.is_enabled(default=True):
634
- spec_options["ENABLE_STAGE_MOUNT_V2"] = True
635
- if payload.payload_name:
636
- job_options["GENERATE_SUFFIX"] = True
637
- job_options = {k: v for k, v in job_options.items() if v is not None}
638
-
639
- query_template = "CALL SYSTEM$EXECUTE_ML_JOB(?, ?, ?, ?)"
640
- if job_id:
641
- database, schema, _ = identifier.parse_schema_level_object_identifier(job_id)
642
- params = [
643
- job_id
644
- if payload.payload_name is None
645
- else identifier.get_schema_level_object_identifier(database, schema, payload.payload_name) + "_",
646
- compute_pool,
647
- json.dumps(spec_options),
648
- json.dumps(job_options),
649
- ]
650
- actual_job_id = query_helper.run_query(session, query_template, params=params)[0][0]
651
-
652
- return get_job(actual_job_id, session=session)
653
-
654
-
655
494
  def _ensure_session(session: Optional[snowpark.Session]) -> snowpark.Session:
656
495
  try:
657
496
  session = session or get_active_session()
@@ -4,6 +4,8 @@ import warnings
4
4
  from snowflake.ml.model._client.model.batch_inference_specs import (
5
5
  ColumnHandlingOptions,
6
6
  FileEncoding,
7
+ InputFormat,
8
+ InputSpec,
7
9
  JobSpec,
8
10
  OutputSpec,
9
11
  SaveMode,
@@ -20,6 +22,8 @@ __all__ = [
20
22
  "ModelVersion",
21
23
  "ExportMode",
22
24
  "HuggingFacePipelineModel",
25
+ "InputSpec",
26
+ "InputFormat",
23
27
  "JobSpec",
24
28
  "OutputSpec",
25
29
  "SaveMode",
@@ -1,5 +1,5 @@
1
1
  from enum import Enum
2
- from typing import Optional
2
+ from typing import Any, Optional
3
3
 
4
4
  from pydantic import BaseModel
5
5
  from typing_extensions import TypedDict
@@ -19,6 +19,12 @@ class SaveMode(str, Enum):
19
19
  ERROR = "error"
20
20
 
21
21
 
22
+ class InputFormat(str, Enum):
23
+ """The format of the input column data."""
24
+
25
+ FULL_STAGE_PATH = "full_stage_path"
26
+
27
+
22
28
  class FileEncoding(str, Enum):
23
29
  """The encoding of the file content that will be passed to the custom model."""
24
30
 
@@ -30,7 +36,37 @@ class FileEncoding(str, Enum):
30
36
  class ColumnHandlingOptions(TypedDict):
31
37
  """Options for handling specific columns during run_batch for file I/O."""
32
38
 
33
- encoding: FileEncoding
39
+ input_format: InputFormat
40
+ convert_to: FileEncoding
41
+
42
+
43
+ class InputSpec(BaseModel):
44
+ """Specification for batch inference input options.
45
+
46
+ Defines optional configuration for processing input data during batch inference.
47
+
48
+ Attributes:
49
+ params (Optional[dict[str, Any]]): Optional dictionary of model inference parameters
50
+ (e.g., temperature, top_k for LLMs). These are passed as keyword arguments to the
51
+ model's inference method. Defaults to None.
52
+ column_handling (Optional[dict[str, ColumnHandlingOptions]]): Optional dictionary
53
+ specifying how to handle specific columns during file I/O. Maps column names to their
54
+ input format and file encoding configuration.
55
+
56
+ Example:
57
+ >>> input_spec = InputSpec(
58
+ ... params={"temperature": 0.7, "top_k": 50},
59
+ ... column_handling={
60
+ ... "image_col": {
61
+ ... "input_format": InputFormat.FULL_STAGE_PATH,
62
+ ... "convert_to": FileEncoding.BASE64
63
+ ... }
64
+ ... }
65
+ ... )
66
+ """
67
+
68
+ params: Optional[dict[str, Any]] = None
69
+ column_handling: Optional[dict[str, ColumnHandlingOptions]] = None
34
70
 
35
71
 
36
72
  class OutputSpec(BaseModel):