snowflake-ml-python 1.6.0__py3-none-any.whl → 1.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. snowflake/cortex/_complete.py +7 -33
  2. snowflake/ml/_internal/env_utils.py +11 -5
  3. snowflake/ml/_internal/exceptions/modeling_error_messages.py +4 -1
  4. snowflake/ml/_internal/telemetry.py +14 -0
  5. snowflake/ml/_internal/utils/pkg_version_utils.py +8 -22
  6. snowflake/ml/data/_internal/arrow_ingestor.py +66 -10
  7. snowflake/ml/data/data_connector.py +59 -6
  8. snowflake/ml/data/data_ingestor.py +18 -1
  9. snowflake/ml/data/{_internal/ingestor_utils.py → ingestor_utils.py} +5 -1
  10. snowflake/ml/data/torch_dataset.py +33 -0
  11. snowflake/ml/dataset/dataset_metadata.py +3 -1
  12. snowflake/ml/dataset/dataset_reader.py +9 -3
  13. snowflake/ml/feature_store/examples/airline_features/entities.py +16 -0
  14. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +31 -0
  15. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +42 -0
  16. snowflake/ml/feature_store/examples/airline_features/source.yaml +7 -0
  17. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +10 -4
  18. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +6 -0
  19. snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +3 -0
  20. snowflake/ml/feature_store/examples/example_helper.py +69 -31
  21. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +3 -3
  22. snowflake/ml/feature_store/examples/new_york_taxi_features/features/{dropoff_features.py → location_features.py} +14 -9
  23. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +36 -0
  24. snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +5 -1
  25. snowflake/ml/feature_store/examples/source_data/airline.yaml +4 -0
  26. snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +1 -1
  27. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +3 -3
  28. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +13 -6
  29. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +8 -5
  30. snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +3 -0
  31. snowflake/ml/feature_store/feature_store.py +59 -24
  32. snowflake/ml/feature_store/feature_view.py +148 -4
  33. snowflake/ml/model/_client/model/model_impl.py +11 -2
  34. snowflake/ml/model/_client/model/model_version_impl.py +171 -20
  35. snowflake/ml/model/_client/ops/model_ops.py +105 -27
  36. snowflake/ml/model/_client/ops/service_ops.py +121 -0
  37. snowflake/ml/model/_client/service/model_deployment_spec.py +95 -0
  38. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +31 -0
  39. snowflake/ml/model/_client/sql/model_version.py +13 -4
  40. snowflake/ml/model/_client/sql/service.py +129 -0
  41. snowflake/ml/model/_model_composer/model_composer.py +3 -0
  42. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +10 -2
  43. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +3 -0
  44. snowflake/ml/model/_packager/model_env/model_env.py +7 -2
  45. snowflake/ml/model/_packager/model_handlers/_base.py +29 -12
  46. snowflake/ml/model/_packager/model_handlers/catboost.py +19 -12
  47. snowflake/ml/model/_packager/model_handlers/custom.py +6 -2
  48. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +9 -5
  49. snowflake/ml/model/_packager/model_handlers/lightgbm.py +27 -18
  50. snowflake/ml/model/_packager/model_handlers/llm.py +7 -3
  51. snowflake/ml/model/_packager/model_handlers/mlflow.py +8 -3
  52. snowflake/ml/model/_packager/model_handlers/pytorch.py +8 -3
  53. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -3
  54. snowflake/ml/model/_packager/model_handlers/sklearn.py +87 -4
  55. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +7 -2
  56. snowflake/ml/model/_packager/model_handlers/tensorflow.py +9 -4
  57. snowflake/ml/model/_packager/model_handlers/torchscript.py +8 -3
  58. snowflake/ml/model/_packager/model_handlers/xgboost.py +25 -16
  59. snowflake/ml/model/_packager/model_meta/model_meta.py +32 -2
  60. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +19 -0
  61. snowflake/ml/model/_packager/model_packager.py +2 -1
  62. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
  63. snowflake/ml/model/type_hints.py +1 -3
  64. snowflake/ml/modeling/framework/base.py +28 -19
  65. snowflake/ml/modeling/pipeline/pipeline.py +3 -0
  66. snowflake/ml/registry/_manager/model_manager.py +16 -2
  67. snowflake/ml/utils/sql_client.py +22 -0
  68. snowflake/ml/version.py +1 -1
  69. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.1.dist-info}/METADATA +35 -2
  70. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.1.dist-info}/RECORD +73 -62
  71. snowflake/ml/feature_store/examples/new_york_taxi_features/features/pickup_features.py +0 -58
  72. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.1.dist-info}/LICENSE.txt +0 -0
  73. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.1.dist-info}/WHEEL +0 -0
  74. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.1.dist-info}/top_level.txt +0 -0
