snowflake-ml-python 1.5.4__py3-none-any.whl → 1.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (100) hide show
  1. snowflake/cortex/__init__.py +2 -0
  2. snowflake/cortex/_classify_text.py +36 -0
  3. snowflake/cortex/_complete.py +66 -35
  4. snowflake/cortex/_util.py +4 -4
  5. snowflake/ml/_internal/env_utils.py +11 -5
  6. snowflake/ml/_internal/exceptions/modeling_error_messages.py +4 -1
  7. snowflake/ml/_internal/lineage/lineage_utils.py +4 -4
  8. snowflake/ml/_internal/telemetry.py +26 -2
  9. snowflake/ml/_internal/utils/pkg_version_utils.py +8 -22
  10. snowflake/ml/data/_internal/arrow_ingestor.py +284 -0
  11. snowflake/ml/data/data_connector.py +186 -0
  12. snowflake/ml/data/data_ingestor.py +45 -0
  13. snowflake/ml/data/data_source.py +23 -0
  14. snowflake/ml/data/ingestor_utils.py +62 -0
  15. snowflake/ml/data/torch_dataset.py +33 -0
  16. snowflake/ml/dataset/dataset.py +1 -13
  17. snowflake/ml/dataset/dataset_metadata.py +3 -1
  18. snowflake/ml/dataset/dataset_reader.py +23 -117
  19. snowflake/ml/feature_store/access_manager.py +7 -1
  20. snowflake/ml/feature_store/entity.py +19 -2
  21. snowflake/ml/feature_store/examples/airline_features/entities.py +16 -0
  22. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +31 -0
  23. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +42 -0
  24. snowflake/ml/feature_store/examples/airline_features/source.yaml +7 -0
  25. snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +20 -0
  26. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +37 -0
  27. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +30 -0
  28. snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +7 -0
  29. snowflake/ml/feature_store/examples/example_helper.py +278 -0
  30. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +12 -0
  31. snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +44 -0
  32. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +36 -0
  33. snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +9 -0
  34. snowflake/ml/feature_store/examples/source_data/airline.yaml +4 -0
  35. snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +36 -0
  36. snowflake/ml/feature_store/examples/source_data/fraud_transactions.yaml +29 -0
  37. snowflake/ml/feature_store/examples/source_data/nyc_yellow_trips.yaml +4 -0
  38. snowflake/ml/feature_store/examples/source_data/winequality_red.yaml +32 -0
  39. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +14 -0
  40. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +36 -0
  41. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +24 -0
  42. snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +8 -0
  43. snowflake/ml/feature_store/feature_store.py +637 -76
  44. snowflake/ml/feature_store/feature_view.py +316 -9
  45. snowflake/ml/fileset/stage_fs.py +18 -10
  46. snowflake/ml/lineage/lineage_node.py +1 -1
  47. snowflake/ml/model/_client/model/model_impl.py +11 -2
  48. snowflake/ml/model/_client/model/model_version_impl.py +171 -20
  49. snowflake/ml/model/_client/ops/model_ops.py +105 -27
  50. snowflake/ml/model/_client/ops/service_ops.py +121 -0
  51. snowflake/ml/model/_client/service/model_deployment_spec.py +95 -0
  52. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +31 -0
  53. snowflake/ml/model/_client/sql/model_version.py +13 -4
  54. snowflake/ml/model/_client/sql/service.py +129 -0
  55. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +2 -3
  56. snowflake/ml/model/_model_composer/model_composer.py +14 -14
  57. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +33 -17
  58. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +5 -1
  59. snowflake/ml/model/_model_composer/model_method/function_generator.py +3 -3
  60. snowflake/ml/model/_model_composer/model_method/infer_function.py_template +3 -32
  61. snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +3 -27
  62. snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +3 -32
  63. snowflake/ml/model/_model_composer/model_method/model_method.py +5 -2
  64. snowflake/ml/model/_packager/model_env/model_env.py +7 -2
  65. snowflake/ml/model/_packager/model_handlers/_base.py +30 -3
  66. snowflake/ml/model/_packager/model_handlers/_utils.py +58 -1
  67. snowflake/ml/model/_packager/model_handlers/catboost.py +52 -3
  68. snowflake/ml/model/_packager/model_handlers/custom.py +6 -2
  69. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +9 -5
  70. snowflake/ml/model/_packager/model_handlers/lightgbm.py +80 -3
  71. snowflake/ml/model/_packager/model_handlers/llm.py +7 -3
  72. snowflake/ml/model/_packager/model_handlers/mlflow.py +8 -3
  73. snowflake/ml/model/_packager/model_handlers/pytorch.py +8 -3
  74. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -3
  75. snowflake/ml/model/_packager/model_handlers/sklearn.py +87 -4
  76. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +7 -2
  77. snowflake/ml/model/_packager/model_handlers/tensorflow.py +9 -4
  78. snowflake/ml/model/_packager/model_handlers/torchscript.py +8 -3
  79. snowflake/ml/model/_packager/model_handlers/xgboost.py +71 -3
  80. snowflake/ml/model/_packager/model_meta/model_meta.py +32 -2
  81. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +19 -0
  82. snowflake/ml/model/_packager/model_packager.py +2 -1
  83. snowflake/ml/model/_packager/model_runtime/model_runtime.py +7 -7
  84. snowflake/ml/model/model_signature.py +4 -4
  85. snowflake/ml/model/type_hints.py +2 -0
  86. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +1 -1
  87. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_search_udf_file.py +13 -1
  88. snowflake/ml/modeling/framework/base.py +28 -19
  89. snowflake/ml/modeling/impute/simple_imputer.py +26 -0
  90. snowflake/ml/modeling/pipeline/pipeline.py +7 -4
  91. snowflake/ml/registry/_manager/model_manager.py +16 -2
  92. snowflake/ml/registry/registry.py +100 -13
  93. snowflake/ml/utils/sql_client.py +22 -0
  94. snowflake/ml/version.py +1 -1
  95. {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.1.dist-info}/METADATA +81 -2
  96. {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.1.dist-info}/RECORD +99 -66
  97. {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.1.dist-info}/WHEEL +1 -1
  98. snowflake/ml/_internal/lineage/data_source.py +0 -10
  99. {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.1.dist-info}/LICENSE.txt +0 -0
  100. {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,31 @@
1
+ from typing import List, TypedDict
2
+
3
+ from typing_extensions import NotRequired, Required
4
+
5
+
6
+ class ModelDict(TypedDict):
7
+ name: Required[str]
8
+ version: Required[str]
9
+
10
+
11
+ class ImageBuildDict(TypedDict):
12
+ compute_pool: Required[str]
13
+ image_repo: Required[str]
14
+ image_name: NotRequired[str]
15
+ force_rebuild: Required[bool]
16
+ external_access_integrations: Required[List[str]]
17
+
18
+
19
+ class ServiceDict(TypedDict):
20
+ name: Required[str]
21
+ compute_pool: Required[str]
22
+ ingress_enabled: Required[bool]
23
+ min_instances: Required[int]
24
+ max_instances: Required[int]
25
+ gpu: NotRequired[str]
26
+
27
+
28
+ class ModelDeploymentSpecDict(TypedDict):
29
+ models: Required[List[ModelDict]]
30
+ image_build: Required[ImageBuildDict]
31
+ service: Required[ServiceDict]
@@ -371,6 +371,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
371
371
  returns: List[Tuple[str, spt.DataType, sql_identifier.SqlIdentifier]],
372
372
  partition_column: Optional[sql_identifier.SqlIdentifier],
373
373
  statement_params: Optional[Dict[str, Any]] = None,
374
+ is_partitioned: bool = True,
374
375
  ) -> dataframe.DataFrame:
375
376
  with_statements = []
376
377
  if len(input_df.queries["queries"]) == 1 and len(input_df.queries["post_actions"]) == 0:
@@ -409,12 +410,20 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
409
410
 
410
411
  sql = textwrap.dedent(
411
412
  f"""WITH {','.join(with_statements)}
412
- SELECT *,
413
- FROM {INTERMEDIATE_TABLE_NAME},
414
- TABLE({module_version_alias}!{method_name.identifier()}({args_sql})
415
- OVER (PARTITION BY {partition_by}))"""
413
+ SELECT *,
414
+ FROM {INTERMEDIATE_TABLE_NAME},
415
+ TABLE({module_version_alias}!{method_name.identifier()}({args_sql}))"""
416
416
  )
417
417
 
418
+ if is_partitioned or partition_column is not None:
419
+ sql = textwrap.dedent(
420
+ f"""WITH {','.join(with_statements)}
421
+ SELECT *,
422
+ FROM {INTERMEDIATE_TABLE_NAME},
423
+ TABLE({module_version_alias}!{method_name.identifier()}({args_sql})
424
+ OVER (PARTITION BY {partition_by}))"""
425
+ )
426
+
418
427
  output_df = self._session.sql(sql)
419
428
 
420
429
  # Prepare the output
@@ -0,0 +1,129 @@
1
+ import textwrap
2
+ from typing import Any, Dict, List, Optional, Tuple
3
+
4
+ from snowflake.ml._internal.utils import (
5
+ identifier,
6
+ query_result_checker,
7
+ sql_identifier,
8
+ )
9
+ from snowflake.ml.model._client.sql import _base
10
+ from snowflake.snowpark import dataframe, functions as F, types as spt
11
+ from snowflake.snowpark._internal import utils as snowpark_utils
12
+
13
+
14
+ class ServiceSQLClient(_base._BaseSQLClient):
15
+ def build_model_container(
16
+ self,
17
+ *,
18
+ database_name: Optional[sql_identifier.SqlIdentifier],
19
+ schema_name: Optional[sql_identifier.SqlIdentifier],
20
+ model_name: sql_identifier.SqlIdentifier,
21
+ version_name: sql_identifier.SqlIdentifier,
22
+ compute_pool_name: sql_identifier.SqlIdentifier,
23
+ image_repo_database_name: Optional[sql_identifier.SqlIdentifier],
24
+ image_repo_schema_name: Optional[sql_identifier.SqlIdentifier],
25
+ image_repo_name: sql_identifier.SqlIdentifier,
26
+ gpu: Optional[str],
27
+ force_rebuild: bool,
28
+ external_access_integration: sql_identifier.SqlIdentifier,
29
+ statement_params: Optional[Dict[str, Any]] = None,
30
+ ) -> None:
31
+ actual_image_repo_database = image_repo_database_name or self._database_name
32
+ actual_image_repo_schema = image_repo_schema_name or self._schema_name
33
+ fq_model_name = self.fully_qualified_object_name(database_name, schema_name, model_name)
34
+ fq_image_repo_name = "/" + "/".join(
35
+ [
36
+ actual_image_repo_database.identifier(),
37
+ actual_image_repo_schema.identifier(),
38
+ image_repo_name.identifier(),
39
+ ]
40
+ )
41
+ is_gpu = gpu is not None
42
+ query_result_checker.SqlResultValidator(
43
+ self._session,
44
+ (
45
+ f"CALL SYSTEM$BUILD_MODEL_CONTAINER('{fq_model_name}', '{version_name}', '{compute_pool_name}',"
46
+ f" '{fq_image_repo_name}', '{is_gpu}', '{force_rebuild}', '', '{external_access_integration}')"
47
+ ),
48
+ statement_params=statement_params,
49
+ ).has_dimensions(expected_rows=1, expected_cols=1).validate()
50
+
51
+ def deploy_model(
52
+ self,
53
+ *,
54
+ stage_path: str,
55
+ model_deployment_spec_file_rel_path: str,
56
+ statement_params: Optional[Dict[str, Any]] = None,
57
+ ) -> None:
58
+ query_result_checker.SqlResultValidator(
59
+ self._session,
60
+ f"CALL SYSTEM$DEPLOY_MODEL('@{stage_path}/{model_deployment_spec_file_rel_path}')",
61
+ statement_params=statement_params,
62
+ ).has_dimensions(expected_rows=1, expected_cols=1).validate()
63
+
64
+ def invoke_function_method(
65
+ self,
66
+ *,
67
+ database_name: Optional[sql_identifier.SqlIdentifier],
68
+ schema_name: Optional[sql_identifier.SqlIdentifier],
69
+ service_name: sql_identifier.SqlIdentifier,
70
+ method_name: sql_identifier.SqlIdentifier,
71
+ input_df: dataframe.DataFrame,
72
+ input_args: List[sql_identifier.SqlIdentifier],
73
+ returns: List[Tuple[str, spt.DataType, sql_identifier.SqlIdentifier]],
74
+ statement_params: Optional[Dict[str, Any]] = None,
75
+ ) -> dataframe.DataFrame:
76
+ with_statements = []
77
+ if len(input_df.queries["queries"]) == 1 and len(input_df.queries["post_actions"]) == 0:
78
+ INTERMEDIATE_TABLE_NAME = "SNOWPARK_ML_MODEL_INFERENCE_INPUT"
79
+ with_statements.append(f"{INTERMEDIATE_TABLE_NAME} AS ({input_df.queries['queries'][0]})")
80
+ else:
81
+ actual_database_name = database_name or self._database_name
82
+ actual_schema_name = schema_name or self._schema_name
83
+ tmp_table_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.TABLE)
84
+ INTERMEDIATE_TABLE_NAME = identifier.get_schema_level_object_identifier(
85
+ actual_database_name.identifier(),
86
+ actual_schema_name.identifier(),
87
+ tmp_table_name,
88
+ )
89
+ input_df.write.save_as_table(
90
+ table_name=INTERMEDIATE_TABLE_NAME,
91
+ mode="errorifexists",
92
+ table_type="temporary",
93
+ statement_params=statement_params,
94
+ )
95
+
96
+ INTERMEDIATE_OBJ_NAME = "TMP_RESULT"
97
+
98
+ with_sql = f"WITH {','.join(with_statements)}" if with_statements else ""
99
+ args_sql_list = []
100
+ for input_arg_value in input_args:
101
+ args_sql_list.append(input_arg_value)
102
+ args_sql = ", ".join(args_sql_list)
103
+
104
+ sql = textwrap.dedent(
105
+ f"""{with_sql}
106
+ SELECT *,
107
+ {service_name.identifier()}_{method_name.identifier()}({args_sql}) AS {INTERMEDIATE_OBJ_NAME}
108
+ FROM {INTERMEDIATE_TABLE_NAME}"""
109
+ )
110
+
111
+ output_df = self._session.sql(sql)
112
+
113
+ # Prepare the output
114
+ output_cols = []
115
+ output_names = []
116
+
117
+ for output_name, output_type, output_col_name in returns:
118
+ output_cols.append(F.col(INTERMEDIATE_OBJ_NAME)[output_name].astype(output_type))
119
+ output_names.append(output_col_name)
120
+
121
+ output_df = output_df.with_columns(
122
+ col_names=output_names,
123
+ values=output_cols,
124
+ ).drop(INTERMEDIATE_OBJ_NAME)
125
+
126
+ if statement_params:
127
+ output_df._statement_params = statement_params # type: ignore[assignment]
128
+
129
+ return output_df
@@ -85,9 +85,8 @@ def _run_setup() -> None:
85
85
 
86
86
  TARGET_METHOD = os.getenv("TARGET_METHOD")
87
87
 
88
- _concurrent_requests_max_env = os.getenv("_CONCURRENT_REQUESTS_MAX", None)
89
-
90
- _CONCURRENT_REQUESTS_MAX = int(_concurrent_requests_max_env) if _concurrent_requests_max_env else None
88
+ _concurrent_requests_max_env = os.getenv("_CONCURRENT_REQUESTS_MAX", "1")
89
+ _CONCURRENT_REQUESTS_MAX = int(_concurrent_requests_max_env)
91
90
 
92
91
  with tempfile.TemporaryDirectory() as tmp_dir:
93
92
  if zipfile.is_zipfile(model_zip_stage_path):
@@ -10,8 +10,10 @@ from absl import logging
10
10
  from packaging import requirements
11
11
  from typing_extensions import deprecated
12
12
 
13
+ from snowflake import snowpark
13
14
  from snowflake.ml._internal import env as snowml_env, env_utils, file_utils
14
- from snowflake.ml._internal.lineage import data_source, lineage_utils
15
+ from snowflake.ml._internal.lineage import lineage_utils
16
+ from snowflake.ml.data import data_source
15
17
  from snowflake.ml.model import model_signature, type_hints as model_types
16
18
  from snowflake.ml.model._model_composer.model_manifest import model_manifest
17
19
  from snowflake.ml.model._packager import model_packager
@@ -128,16 +130,14 @@ class ModelComposer:
128
130
  file_utils.copytree(
129
131
  str(self._packager_workspace_path), str(self.workspace_path / ModelComposer.MODEL_DIR_REL_PATH)
130
132
  )
131
-
132
- file_utils.make_archive(self.model_local_path, str(self._packager_workspace_path))
133
-
134
- self.manifest.save(
135
- session=self.session,
136
- model_meta=model_metadata,
137
- model_file_rel_path=pathlib.PurePosixPath(self.model_file_rel_path),
138
- options=options,
139
- data_sources=self._get_data_sources(model, sample_input_data),
140
- )
133
+ self.manifest.save(
134
+ model_meta=self.packager.meta,
135
+ model_rel_path=pathlib.PurePosixPath(ModelComposer.MODEL_DIR_REL_PATH),
136
+ options=options,
137
+ data_sources=self._get_data_sources(model, sample_input_data),
138
+ )
139
+ else:
140
+ file_utils.make_archive(self.model_local_path, str(self._packager_workspace_path))
141
141
 
142
142
  file_utils.upload_directory_to_stage(
143
143
  self.session,
@@ -186,6 +186,6 @@ class ModelComposer:
186
186
  data_sources = lineage_utils.get_data_sources(model)
187
187
  if not data_sources and sample_input_data is not None:
188
188
  data_sources = lineage_utils.get_data_sources(sample_input_data)
189
- if isinstance(data_sources, list) and all(isinstance(item, data_source.DataSource) for item in data_sources):
190
- return data_sources
191
- return None
189
+ if not data_sources and isinstance(sample_input_data, snowpark.DataFrame):
190
+ data_sources = [data_source.DataFrameInfo(sample_input_data.queries["queries"][-1])]
191
+ return data_sources
@@ -1,11 +1,13 @@
1
1
  import collections
2
2
  import copy
3
3
  import pathlib
4
+ import warnings
4
5
  from typing import List, Optional, cast
5
6
 
6
7
  import yaml
7
8
 
8
- from snowflake.ml._internal.lineage import data_source
9
+ from snowflake.ml._internal import env_utils
10
+ from snowflake.ml.data import data_source
9
11
  from snowflake.ml.model import type_hints
10
12
  from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
11
13
  from snowflake.ml.model._model_composer.model_method import (
@@ -16,7 +18,6 @@ from snowflake.ml.model._packager.model_meta import (
16
18
  model_meta as model_meta_api,
17
19
  model_meta_schema,
18
20
  )
19
- from snowflake.snowpark import Session
20
21
 
21
22
 
22
23
  class ModelManifest:
@@ -36,9 +37,8 @@ class ModelManifest:
36
37
 
37
38
  def save(
38
39
  self,
39
- session: Session,
40
40
  model_meta: model_meta_api.ModelMetadata,
41
- model_file_rel_path: pathlib.PurePosixPath,
41
+ model_rel_path: pathlib.PurePosixPath,
42
42
  options: Optional[type_hints.ModelSaveOption] = None,
43
43
  data_sources: Optional[List[data_source.DataSource]] = None,
44
44
  ) -> None:
@@ -47,10 +47,12 @@ class ModelManifest:
47
47
 
48
48
  runtime_to_use = copy.deepcopy(model_meta.runtimes["cpu"])
49
49
  runtime_to_use.name = self._DEFAULT_RUNTIME_NAME
50
- runtime_to_use.imports.append(model_file_rel_path)
51
- runtime_dict = runtime_to_use.save(self.workspace_path)
50
+ runtime_to_use.imports.append(str(model_rel_path) + "/")
51
+ runtime_dict = runtime_to_use.save(
52
+ self.workspace_path, default_channel_override=env_utils.SNOWFLAKE_CONDA_CHANNEL_URL
53
+ )
52
54
 
53
- self.function_generator = function_generator.FunctionGenerator(model_file_rel_path=model_file_rel_path)
55
+ self.function_generator = function_generator.FunctionGenerator(model_dir_rel_path=model_rel_path)
54
56
  self.methods: List[model_method.ModelMethod] = []
55
57
  for target_method in model_meta.signatures.keys():
56
58
  method = model_method.ModelMethod(
@@ -75,6 +77,16 @@ class ModelManifest:
75
77
  "In this case, set case_sensitive as True for those methods to distinguish them."
76
78
  )
77
79
 
80
+ dependencies = model_manifest_schema.ModelRuntimeDependenciesDict(conda=runtime_dict["dependencies"]["conda"])
81
+ if options.get("include_pip_dependencies"):
82
+ warnings.warn(
83
+ "`include_pip_dependencies` specified as True: pip dependencies will be included and may not"
84
+ "be warehouse-compabible. The model may need to be run in SPCS.",
85
+ category=UserWarning,
86
+ stacklevel=1,
87
+ )
88
+ dependencies["pip"] = runtime_dict["dependencies"]["pip"]
89
+
78
90
  manifest_dict = model_manifest_schema.ModelManifestDict(
79
91
  manifest_version=model_manifest_schema.MODEL_MANIFEST_VERSION,
80
92
  runtimes={
@@ -82,9 +94,7 @@ class ModelManifest:
82
94
  language="PYTHON",
83
95
  version=runtime_to_use.runtime_env.python_version,
84
96
  imports=runtime_dict["imports"],
85
- dependencies=model_manifest_schema.ModelRuntimeDependenciesDict(
86
- conda=runtime_dict["dependencies"]["conda"]
87
- ),
97
+ dependencies=dependencies,
88
98
  )
89
99
  },
90
100
  methods=[
@@ -127,12 +137,18 @@ class ModelManifest:
127
137
  result = []
128
138
  if data_sources:
129
139
  for source in data_sources:
130
- result.append(
131
- model_manifest_schema.LineageSourceDict(
132
- # Currently, we only support lineage from Dataset.
133
- type=model_manifest_schema.LineageSourceTypes.DATASET.value,
134
- entity=source.fully_qualified_name,
135
- version=source.version,
140
+ if isinstance(source, data_source.DatasetInfo):
141
+ result.append(
142
+ model_manifest_schema.LineageSourceDict(
143
+ type=model_manifest_schema.LineageSourceTypes.DATASET.value,
144
+ entity=source.fully_qualified_name,
145
+ version=source.version,
146
+ )
147
+ )
148
+ elif isinstance(source, data_source.DataFrameInfo):
149
+ result.append(
150
+ model_manifest_schema.LineageSourceDict(
151
+ type=model_manifest_schema.LineageSourceTypes.QUERY.value, entity=source.sql
152
+ )
136
153
  )
137
- )
138
154
  return result
@@ -18,7 +18,8 @@ class ModelMethodFunctionTypes(enum.Enum):
18
18
 
19
19
 
20
20
  class ModelRuntimeDependenciesDict(TypedDict):
21
- conda: Required[str]
21
+ conda: NotRequired[str]
22
+ pip: NotRequired[str]
22
23
 
23
24
 
24
25
  class ModelRuntimeDict(TypedDict):
@@ -56,12 +57,14 @@ class ModelFunctionInfo(TypedDict):
56
57
  target_method: actual target method name to be called.
57
58
  target_method_function_type: target method function type (FUNCTION or TABLE_FUNCTION).
58
59
  signature: The signature of the model method.
60
+ is_partitioned: Whether the function is partitioned.
59
61
  """
60
62
 
61
63
  name: Required[str]
62
64
  target_method: Required[str]
63
65
  target_method_function_type: Required[str]
64
66
  signature: Required[model_signature.ModelSignature]
67
+ is_partitioned: Required[bool]
65
68
 
66
69
 
67
70
  class ModelFunctionInfoDict(TypedDict):
@@ -77,6 +80,7 @@ class SnowparkMLDataDict(TypedDict):
77
80
 
78
81
  class LineageSourceTypes(enum.Enum):
79
82
  DATASET = "DATASET"
83
+ QUERY = "QUERY"
80
84
 
81
85
 
82
86
  class LineageSourceDict(TypedDict):
@@ -33,9 +33,9 @@ class FunctionGenerator:
33
33
 
34
34
  def __init__(
35
35
  self,
36
- model_file_rel_path: pathlib.PurePosixPath,
36
+ model_dir_rel_path: pathlib.PurePosixPath,
37
37
  ) -> None:
38
- self.model_file_rel_path = model_file_rel_path
38
+ self.model_dir_rel_path = model_dir_rel_path
39
39
 
40
40
  def generate(
41
41
  self,
@@ -67,7 +67,7 @@ class FunctionGenerator:
67
67
  )
68
68
 
69
69
  udf_code = function_template.format(
70
- model_file_name=self.model_file_rel_path.name,
70
+ model_dir_name=self.model_dir_rel_path.name,
71
71
  target_method=target_method,
72
72
  max_batch_size=options.get("max_batch_size", None),
73
73
  function_name=FunctionGenerator.FUNCTION_NAME,
@@ -1,12 +1,7 @@
1
- import fcntl
2
1
  import functools
3
2
  import inspect
4
3
  import os
5
4
  import sys
6
- import threading
7
- import zipfile
8
- from types import TracebackType
9
- from typing import Optional, Type
10
5
 
11
6
  import anyio
12
7
  import pandas as pd
@@ -15,42 +10,18 @@ from _snowflake import vectorized
15
10
  from snowflake.ml.model._packager import model_packager
16
11
 
17
12
 
18
- class FileLock:
19
- def __enter__(self) -> None:
20
- self._lock = threading.Lock()
21
- self._lock.acquire()
22
- self._fd = open("/tmp/lockfile.LOCK", "w+")
23
- fcntl.lockf(self._fd, fcntl.LOCK_EX)
24
-
25
- def __exit__(
26
- self, exc_type: Optional[Type[BaseException]], exc: Optional[BaseException], traceback: Optional[TracebackType]
27
- ) -> None:
28
- self._fd.close()
29
- self._lock.release()
30
-
31
-
32
13
  # User-defined parameters
33
- MODEL_FILE_NAME = "{model_file_name}"
14
+ MODEL_DIR_REL_PATH = "{model_dir_name}"
34
15
  TARGET_METHOD = "{target_method}"
35
16
  MAX_BATCH_SIZE = {max_batch_size}
36
17
 
37
-
38
18
  # Retrieve the model
39
19
  IMPORT_DIRECTORY_NAME = "snowflake_import_directory"
40
20
  import_dir = sys._xoptions[IMPORT_DIRECTORY_NAME]
41
-
42
- model_dir_name = os.path.splitext(MODEL_FILE_NAME)[0]
43
- zip_model_path = os.path.join(import_dir, MODEL_FILE_NAME)
44
- extracted = "/tmp/models"
45
- extracted_model_dir_path = os.path.join(extracted, model_dir_name)
46
-
47
- with FileLock():
48
- if not os.path.isdir(extracted_model_dir_path):
49
- with zipfile.ZipFile(zip_model_path, "r") as myzip:
50
- myzip.extractall(extracted_model_dir_path)
21
+ model_dir_path = os.path.join(import_dir, MODEL_DIR_REL_PATH)
51
22
 
52
23
  # Load the model
53
- pk = model_packager.ModelPackager(extracted_model_dir_path)
24
+ pk = model_packager.ModelPackager(model_dir_path)
54
25
  pk.load(as_custom_model=True)
55
26
  assert pk.model, "model is not loaded"
56
27
  assert pk.meta, "model metadata is not loaded"
@@ -15,42 +15,18 @@ from _snowflake import vectorized
15
15
  from snowflake.ml.model._packager import model_packager
16
16
 
17
17
 
18
- class FileLock:
19
- def __enter__(self) -> None:
20
- self._lock = threading.Lock()
21
- self._lock.acquire()
22
- self._fd = open("/tmp/lockfile.LOCK", "w+")
23
- fcntl.lockf(self._fd, fcntl.LOCK_EX)
24
-
25
- def __exit__(
26
- self, exc_type: Optional[Type[BaseException]], exc: Optional[BaseException], traceback: Optional[TracebackType]
27
- ) -> None:
28
- self._fd.close()
29
- self._lock.release()
30
-
31
-
32
18
  # User-defined parameters
33
- MODEL_FILE_NAME = "{model_file_name}"
19
+ MODEL_DIR_REL_PATH = "{model_dir_name}"
34
20
  TARGET_METHOD = "{target_method}"
35
21
  MAX_BATCH_SIZE = {max_batch_size}
36
22
 
37
-
38
23
  # Retrieve the model
39
24
  IMPORT_DIRECTORY_NAME = "snowflake_import_directory"
40
25
  import_dir = sys._xoptions[IMPORT_DIRECTORY_NAME]
41
-
42
- model_dir_name = os.path.splitext(MODEL_FILE_NAME)[0]
43
- zip_model_path = os.path.join(import_dir, MODEL_FILE_NAME)
44
- extracted = "/tmp/models"
45
- extracted_model_dir_path = os.path.join(extracted, model_dir_name)
46
-
47
- with FileLock():
48
- if not os.path.isdir(extracted_model_dir_path):
49
- with zipfile.ZipFile(zip_model_path, "r") as myzip:
50
- myzip.extractall(extracted_model_dir_path)
26
+ model_dir_path = os.path.join(import_dir, MODEL_DIR_REL_PATH)
51
27
 
52
28
  # Load the model
53
- pk = model_packager.ModelPackager(extracted_model_dir_path)
29
+ pk = model_packager.ModelPackager(model_dir_path)
54
30
  pk.load(as_custom_model=True)
55
31
  assert pk.model, "model is not loaded"
56
32
  assert pk.meta, "model metadata is not loaded"
@@ -1,12 +1,7 @@
1
- import fcntl
2
1
  import functools
3
2
  import inspect
4
3
  import os
5
4
  import sys
6
- import threading
7
- import zipfile
8
- from types import TracebackType
9
- from typing import Optional, Type
10
5
 
11
6
  import anyio
12
7
  import pandas as pd
@@ -15,42 +10,18 @@ from _snowflake import vectorized
15
10
  from snowflake.ml.model._packager import model_packager
16
11
 
17
12
 
18
- class FileLock:
19
- def __enter__(self) -> None:
20
- self._lock = threading.Lock()
21
- self._lock.acquire()
22
- self._fd = open("/tmp/lockfile.LOCK", "w+")
23
- fcntl.lockf(self._fd, fcntl.LOCK_EX)
24
-
25
- def __exit__(
26
- self, exc_type: Optional[Type[BaseException]], exc: Optional[BaseException], traceback: Optional[TracebackType]
27
- ) -> None:
28
- self._fd.close()
29
- self._lock.release()
30
-
31
-
32
13
  # User-defined parameters
33
- MODEL_FILE_NAME = "{model_file_name}"
14
+ MODEL_DIR_REL_PATH = "{model_dir_name}"
34
15
  TARGET_METHOD = "{target_method}"
35
16
  MAX_BATCH_SIZE = {max_batch_size}
36
17
 
37
-
38
18
  # Retrieve the model
39
19
  IMPORT_DIRECTORY_NAME = "snowflake_import_directory"
40
20
  import_dir = sys._xoptions[IMPORT_DIRECTORY_NAME]
41
-
42
- model_dir_name = os.path.splitext(MODEL_FILE_NAME)[0]
43
- zip_model_path = os.path.join(import_dir, MODEL_FILE_NAME)
44
- extracted = "/tmp/models"
45
- extracted_model_dir_path = os.path.join(extracted, model_dir_name)
46
-
47
- with FileLock():
48
- if not os.path.isdir(extracted_model_dir_path):
49
- with zipfile.ZipFile(zip_model_path, "r") as myzip:
50
- myzip.extractall(extracted_model_dir_path)
21
+ model_dir_path = os.path.join(import_dir, MODEL_DIR_REL_PATH)
51
22
 
52
23
  # Load the model
53
- pk = model_packager.ModelPackager(extracted_model_dir_path)
24
+ pk = model_packager.ModelPackager(model_dir_path)
54
25
  pk.load(as_custom_model=True)
55
26
  assert pk.model, "model is not loaded"
56
27
  assert pk.meta, "model metadata is not loaded"
@@ -26,11 +26,14 @@ class ModelMethodOptions(TypedDict):
26
26
  def get_model_method_options_from_options(
27
27
  options: type_hints.ModelSaveOption, target_method: str
28
28
  ) -> ModelMethodOptions:
29
+ default_function_type = model_manifest_schema.ModelMethodFunctionTypes.FUNCTION.value
30
+ if options.get("enable_explainability", False) and target_method.startswith("explain"):
31
+ default_function_type = model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value
29
32
  method_option = options.get("method_options", {}).get(target_method, {})
30
- global_function_type = options.get("function_type", model_manifest_schema.ModelMethodFunctionTypes.FUNCTION.value)
33
+ global_function_type = options.get("function_type", default_function_type)
31
34
  function_type = method_option.get("function_type", global_function_type)
32
35
  if function_type not in [function_type.value for function_type in model_manifest_schema.ModelMethodFunctionTypes]:
33
- raise NotImplementedError
36
+ raise NotImplementedError(f"Function type {function_type} is not supported.")
34
37
 
35
38
  return ModelMethodOptions(
36
39
  case_sensitive=method_option.get("case_sensitive", False),
@@ -363,9 +363,14 @@ class ModelEnv:
363
363
  self.cuda_version = env_dict.get("cuda_version", None)
364
364
  self.snowpark_ml_version = env_dict["snowpark_ml_version"]
365
365
 
366
- def save_as_dict(self, base_dir: pathlib.Path) -> model_meta_schema.ModelEnvDict:
366
+ def save_as_dict(
367
+ self, base_dir: pathlib.Path, default_channel_override: str = env_utils.SNOWFLAKE_CONDA_CHANNEL_URL
368
+ ) -> model_meta_schema.ModelEnvDict:
367
369
  env_utils.save_conda_env_file(
368
- pathlib.Path(base_dir / self.conda_env_rel_path), self._conda_dependencies, self.python_version
370
+ pathlib.Path(base_dir / self.conda_env_rel_path),
371
+ self._conda_dependencies,
372
+ self.python_version,
373
+ default_channel_override=default_channel_override,
369
374
  )
370
375
  env_utils.save_requirements_file(
371
376
  pathlib.Path(base_dir / self.pip_requirements_rel_path), self._pip_requirements