@@ -2,7 +2,7 @@ import enum
2
2
  import pathlib
3
3
  import tempfile
4
4
  import warnings
5
- from typing import Any, Callable, Dict, List, Optional, Union
5
+ from typing import Any, Callable, Dict, List, Optional, Union, overload
6
6
 
7
7
  import pandas as pd
8
8
 
@@ -10,7 +10,7 @@ from snowflake.ml._internal import telemetry
10
10
  from snowflake.ml._internal.utils import sql_identifier
11
11
  from snowflake.ml.lineage import lineage_node
12
12
  from snowflake.ml.model import type_hints as model_types
13
- from snowflake.ml.model._client.ops import metadata_ops, model_ops
13
+ from snowflake.ml.model._client.ops import metadata_ops, model_ops, service_ops
14
14
  from snowflake.ml.model._model_composer import model_composer
15
15
  from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
16
16
  from snowflake.ml.model._packager.model_handlers import snowmlmodel
@@ -29,6 +29,7 @@ class ModelVersion(lineage_node.LineageNode):
29
29
  """Model Version Object representing a specific version of the model that could be run."""
30
30
 
31
31
  _model_ops: model_ops.ModelOperator
32
+ _service_ops: service_ops.ServiceOperator
32
33
  _model_name: sql_identifier.SqlIdentifier
33
34
  _version_name: sql_identifier.SqlIdentifier
34
35
  _functions: List[model_manifest_schema.ModelFunctionInfo]
@@ -41,11 +42,13 @@ class ModelVersion(lineage_node.LineageNode):
41
42
  cls,
42
43
  model_ops: model_ops.ModelOperator,
43
44
  *,
45
+ service_ops: service_ops.ServiceOperator,
44
46
  model_name: sql_identifier.SqlIdentifier,
45
47
  version_name: sql_identifier.SqlIdentifier,
46
48
  ) -> "ModelVersion":
47
49
  self: "ModelVersion" = object.__new__(cls)
48
50
  self._model_ops = model_ops
51
+ self._service_ops = service_ops
49
52
  self._model_name = model_name
50
53
  self._version_name = version_name
51
54
  self._functions = self._get_functions()
@@ -65,6 +68,7 @@ class ModelVersion(lineage_node.LineageNode):
65
68
  return False
66
69
  return (
67
70
  self._model_ops == __value._model_ops
71
+ and self._service_ops == __value._service_ops
68
72
  and self._model_name == __value._model_name
69
73
  and self._version_name == __value._version_name
70
74
  )
@@ -318,10 +322,7 @@ class ModelVersion(lineage_node.LineageNode):
318
322
  """
319
323
  return self._functions
320
324
 
321
- @telemetry.send_api_usage_telemetry(
322
- project=_TELEMETRY_PROJECT,
323
- subproject=_TELEMETRY_SUBPROJECT,
324
- )
325
+ @overload
325
326
  def run(
326
327
  self,
327
328
  X: Union[pd.DataFrame, dataframe.DataFrame],
@@ -339,6 +340,53 @@ class ModelVersion(lineage_node.LineageNode):
339
340
  partition_column: The partition column name to partition by.
340
341
  strict_input_validation: Enable stricter validation for the input data. This will result value range based
341
342
  type validation to make sure your input data won't overflow when providing to the model.
343
+ """
344
+ ...
345
+
346
+ @overload
347
+ def run(
348
+ self,
349
+ X: Union[pd.DataFrame, dataframe.DataFrame],
350
+ *,
351
+ service_name: str,
352
+ function_name: Optional[str] = None,
353
+ strict_input_validation: bool = False,
354
+ ) -> Union[pd.DataFrame, dataframe.DataFrame]:
355
+ """Invoke a method in a model version object via a service.
356
+
357
+ Args:
358
+ X: The input data, which could be a pandas DataFrame or Snowpark DataFrame.
359
+ service_name: The service name.
360
+ function_name: The function name to run. It is the name used to call a function in SQL.
361
+ strict_input_validation: Enable stricter validation for the input data. This will result value range based
362
+ type validation to make sure your input data won't overflow when providing to the model.
363
+ """
364
+ ...
365
+
366
+ @telemetry.send_api_usage_telemetry(
367
+ project=_TELEMETRY_PROJECT,
368
+ subproject=_TELEMETRY_SUBPROJECT,
369
+ func_params_to_log=["function_name", "service_name"],
370
+ )
371
+ def run(
372
+ self,
373
+ X: Union[pd.DataFrame, "dataframe.DataFrame"],
374
+ *,
375
+ service_name: Optional[str] = None,
376
+ function_name: Optional[str] = None,
377
+ partition_column: Optional[str] = None,
378
+ strict_input_validation: bool = False,
379
+ ) -> Union[pd.DataFrame, "dataframe.DataFrame"]:
380
+ """Invoke a method in a model version object via the warehouse or a service.
381
+
382
+ Args:
383
+ X: The input data, which could be a pandas DataFrame or Snowpark DataFrame.
384
+ service_name: The service name. If None, the function is invoked via the warehouse. Otherwise, the function
385
+ is invoked via the given service.
386
+ function_name: The function name to run. It is the name used to call a function in SQL.
387
+ partition_column: The partition column name to partition by.
388
+ strict_input_validation: Enable stricter validation for the input data. This will result value range based
389
+ type validation to make sure your input data won't overflow when providing to the model.
342
390
 
343
391
  Raises:
344
392
  ValueError: When no method with the corresponding name is available.
@@ -375,23 +423,37 @@ class ModelVersion(lineage_node.LineageNode):
375
423
  elif len(functions) != 1:
376
424
  raise ValueError(
377
425
  f"There are more than 1 target methods available in the model {self.fully_qualified_model_name}"
378
- f" version {self.version_name}. Please specify a `method_name` when calling the `run` method."
426
+ f" version {self.version_name}. Please specify a `function_name` when calling the `run` method."
379
427
  )
380
428
  else:
381
429
  target_function_info = functions[0]
382
- return self._model_ops.invoke_method(
383
- method_name=sql_identifier.SqlIdentifier(target_function_info["name"]),
384
- method_function_type=target_function_info["target_method_function_type"],
385
- signature=target_function_info["signature"],
386
- X=X,
387
- database_name=None,
388
- schema_name=None,
389
- model_name=self._model_name,
390
- version_name=self._version_name,
391
- strict_input_validation=strict_input_validation,
392
- partition_column=partition_column,
393
- statement_params=statement_params,
394
- )
430
+
431
+ if service_name:
432
+ return self._model_ops.invoke_method(
433
+ method_name=sql_identifier.SqlIdentifier(target_function_info["name"]),
434
+ signature=target_function_info["signature"],
435
+ X=X,
436
+ database_name=None,
437
+ schema_name=None,
438
+ service_name=sql_identifier.SqlIdentifier(service_name),
439
+ strict_input_validation=strict_input_validation,
440
+ statement_params=statement_params,
441
+ )
442
+ else:
443
+ return self._model_ops.invoke_method(
444
+ method_name=sql_identifier.SqlIdentifier(target_function_info["name"]),
445
+ method_function_type=target_function_info["target_method_function_type"],
446
+ signature=target_function_info["signature"],
447
+ X=X,
448
+ database_name=None,
449
+ schema_name=None,
450
+ model_name=self._model_name,
451
+ version_name=self._version_name,
452
+ strict_input_validation=strict_input_validation,
453
+ partition_column=partition_column,
454
+ statement_params=statement_params,
455
+ is_partitioned=target_function_info["is_partitioned"],
456
+ )
395
457
 
396
458
  @telemetry.send_api_usage_telemetry(
397
459
  project=_TELEMETRY_PROJECT, subproject=_TELEMETRY_SUBPROJECT, func_params_to_log=["export_mode"]
@@ -525,9 +587,98 @@ class ModelVersion(lineage_node.LineageNode):
525
587
  database_name=database_name_id,
526
588
  schema_name=schema_name_id,
527
589
  ),
590
+ service_ops=service_ops.ServiceOperator(
591
+ session,
592
+ database_name=database_name_id,
593
+ schema_name=schema_name_id,
594
+ ),
528
595
  model_name=model_name_id,
529
596
  version_name=sql_identifier.SqlIdentifier(version),
530
597
  )
531
598
 
599
+ @telemetry.send_api_usage_telemetry(
600
+ project=_TELEMETRY_PROJECT,
601
+ subproject=_TELEMETRY_SUBPROJECT,
602
+ func_params_to_log=[
603
+ "service_name",
604
+ "image_build_compute_pool",
605
+ "service_compute_pool",
606
+ "image_repo_database",
607
+ "image_repo_schema",
608
+ "image_repo",
609
+ "image_name",
610
+ "gpu_requests",
611
+ ],
612
+ )
613
+ def create_service(
614
+ self,
615
+ *,
616
+ service_name: str,
617
+ image_build_compute_pool: Optional[str] = None,
618
+ service_compute_pool: str,
619
+ image_repo: str,
620
+ image_name: Optional[str] = None,
621
+ ingress_enabled: bool = False,
622
+ min_instances: int = 1,
623
+ max_instances: int = 1,
624
+ gpu_requests: Optional[str] = None,
625
+ force_rebuild: bool = False,
626
+ build_external_access_integration: str,
627
+ ) -> str:
628
+ """Create an inference service with the given spec.
629
+
630
+ Args:
631
+ service_name: The name of the service, can be fully qualified. If not fully qualified, the database or
632
+ schema of the model will be used.
633
+ image_build_compute_pool: The name of the compute pool used to build the model inference image. Use
634
+ the service compute pool if None.
635
+ service_compute_pool: The name of the compute pool used to run the inference service.
636
+ image_repo: The name of the image repository, can be fully qualified. If not fully qualified, the database
637
+ or schema of the model will be used.
638
+ image_name: The name of the model inference image. Use a generated name if None.
639
+ ingress_enabled: Whether to enable ingress.
640
+ min_instances: The minimum number of inference service instances to run.
641
+ max_instances: The maximum number of inference service instances to run.
642
+ gpu_requests: The gpu limit for GPU based inference. Can be integer, fractional or string values. Use CPU
643
+ if None.
644
+ force_rebuild: Whether to force a model inference image rebuild.
645
+ build_external_access_integration: The external access integration for image build.
646
+
647
+ Returns:
648
+ The service name.
649
+ """
650
+ statement_params = telemetry.get_statement_params(
651
+ project=_TELEMETRY_PROJECT,
652
+ subproject=_TELEMETRY_SUBPROJECT,
653
+ )
654
+ service_db_id, service_schema_id, service_id = sql_identifier.parse_fully_qualified_name(service_name)
655
+ image_repo_db_id, image_repo_schema_id, image_repo_id = sql_identifier.parse_fully_qualified_name(image_repo)
656
+ return self._service_ops.create_service(
657
+ database_name=None,
658
+ schema_name=None,
659
+ model_name=self._model_name,
660
+ version_name=self._version_name,
661
+ service_database_name=service_db_id,
662
+ service_schema_name=service_schema_id,
663
+ service_name=service_id,
664
+ image_build_compute_pool_name=(
665
+ sql_identifier.SqlIdentifier(image_build_compute_pool)
666
+ if image_build_compute_pool
667
+ else sql_identifier.SqlIdentifier(service_compute_pool)
668
+ ),
669
+ service_compute_pool_name=sql_identifier.SqlIdentifier(service_compute_pool),
670
+ image_repo_database_name=image_repo_db_id,
671
+ image_repo_schema_name=image_repo_schema_id,
672
+ image_repo_name=image_repo_id,
673
+ image_name=sql_identifier.SqlIdentifier(image_name) if image_name else None,
674
+ ingress_enabled=ingress_enabled,
675
+ min_instances=min_instances,
676
+ max_instances=max_instances,
677
+ gpu_requests=gpu_requests,
678
+ force_rebuild=force_rebuild,
679
+ build_external_access_integration=sql_identifier.SqlIdentifier(build_external_access_integration),
680
+ statement_params=statement_params,
681
+ )
682
+
532
683
 
533
684
  lineage_node.DOMAIN_LINEAGE_REGISTRY["model"] = ModelVersion
@@ -2,7 +2,7 @@ import os
2
2
  import pathlib
3
3
  import tempfile
4
4
  import warnings
5
- from typing import Any, Dict, List, Literal, Optional, Union, cast
5
+ from typing import Any, Dict, List, Literal, Optional, Union, cast, overload
6
6
 
7
7
  import yaml
8
8
 
@@ -12,6 +12,7 @@ from snowflake.ml.model._client.ops import metadata_ops
12
12
  from snowflake.ml.model._client.sql import (
13
13
  model as model_sql,
14
14
  model_version as model_version_sql,
15
+ service as service_sql,
15
16
  stage as stage_sql,
16
17
  tag as tag_sql,
17
18
  )
@@ -21,7 +22,7 @@ from snowflake.ml.model._model_composer.model_manifest import (
21
22
  model_manifest_schema,
22
23
  )
23
24
  from snowflake.ml.model._packager.model_env import model_env
24
- from snowflake.ml.model._packager.model_meta import model_meta
25
+ from snowflake.ml.model._packager.model_meta import model_meta, model_meta_schema
25
26
  from snowflake.ml.model._packager.model_runtime import model_runtime
26
27
  from snowflake.ml.model._signatures import snowpark_handler
27
28
  from snowflake.snowpark import dataframe, row, session
@@ -60,6 +61,11 @@ class ModelOperator:
60
61
  database_name=database_name,
61
62
  schema_name=schema_name,
62
63
  )
64
+ self._service_client = service_sql.ServiceSQLClient(
65
+ session,
66
+ database_name=database_name,
67
+ schema_name=schema_name,
68
+ )
63
69
  self._metadata_ops = metadata_ops.MetadataOperator(
64
70
  session,
65
71
  database_name=database_name,
@@ -597,16 +603,38 @@ class ModelOperator:
597
603
  function_names, list(signatures.keys())
598
604
  )
599
605
 
600
- return [
601
- model_manifest_schema.ModelFunctionInfo(
602
- name=function_name.identifier(),
603
- target_method=function_name_mapping[function_name],
604
- target_method_function_type=function_type,
605
- signature=model_signature.ModelSignature.from_dict(signatures[function_name_mapping[function_name]]),
606
+ model_func_info = []
607
+
608
+ for function_name, function_type in function_names_and_types:
609
+
610
+ target_method = function_name_mapping[function_name]
611
+
612
+ is_partitioned = False
613
+ if function_type == model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value:
614
+ # better to set default True here because worse case it will be slow but not error out
615
+ is_partitioned = (
616
+ (
617
+ model_spec["function_properties"]
618
+ .get(target_method, {})
619
+ .get(model_meta_schema.FunctionProperties.PARTITIONED.value, True)
620
+ )
621
+ if "function_properties" in model_spec
622
+ else True
623
+ )
624
+
625
+ model_func_info.append(
626
+ model_manifest_schema.ModelFunctionInfo(
627
+ name=function_name.identifier(),
628
+ target_method=target_method,
629
+ target_method_function_type=function_type,
630
+ signature=model_signature.ModelSignature.from_dict(signatures[target_method]),
631
+ is_partitioned=is_partitioned,
632
+ )
606
633
  )
607
- for function_name, function_type in function_names_and_types
608
- ]
609
634
 
635
+ return model_func_info
636
+
637
+ @overload
610
638
  def invoke_method(
611
639
  self,
612
640
  *,
@@ -621,6 +649,41 @@ class ModelOperator:
621
649
  strict_input_validation: bool = False,
622
650
  partition_column: Optional[sql_identifier.SqlIdentifier] = None,
623
651
  statement_params: Optional[Dict[str, str]] = None,
652
+ is_partitioned: Optional[bool] = None,
653
+ ) -> Union[type_hints.SupportedDataType, dataframe.DataFrame]:
654
+ ...
655
+
656
+ @overload
657
+ def invoke_method(
658
+ self,
659
+ *,
660
+ method_name: sql_identifier.SqlIdentifier,
661
+ signature: model_signature.ModelSignature,
662
+ X: Union[type_hints.SupportedDataType, dataframe.DataFrame],
663
+ database_name: Optional[sql_identifier.SqlIdentifier],
664
+ schema_name: Optional[sql_identifier.SqlIdentifier],
665
+ service_name: sql_identifier.SqlIdentifier,
666
+ strict_input_validation: bool = False,
667
+ statement_params: Optional[Dict[str, str]] = None,
668
+ ) -> Union[type_hints.SupportedDataType, dataframe.DataFrame]:
669
+ ...
670
+
671
+ def invoke_method(
672
+ self,
673
+ *,
674
+ method_name: sql_identifier.SqlIdentifier,
675
+ method_function_type: Optional[str] = None,
676
+ signature: model_signature.ModelSignature,
677
+ X: Union[type_hints.SupportedDataType, dataframe.DataFrame],
678
+ database_name: Optional[sql_identifier.SqlIdentifier],
679
+ schema_name: Optional[sql_identifier.SqlIdentifier],
680
+ model_name: Optional[sql_identifier.SqlIdentifier] = None,
681
+ version_name: Optional[sql_identifier.SqlIdentifier] = None,
682
+ service_name: Optional[sql_identifier.SqlIdentifier] = None,
683
+ strict_input_validation: bool = False,
684
+ partition_column: Optional[sql_identifier.SqlIdentifier] = None,
685
+ statement_params: Optional[Dict[str, str]] = None,
686
+ is_partitioned: Optional[bool] = None,
624
687
  ) -> Union[type_hints.SupportedDataType, dataframe.DataFrame]:
625
688
  identifier_rule = model_signature.SnowparkIdentifierRule.INFERRED
626
689
 
@@ -657,31 +720,46 @@ class ModelOperator:
657
720
  if output_name in original_cols:
658
721
  original_cols.remove(output_name)
659
722
 
660
- if method_function_type == model_manifest_schema.ModelMethodFunctionTypes.FUNCTION.value:
661
- df_res = self._model_version_client.invoke_function_method(
723
+ if service_name:
724
+ df_res = self._service_client.invoke_function_method(
662
725
  method_name=method_name,
663
726
  input_df=s_df,
664
727
  input_args=input_args,
665
728
  returns=returns,
666
729
  database_name=database_name,
667
730
  schema_name=schema_name,
668
- model_name=model_name,
669
- version_name=version_name,
670
- statement_params=statement_params,
671
- )
672
- elif method_function_type == model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value:
673
- df_res = self._model_version_client.invoke_table_function_method(
674
- method_name=method_name,
675
- input_df=s_df,
676
- input_args=input_args,
677
- partition_column=partition_column,
678
- returns=returns,
679
- database_name=database_name,
680
- schema_name=schema_name,
681
- model_name=model_name,
682
- version_name=version_name,
731
+ service_name=service_name,
683
732
  statement_params=statement_params,
684
733
  )
734
+ else:
735
+ assert model_name is not None
736
+ assert version_name is not None
737
+ if method_function_type == model_manifest_schema.ModelMethodFunctionTypes.FUNCTION.value:
738
+ df_res = self._model_version_client.invoke_function_method(
739
+ method_name=method_name,
740
+ input_df=s_df,
741
+ input_args=input_args,
742
+ returns=returns,
743
+ database_name=database_name,
744
+ schema_name=schema_name,
745
+ model_name=model_name,
746
+ version_name=version_name,
747
+ statement_params=statement_params,
748
+ )
749
+ elif method_function_type == model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value:
750
+ df_res = self._model_version_client.invoke_table_function_method(
751
+ method_name=method_name,
752
+ input_df=s_df,
753
+ input_args=input_args,
754
+ partition_column=partition_column,
755
+ returns=returns,
756
+ database_name=database_name,
757
+ schema_name=schema_name,
758
+ model_name=model_name,
759
+ version_name=version_name,
760
+ statement_params=statement_params,
761
+ is_partitioned=is_partitioned or False,
762
+ )
685
763
 
686
764
  if keep_order:
687
765
  # if it's a partitioned table function, _ID will be null and we won't be able to sort.
@@ -0,0 +1,121 @@
1
+ import pathlib
2
+ import tempfile
3
+ from typing import Any, Dict, Optional
4
+
5
+ from snowflake.ml._internal import file_utils
6
+ from snowflake.ml._internal.utils import sql_identifier
7
+ from snowflake.ml.model._client.service import model_deployment_spec
8
+ from snowflake.ml.model._client.sql import service as service_sql, stage as stage_sql
9
+ from snowflake.snowpark import session
10
+ from snowflake.snowpark._internal import utils as snowpark_utils
11
+
12
+
13
+ class ServiceOperator:
14
+ """Service operator for container services logic."""
15
+
16
+ def __init__(
17
+ self,
18
+ session: session.Session,
19
+ *,
20
+ database_name: sql_identifier.SqlIdentifier,
21
+ schema_name: sql_identifier.SqlIdentifier,
22
+ ) -> None:
23
+ self._session = session
24
+ self._database_name = database_name
25
+ self._schema_name = schema_name
26
+ self._workspace = tempfile.TemporaryDirectory()
27
+ self._service_client = service_sql.ServiceSQLClient(
28
+ session,
29
+ database_name=database_name,
30
+ schema_name=schema_name,
31
+ )
32
+ self._stage_client = stage_sql.StageSQLClient(
33
+ session,
34
+ database_name=database_name,
35
+ schema_name=schema_name,
36
+ )
37
+ self._model_deployment_spec = model_deployment_spec.ModelDeploymentSpec(
38
+ workspace_path=pathlib.Path(self._workspace.name)
39
+ )
40
+
41
+ def __eq__(self, __value: object) -> bool:
42
+ if not isinstance(__value, ServiceOperator):
43
+ return False
44
+ return self._service_client == __value._service_client
45
+
46
+ @property
47
+ def workspace_path(self) -> pathlib.Path:
48
+ return pathlib.Path(self._workspace.name)
49
+
50
+ def create_service(
51
+ self,
52
+ *,
53
+ database_name: Optional[sql_identifier.SqlIdentifier],
54
+ schema_name: Optional[sql_identifier.SqlIdentifier],
55
+ model_name: sql_identifier.SqlIdentifier,
56
+ version_name: sql_identifier.SqlIdentifier,
57
+ service_database_name: Optional[sql_identifier.SqlIdentifier],
58
+ service_schema_name: Optional[sql_identifier.SqlIdentifier],
59
+ service_name: sql_identifier.SqlIdentifier,
60
+ image_build_compute_pool_name: sql_identifier.SqlIdentifier,
61
+ service_compute_pool_name: sql_identifier.SqlIdentifier,
62
+ image_repo_database_name: Optional[sql_identifier.SqlIdentifier],
63
+ image_repo_schema_name: Optional[sql_identifier.SqlIdentifier],
64
+ image_repo_name: sql_identifier.SqlIdentifier,
65
+ image_name: Optional[sql_identifier.SqlIdentifier],
66
+ ingress_enabled: bool,
67
+ min_instances: int,
68
+ max_instances: int,
69
+ gpu_requests: Optional[str],
70
+ force_rebuild: bool,
71
+ build_external_access_integration: sql_identifier.SqlIdentifier,
72
+ statement_params: Optional[Dict[str, Any]] = None,
73
+ ) -> str:
74
+ # create a temp stage
75
+ stage_name = sql_identifier.SqlIdentifier(
76
+ snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.STAGE)
77
+ )
78
+ self._stage_client.create_tmp_stage(
79
+ database_name=database_name,
80
+ schema_name=schema_name,
81
+ stage_name=stage_name,
82
+ statement_params=statement_params,
83
+ )
84
+ stage_path = self._stage_client.fully_qualified_object_name(database_name, schema_name, stage_name)
85
+
86
+ self._model_deployment_spec.save(
87
+ database_name=database_name or self._database_name,
88
+ schema_name=schema_name or self._schema_name,
89
+ model_name=model_name,
90
+ version_name=version_name,
91
+ service_database_name=service_database_name,
92
+ service_schema_name=service_schema_name,
93
+ service_name=service_name,
94
+ image_build_compute_pool_name=image_build_compute_pool_name,
95
+ service_compute_pool_name=service_compute_pool_name,
96
+ image_repo_database_name=image_repo_database_name,
97
+ image_repo_schema_name=image_repo_schema_name,
98
+ image_repo_name=image_repo_name,
99
+ image_name=image_name,
100
+ ingress_enabled=ingress_enabled,
101
+ min_instances=min_instances,
102
+ max_instances=max_instances,
103
+ gpu=gpu_requests,
104
+ force_rebuild=force_rebuild,
105
+ external_access_integration=build_external_access_integration,
106
+ )
107
+ file_utils.upload_directory_to_stage(
108
+ self._session,
109
+ local_path=self.workspace_path,
110
+ stage_path=pathlib.PurePosixPath(stage_path),
111
+ statement_params=statement_params,
112
+ )
113
+
114
+ # deploy the model service
115
+ self._service_client.deploy_model(
116
+ stage_path=stage_path,
117
+ model_deployment_spec_file_rel_path=model_deployment_spec.ModelDeploymentSpec.DEPLOY_SPEC_FILE_REL_PATH,
118
+ statement_params=statement_params,
119
+ )
120
+
121
+ return service_name
@@ -0,0 +1,95 @@
1
+ import pathlib
2
+ from typing import Optional
3
+
4
+ import yaml
5
+
6
+ from snowflake.ml._internal.utils import identifier, sql_identifier
7
+ from snowflake.ml.model._client.service import model_deployment_spec_schema
8
+
9
+
10
+ class ModelDeploymentSpec:
11
+ """Class to construct deploy.yml file for Model container services deployment.
12
+
13
+ Attributes:
14
+ workspace_path: A local path where model related files should be dumped to.
15
+ """
16
+
17
+ DEPLOY_SPEC_FILE_REL_PATH = "deploy.yml"
18
+
19
+ def __init__(self, workspace_path: pathlib.Path) -> None:
20
+ self.workspace_path = workspace_path
21
+
22
+ def save(
23
+ self,
24
+ *,
25
+ database_name: sql_identifier.SqlIdentifier,
26
+ schema_name: sql_identifier.SqlIdentifier,
27
+ model_name: sql_identifier.SqlIdentifier,
28
+ version_name: sql_identifier.SqlIdentifier,
29
+ service_database_name: Optional[sql_identifier.SqlIdentifier],
30
+ service_schema_name: Optional[sql_identifier.SqlIdentifier],
31
+ service_name: sql_identifier.SqlIdentifier,
32
+ image_build_compute_pool_name: sql_identifier.SqlIdentifier,
33
+ service_compute_pool_name: sql_identifier.SqlIdentifier,
34
+ image_repo_database_name: Optional[sql_identifier.SqlIdentifier],
35
+ image_repo_schema_name: Optional[sql_identifier.SqlIdentifier],
36
+ image_repo_name: sql_identifier.SqlIdentifier,
37
+ image_name: Optional[sql_identifier.SqlIdentifier],
38
+ ingress_enabled: bool,
39
+ min_instances: int,
40
+ max_instances: int,
41
+ gpu: Optional[str],
42
+ force_rebuild: bool,
43
+ external_access_integration: sql_identifier.SqlIdentifier,
44
+ ) -> None:
45
+ # create the deployment spec
46
+ # models spec
47
+ fq_model_name = identifier.get_schema_level_object_identifier(
48
+ database_name.identifier(), schema_name.identifier(), model_name.identifier()
49
+ )
50
+ model_dict = model_deployment_spec_schema.ModelDict(name=fq_model_name, version=version_name.identifier())
51
+
52
+ # image_build spec
53
+ saved_image_repo_database = image_repo_database_name or database_name
54
+ saved_image_repo_schema = image_repo_schema_name or schema_name
55
+ fq_image_repo_name = identifier.get_schema_level_object_identifier(
56
+ saved_image_repo_database.identifier(), saved_image_repo_schema.identifier(), image_repo_name.identifier()
57
+ )
58
+ image_build_dict = model_deployment_spec_schema.ImageBuildDict(
59
+ compute_pool=image_build_compute_pool_name.identifier(),
60
+ image_repo=fq_image_repo_name,
61
+ force_rebuild=force_rebuild,
62
+ external_access_integrations=[external_access_integration.identifier()],
63
+ )
64
+ if image_name:
65
+ image_build_dict["image_name"] = image_name.identifier()
66
+
67
+ # service spec
68
+ saved_service_database = service_database_name or database_name
69
+ saved_service_schema = service_schema_name or schema_name
70
+ fq_service_name = identifier.get_schema_level_object_identifier(
71
+ saved_service_database.identifier(), saved_service_schema.identifier(), service_name.identifier()
72
+ )
73
+ service_dict = model_deployment_spec_schema.ServiceDict(
74
+ name=fq_service_name,
75
+ compute_pool=service_compute_pool_name.identifier(),
76
+ ingress_enabled=ingress_enabled,
77
+ min_instances=min_instances,
78
+ max_instances=max_instances,
79
+ )
80
+ if gpu:
81
+ service_dict["gpu"] = gpu
82
+
83
+ # model deployment spec
84
+ model_deployment_spec_dict = model_deployment_spec_schema.ModelDeploymentSpecDict(
85
+ models=[model_dict],
86
+ image_build=image_build_dict,
87
+ service=service_dict,
88
+ )
89
+
90
+ # save the yaml
91
+ file_path = self.workspace_path / self.DEPLOY_SPEC_FILE_REL_PATH
92
+ with file_path.open("w", encoding="utf-8") as f:
93
+ # Anchors are not supported in the server, avoid that.
94
+ yaml.SafeDumper.ignore_aliases = lambda *args: True # type: ignore[method-assign]
95
+ yaml.safe_dump(model_deployment_spec_dict, f)