snowflake-ml-python 1.8.1__py3-none-any.whl → 1.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (170) hide show
  1. snowflake/cortex/_classify_text.py +3 -3
  2. snowflake/cortex/_complete.py +64 -31
  3. snowflake/cortex/_embed_text_1024.py +4 -4
  4. snowflake/cortex/_embed_text_768.py +4 -4
  5. snowflake/cortex/_finetune.py +8 -8
  6. snowflake/cortex/_util.py +8 -12
  7. snowflake/ml/_internal/env.py +4 -3
  8. snowflake/ml/_internal/env_utils.py +63 -34
  9. snowflake/ml/_internal/file_utils.py +10 -21
  10. snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +5 -7
  11. snowflake/ml/_internal/init_utils.py +2 -3
  12. snowflake/ml/_internal/lineage/lineage_utils.py +6 -6
  13. snowflake/ml/_internal/platform_capabilities.py +41 -5
  14. snowflake/ml/_internal/telemetry.py +39 -52
  15. snowflake/ml/_internal/type_utils.py +3 -3
  16. snowflake/ml/_internal/utils/db_utils.py +2 -2
  17. snowflake/ml/_internal/utils/identifier.py +8 -8
  18. snowflake/ml/_internal/utils/import_utils.py +2 -2
  19. snowflake/ml/_internal/utils/parallelize.py +7 -7
  20. snowflake/ml/_internal/utils/pkg_version_utils.py +11 -11
  21. snowflake/ml/_internal/utils/query_result_checker.py +4 -4
  22. snowflake/ml/_internal/utils/snowflake_env.py +28 -6
  23. snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +2 -2
  24. snowflake/ml/_internal/utils/sql_identifier.py +3 -3
  25. snowflake/ml/_internal/utils/table_manager.py +9 -9
  26. snowflake/ml/data/_internal/arrow_ingestor.py +7 -7
  27. snowflake/ml/data/data_connector.py +40 -36
  28. snowflake/ml/data/data_ingestor.py +4 -15
  29. snowflake/ml/data/data_source.py +2 -2
  30. snowflake/ml/data/ingestor_utils.py +3 -3
  31. snowflake/ml/data/torch_utils.py +5 -5
  32. snowflake/ml/dataset/dataset.py +11 -11
  33. snowflake/ml/dataset/dataset_metadata.py +8 -8
  34. snowflake/ml/dataset/dataset_reader.py +12 -8
  35. snowflake/ml/feature_store/__init__.py +1 -1
  36. snowflake/ml/feature_store/access_manager.py +7 -7
  37. snowflake/ml/feature_store/entity.py +6 -6
  38. snowflake/ml/feature_store/examples/airline_features/entities.py +1 -3
  39. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +1 -3
  40. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +1 -3
  41. snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +1 -3
  42. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +1 -3
  43. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +1 -3
  44. snowflake/ml/feature_store/examples/example_helper.py +16 -16
  45. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +1 -3
  46. snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +1 -3
  47. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +1 -3
  48. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +1 -3
  49. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +1 -3
  50. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +1 -3
  51. snowflake/ml/feature_store/feature_store.py +52 -64
  52. snowflake/ml/feature_store/feature_view.py +24 -24
  53. snowflake/ml/fileset/embedded_stage_fs.py +5 -5
  54. snowflake/ml/fileset/fileset.py +5 -5
  55. snowflake/ml/fileset/sfcfs.py +13 -13
  56. snowflake/ml/fileset/stage_fs.py +15 -15
  57. snowflake/ml/jobs/_utils/constants.py +2 -4
  58. snowflake/ml/jobs/_utils/interop_utils.py +442 -0
  59. snowflake/ml/jobs/_utils/payload_utils.py +86 -62
  60. snowflake/ml/jobs/_utils/scripts/constants.py +4 -0
  61. snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +136 -0
  62. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +181 -0
  63. snowflake/ml/jobs/_utils/scripts/signal_workers.py +203 -0
  64. snowflake/ml/jobs/_utils/scripts/worker_shutdown_listener.py +242 -0
  65. snowflake/ml/jobs/_utils/spec_utils.py +22 -36
  66. snowflake/ml/jobs/_utils/types.py +8 -2
  67. snowflake/ml/jobs/decorators.py +7 -8
  68. snowflake/ml/jobs/job.py +158 -26
  69. snowflake/ml/jobs/manager.py +78 -30
  70. snowflake/ml/lineage/lineage_node.py +5 -5
  71. snowflake/ml/model/_client/model/model_impl.py +3 -3
  72. snowflake/ml/model/_client/model/model_version_impl.py +103 -35
  73. snowflake/ml/model/_client/ops/metadata_ops.py +7 -7
  74. snowflake/ml/model/_client/ops/model_ops.py +41 -41
  75. snowflake/ml/model/_client/ops/service_ops.py +230 -50
  76. snowflake/ml/model/_client/service/model_deployment_spec.py +175 -48
  77. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +44 -24
  78. snowflake/ml/model/_client/sql/model.py +8 -8
  79. snowflake/ml/model/_client/sql/model_version.py +26 -26
  80. snowflake/ml/model/_client/sql/service.py +22 -18
  81. snowflake/ml/model/_client/sql/stage.py +2 -2
  82. snowflake/ml/model/_client/sql/tag.py +6 -6
  83. snowflake/ml/model/_model_composer/model_composer.py +46 -25
  84. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -16
  85. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +14 -13
  86. snowflake/ml/model/_model_composer/model_method/model_method.py +3 -3
  87. snowflake/ml/model/_packager/model_env/model_env.py +35 -26
  88. snowflake/ml/model/_packager/model_handler.py +4 -4
  89. snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
  90. snowflake/ml/model/_packager/model_handlers/_utils.py +15 -3
  91. snowflake/ml/model/_packager/model_handlers/catboost.py +5 -5
  92. snowflake/ml/model/_packager/model_handlers/custom.py +8 -4
  93. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -21
  94. snowflake/ml/model/_packager/model_handlers/keras.py +4 -4
  95. snowflake/ml/model/_packager/model_handlers/lightgbm.py +4 -14
  96. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -3
  97. snowflake/ml/model/_packager/model_handlers/pytorch.py +4 -4
  98. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +5 -5
  99. snowflake/ml/model/_packager/model_handlers/sklearn.py +5 -6
  100. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +3 -3
  101. snowflake/ml/model/_packager/model_handlers/tensorflow.py +4 -4
  102. snowflake/ml/model/_packager/model_handlers/torchscript.py +4 -4
  103. snowflake/ml/model/_packager/model_handlers/xgboost.py +5 -15
  104. snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -2
  105. snowflake/ml/model/_packager/model_meta/model_meta.py +42 -37
  106. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +13 -11
  107. snowflake/ml/model/_packager/model_meta_migrator/base_migrator.py +3 -3
  108. snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -3
  109. snowflake/ml/model/_packager/model_meta_migrator/migrator_v1.py +4 -4
  110. snowflake/ml/model/_packager/model_packager.py +12 -8
  111. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +32 -1
  112. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
  113. snowflake/ml/model/_signatures/core.py +16 -24
  114. snowflake/ml/model/_signatures/dmatrix_handler.py +2 -2
  115. snowflake/ml/model/_signatures/utils.py +6 -6
  116. snowflake/ml/model/custom_model.py +8 -8
  117. snowflake/ml/model/model_signature.py +9 -20
  118. snowflake/ml/model/models/huggingface_pipeline.py +7 -4
  119. snowflake/ml/model/type_hints.py +5 -3
  120. snowflake/ml/modeling/_internal/estimator_utils.py +7 -7
  121. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +6 -6
  122. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +7 -7
  123. snowflake/ml/modeling/_internal/model_specifications.py +8 -10
  124. snowflake/ml/modeling/_internal/model_trainer.py +5 -5
  125. snowflake/ml/modeling/_internal/model_trainer_builder.py +6 -6
  126. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +30 -30
  127. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +13 -13
  128. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +31 -31
  129. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +19 -19
  130. snowflake/ml/modeling/_internal/transformer_protocols.py +17 -17
  131. snowflake/ml/modeling/framework/_utils.py +10 -10
  132. snowflake/ml/modeling/framework/base.py +32 -32
  133. snowflake/ml/modeling/impute/__init__.py +1 -1
  134. snowflake/ml/modeling/impute/simple_imputer.py +5 -5
  135. snowflake/ml/modeling/metrics/__init__.py +1 -1
  136. snowflake/ml/modeling/metrics/classification.py +39 -39
  137. snowflake/ml/modeling/metrics/metrics_utils.py +12 -12
  138. snowflake/ml/modeling/metrics/ranking.py +7 -7
  139. snowflake/ml/modeling/metrics/regression.py +13 -13
  140. snowflake/ml/modeling/model_selection/__init__.py +1 -1
  141. snowflake/ml/modeling/model_selection/grid_search_cv.py +7 -7
  142. snowflake/ml/modeling/model_selection/randomized_search_cv.py +7 -7
  143. snowflake/ml/modeling/pipeline/__init__.py +1 -1
  144. snowflake/ml/modeling/pipeline/pipeline.py +18 -18
  145. snowflake/ml/modeling/preprocessing/__init__.py +1 -1
  146. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +13 -13
  147. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +4 -4
  148. snowflake/ml/modeling/preprocessing/min_max_scaler.py +8 -8
  149. snowflake/ml/modeling/preprocessing/normalizer.py +0 -1
  150. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +28 -28
  151. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -9
  152. snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -7
  153. snowflake/ml/modeling/preprocessing/standard_scaler.py +5 -5
  154. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +26 -26
  155. snowflake/ml/monitoring/_manager/model_monitor_manager.py +5 -5
  156. snowflake/ml/monitoring/entities/model_monitor_config.py +6 -6
  157. snowflake/ml/registry/_manager/model_manager.py +50 -29
  158. snowflake/ml/registry/registry.py +34 -23
  159. snowflake/ml/utils/authentication.py +2 -2
  160. snowflake/ml/utils/connection_params.py +5 -5
  161. snowflake/ml/utils/sparse.py +5 -4
  162. snowflake/ml/utils/sql_client.py +1 -2
  163. snowflake/ml/version.py +2 -1
  164. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/METADATA +46 -6
  165. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/RECORD +168 -164
  166. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/WHEEL +1 -1
  167. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
  168. snowflake/ml/modeling/_internal/constants.py +0 -2
  169. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/licenses/LICENSE.txt +0 -0
  170. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- from typing import List, Optional, Union, cast
1
+ from typing import Optional, Union, cast
2
2
 
3
3
  from typing_extensions import deprecated
4
4
 
@@ -12,7 +12,7 @@ from snowflake.ml._internal import telemetry
12
12
  )
13
13
  def classify_text(
14
14
  str_input: Union[str, snowpark.Column],
15
- categories: Union[List[str], snowpark.Column],
15
+ categories: Union[list[str], snowpark.Column],
16
16
  session: Optional[snowpark.Session] = None,
17
17
  ) -> Union[str, snowpark.Column]:
18
18
  """Use the LLM inference service to classify the INPUT text into one of the target CATEGORIES.
@@ -32,7 +32,7 @@ def classify_text(
32
32
  def _classify_text_impl(
33
33
  function: str,
34
34
  str_input: Union[str, snowpark.Column],
35
- categories: Union[List[str], snowpark.Column],
35
+ categories: Union[list[str], snowpark.Column],
36
36
  session: Optional[snowpark.Session] = None,
37
37
  ) -> Union[str, snowpark.Column]:
38
38
  return cast(Union[str, snowpark.Column], call_sql_function(function, session, str_input, categories))
@@ -1,11 +1,13 @@
1
1
  import json
2
2
  import logging
3
3
  import time
4
+ import typing
4
5
  from io import BytesIO
5
- from typing import Any, Callable, Dict, Iterator, List, Optional, TypedDict, Union, cast
6
+ from typing import Any, Callable, Iterator, Optional, TypedDict, Union, cast
6
7
  from urllib.parse import urlunparse
7
8
 
8
9
  import requests
10
+ from snowflake.core.rest import RESTResponse
9
11
  from typing_extensions import NotRequired, deprecated
10
12
 
11
13
  from snowflake import snowpark
@@ -28,7 +30,7 @@ class ResponseFormat(TypedDict):
28
30
 
29
31
  type: str
30
32
  """The response format type (e.g. "json")"""
31
- schema: Dict[str, Any]
33
+ schema: dict[str, Any]
32
34
  """The schema defining the structure of the response. For json it should be a valid json schema object"""
33
35
 
34
36
 
@@ -69,7 +71,27 @@ class CompleteOptions(TypedDict):
69
71
  class ResponseParseException(Exception):
70
72
  """This exception is raised when the server response cannot be parsed."""
71
73
 
72
- pass
74
+
75
+ class MidStreamException(Exception):
76
+ """The SSE (Server-sent Event) stream can contain error messages in the middle of the stream,
77
+ using the “error” event type. This exception is raised when there is such a mid-stream error.
78
+ """
79
+
80
+ def __init__(
81
+ self,
82
+ reason: typing.Optional[str] = None,
83
+ http_resp: typing.Optional["RESTResponse"] = None,
84
+ request_id: typing.Optional[str] = None,
85
+ ) -> None:
86
+ message = ""
87
+ if reason is not None:
88
+ message = reason
89
+ if http_resp:
90
+ message = f"Error in stream (HTTP Response: {http_resp.status}) - {http_resp.reason}"
91
+ if request_id != "":
92
+ # add request_id to error message
93
+ message += f" (Request ID: {request_id})"
94
+ super().__init__(message)
73
95
 
74
96
 
75
97
  class GuardrailsOptions(TypedDict):
@@ -112,7 +134,7 @@ def retry(func: Callable[..., requests.Response]) -> Callable[..., requests.Resp
112
134
  return inner
113
135
 
114
136
 
115
- def _make_common_request_headers() -> Dict[str, str]:
137
+ def _make_common_request_headers() -> dict[str, str]:
116
138
  headers = {
117
139
  "Content-Type": "application/json",
118
140
  "Accept": "application/json, text/event-stream",
@@ -120,6 +142,18 @@ def _make_common_request_headers() -> Dict[str, str]:
120
142
  return headers
121
143
 
122
144
 
145
+ def _get_request_id(resp: dict[str, Any]) -> Optional[Any]:
146
+ request_id = None
147
+ if "headers" in resp:
148
+ for key, value in resp["headers"].items():
149
+ # Note: There is some whitespace in the headers making it not possible
150
+ # to directly index the header reliably.
151
+ if key.strip().lower() == "x-snowflake-request-id":
152
+ request_id = value
153
+ break
154
+ return request_id
155
+
156
+
123
157
  def _validate_response_format_object(options: CompleteOptions) -> None:
124
158
  """Validate the response format object for structured-output mode.
125
159
 
@@ -148,14 +182,14 @@ def _validate_response_format_object(options: CompleteOptions) -> None:
148
182
 
149
183
  def _make_request_body(
150
184
  model: str,
151
- prompt: Union[str, List[ConversationMessage]],
185
+ prompt: Union[str, list[ConversationMessage]],
152
186
  options: Optional[CompleteOptions] = None,
153
- ) -> Dict[str, Any]:
187
+ ) -> dict[str, Any]:
154
188
  data = {
155
189
  "model": model,
156
190
  "stream": True,
157
191
  }
158
- if isinstance(prompt, List):
192
+ if isinstance(prompt, list):
159
193
  data["messages"] = prompt
160
194
  else:
161
195
  data["messages"] = [{"content": prompt}]
@@ -182,19 +216,13 @@ def _make_request_body(
182
216
 
183
217
  # XP endpoint returns a dict response which needs to be converted to a format which can
184
218
  # be consumed by the SSEClient. This method does that.
185
- def _xp_dict_to_response(raw_resp: Dict[str, Any]) -> requests.Response:
219
+ def _xp_dict_to_response(raw_resp: dict[str, Any]) -> requests.Response:
186
220
 
187
221
  response = requests.Response()
188
222
  response.status_code = int(raw_resp["status"])
189
223
  response.headers = raw_resp["headers"]
190
224
 
191
- request_id = None
192
- for key, value in raw_resp["headers"].items():
193
- # Note: there is some whitespace in the headers making it not possible
194
- # to directly index the header reliably.
195
- if key.strip().lower() == "x-snowflake-request-id":
196
- request_id = value
197
- break
225
+ request_id = _get_request_id(raw_resp)
198
226
 
199
227
  data = raw_resp["content"]
200
228
  try:
@@ -222,9 +250,9 @@ def _xp_dict_to_response(raw_resp: Dict[str, Any]) -> requests.Response:
222
250
 
223
251
  @retry
224
252
  def _call_complete_xp(
225
- snow_api_xp_request_handler: Optional[Callable[..., Dict[str, Any]]],
253
+ snow_api_xp_request_handler: Optional[Callable[..., dict[str, Any]]],
226
254
  model: str,
227
- prompt: Union[str, List[ConversationMessage]],
255
+ prompt: Union[str, list[ConversationMessage]],
228
256
  options: Optional[CompleteOptions] = None,
229
257
  deadline: Optional[float] = None,
230
258
  ) -> requests.Response:
@@ -238,7 +266,7 @@ def _call_complete_xp(
238
266
  @retry
239
267
  def _call_complete_rest(
240
268
  model: str,
241
- prompt: Union[str, List[ConversationMessage]],
269
+ prompt: Union[str, list[ConversationMessage]],
242
270
  options: Optional[CompleteOptions] = None,
243
271
  session: Optional[snowpark.Session] = None,
244
272
  ) -> requests.Response:
@@ -276,7 +304,12 @@ def _call_complete_rest(
276
304
  )
277
305
 
278
306
 
279
- def _return_stream_response(response: requests.Response, deadline: Optional[float]) -> Iterator[str]:
307
+ def _return_stream_response(
308
+ response: requests.Response,
309
+ deadline: Optional[float],
310
+ session: Optional[snowpark.Session] = None,
311
+ ) -> Iterator[str]:
312
+ request_id = _get_request_id(dict(response.headers))
280
313
  client = SSEClient(response)
281
314
  for event in client.events():
282
315
  if deadline is not None and time.time() > deadline:
@@ -294,7 +327,7 @@ def _return_stream_response(response: requests.Response, deadline: Optional[floa
294
327
  # This is the case of midstream errors which were introduced specifically for structured output.
295
328
  # TODO: discuss during code review
296
329
  if parsed_resp.get("error"):
297
- yield json.dumps(parsed_resp)
330
+ raise MidStreamException(reason=response.text, request_id=request_id)
298
331
  else:
299
332
  pass
300
333
 
@@ -306,9 +339,9 @@ def _complete_call_sql_function_snowpark(
306
339
 
307
340
 
308
341
  def _complete_non_streaming_immediate(
309
- snow_api_xp_request_handler: Optional[Callable[..., Dict[str, Any]]],
342
+ snow_api_xp_request_handler: Optional[Callable[..., dict[str, Any]]],
310
343
  model: str,
311
- prompt: Union[str, List[ConversationMessage]],
344
+ prompt: Union[str, list[ConversationMessage]],
312
345
  options: Optional[CompleteOptions],
313
346
  session: Optional[snowpark.Session] = None,
314
347
  deadline: Optional[float] = None,
@@ -325,10 +358,10 @@ def _complete_non_streaming_immediate(
325
358
 
326
359
 
327
360
  def _complete_non_streaming_impl(
328
- snow_api_xp_request_handler: Optional[Callable[..., Dict[str, Any]]],
361
+ snow_api_xp_request_handler: Optional[Callable[..., dict[str, Any]]],
329
362
  function: str,
330
363
  model: Union[str, snowpark.Column],
331
- prompt: Union[str, List[ConversationMessage], snowpark.Column],
364
+ prompt: Union[str, list[ConversationMessage], snowpark.Column],
332
365
  options: Optional[Union[CompleteOptions, snowpark.Column]],
333
366
  session: Optional[snowpark.Session] = None,
334
367
  deadline: Optional[float] = None,
@@ -355,9 +388,9 @@ def _complete_non_streaming_impl(
355
388
 
356
389
 
357
390
  def _complete_rest(
358
- snow_api_xp_request_handler: Optional[Callable[..., Dict[str, Any]]],
391
+ snow_api_xp_request_handler: Optional[Callable[..., dict[str, Any]]],
359
392
  model: str,
360
- prompt: Union[str, List[ConversationMessage]],
393
+ prompt: Union[str, list[ConversationMessage]],
361
394
  options: Optional[CompleteOptions] = None,
362
395
  session: Optional[snowpark.Session] = None,
363
396
  deadline: Optional[float] = None,
@@ -375,13 +408,13 @@ def _complete_rest(
375
408
  else:
376
409
  response = _call_complete_rest(model=model, prompt=prompt, options=options, session=session, deadline=deadline)
377
410
  assert response.status_code >= 200 and response.status_code < 300
378
- return _return_stream_response(response, deadline)
411
+ return _return_stream_response(response, deadline, session)
379
412
 
380
413
 
381
414
  def _complete_impl(
382
415
  model: Union[str, snowpark.Column],
383
- prompt: Union[str, List[ConversationMessage], snowpark.Column],
384
- snow_api_xp_request_handler: Optional[Callable[..., Dict[str, Any]]] = None,
416
+ prompt: Union[str, list[ConversationMessage], snowpark.Column],
417
+ snow_api_xp_request_handler: Optional[Callable[..., dict[str, Any]]] = None,
385
418
  function: str = "snowflake.cortex.complete",
386
419
  options: Optional[CompleteOptions] = None,
387
420
  session: Optional[snowpark.Session] = None,
@@ -396,7 +429,7 @@ def _complete_impl(
396
429
  if stream:
397
430
  if not isinstance(model, str):
398
431
  raise ValueError("in REST mode, 'model' must be a string")
399
- if not isinstance(prompt, str) and not isinstance(prompt, List):
432
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
400
433
  raise ValueError("in REST mode, 'prompt' must be a string or a list of ConversationMessage")
401
434
  return _complete_rest(
402
435
  snow_api_xp_request_handler=snow_api_xp_request_handler,
@@ -422,7 +455,7 @@ def _complete_impl(
422
455
  )
423
456
  def complete(
424
457
  model: Union[str, snowpark.Column],
425
- prompt: Union[str, List[ConversationMessage], snowpark.Column],
458
+ prompt: Union[str, list[ConversationMessage], snowpark.Column],
426
459
  *,
427
460
  options: Optional[CompleteOptions] = None,
428
461
  session: Optional[snowpark.Session] = None,
@@ -1,4 +1,4 @@
1
- from typing import List, Optional, Union, cast
1
+ from typing import Optional, Union, cast
2
2
 
3
3
  from typing_extensions import deprecated
4
4
 
@@ -14,7 +14,7 @@ def embed_text_1024(
14
14
  model: Union[str, snowpark.Column],
15
15
  text: Union[str, snowpark.Column],
16
16
  session: Optional[snowpark.Session] = None,
17
- ) -> Union[List[float], snowpark.Column]:
17
+ ) -> Union[list[float], snowpark.Column]:
18
18
  """Calls into the LLM inference service to embed the text.
19
19
 
20
20
  Args:
@@ -35,8 +35,8 @@ def _embed_text_1024_impl(
35
35
  model: Union[str, snowpark.Column],
36
36
  text: Union[str, snowpark.Column],
37
37
  session: Optional[snowpark.Session] = None,
38
- ) -> Union[List[float], snowpark.Column]:
39
- return cast(Union[List[float], snowpark.Column], call_sql_function(function, session, model, text))
38
+ ) -> Union[list[float], snowpark.Column]:
39
+ return cast(Union[list[float], snowpark.Column], call_sql_function(function, session, model, text))
40
40
 
41
41
 
42
42
  EmbedText1024 = deprecated(
@@ -1,4 +1,4 @@
1
- from typing import List, Optional, Union, cast
1
+ from typing import Optional, Union, cast
2
2
 
3
3
  from typing_extensions import deprecated
4
4
 
@@ -14,7 +14,7 @@ def embed_text_768(
14
14
  model: Union[str, snowpark.Column],
15
15
  text: Union[str, snowpark.Column],
16
16
  session: Optional[snowpark.Session] = None,
17
- ) -> Union[List[float], snowpark.Column]:
17
+ ) -> Union[list[float], snowpark.Column]:
18
18
  """Calls into the LLM inference service to embed the text.
19
19
 
20
20
  Args:
@@ -35,8 +35,8 @@ def _embed_text_768_impl(
35
35
  model: Union[str, snowpark.Column],
36
36
  text: Union[str, snowpark.Column],
37
37
  session: Optional[snowpark.Session] = None,
38
- ) -> Union[List[float], snowpark.Column]:
39
- return cast(Union[List[float], snowpark.Column], call_sql_function(function, session, model, text))
38
+ ) -> Union[list[float], snowpark.Column]:
39
+ return cast(Union[list[float], snowpark.Column], call_sql_function(function, session, model, text))
40
40
 
41
41
 
42
42
  EmbedText768 = deprecated(
@@ -1,6 +1,6 @@
1
1
  import json
2
2
  from dataclasses import dataclass
3
- from typing import Any, Dict, List, Optional, Union, cast
3
+ from typing import Any, Optional, Union, cast
4
4
 
5
5
  from snowflake import snowpark
6
6
  from snowflake.cortex._util import (
@@ -53,7 +53,7 @@ class FinetuneStatus:
53
53
  created_on: Optional[int] = None
54
54
  """Creation timestamp of the Fine-tuning job in milliseconds."""
55
55
 
56
- error: Optional[Dict[str, Any]] = None
56
+ error: Optional[dict[str, Any]] = None
57
57
  """Error message propagated from the job."""
58
58
 
59
59
  finished_on: Optional[int] = None
@@ -62,7 +62,7 @@ class FinetuneStatus:
62
62
  progress: Optional[float] = None
63
63
  """Progress made as a fraction of total [0.0,1.0]."""
64
64
 
65
- training_result: Optional[List[Dict[str, Any]]] = None
65
+ training_result: Optional[list[dict[str, Any]]] = None
66
66
  """Detailed metrics report for a completed training."""
67
67
 
68
68
  trained_tokens: Optional[int] = None
@@ -135,7 +135,7 @@ class FinetuneJob:
135
135
  """
136
136
  result_string = _finetune_impl(operation="DESCRIBE", session=self._session, function_args=[self.status.id])
137
137
 
138
- result = FinetuneStatus(**cast(Dict[str, Any], _try_load_json(result_string)))
138
+ result = FinetuneStatus(**cast(dict[str, Any], _try_load_json(result_string)))
139
139
  return result
140
140
 
141
141
 
@@ -167,7 +167,7 @@ class Finetune:
167
167
  base_model: str,
168
168
  training_data: Union[str, snowpark.DataFrame],
169
169
  validation_data: Optional[Union[str, snowpark.DataFrame]] = None,
170
- options: Optional[Dict[str, Any]] = None,
170
+ options: Optional[dict[str, Any]] = None,
171
171
  ) -> FinetuneJob:
172
172
  """Create a new fine-tuning runs.
173
173
 
@@ -240,7 +240,7 @@ class Finetune:
240
240
  project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT,
241
241
  subproject=CORTEX_FINETUNE_TELEMETRY_SUBPROJECT,
242
242
  )
243
- def list_jobs(self) -> List["FinetuneJob"]:
243
+ def list_jobs(self) -> list["FinetuneJob"]:
244
244
  """Show current and past fine-tuning runs.
245
245
 
246
246
  Returns:
@@ -253,7 +253,7 @@ class Finetune:
253
253
  return [FinetuneJob(session=self._session, status=FinetuneStatus(**run_status)) for run_status in result]
254
254
 
255
255
 
256
- def _try_load_json(json_string: str) -> Union[Dict[Any, Any], List[Any]]:
256
+ def _try_load_json(json_string: str) -> Union[dict[Any, Any], list[Any]]:
257
257
  try:
258
258
  result = json.loads(str(json_string))
259
259
  except json.JSONDecodeError as e:
@@ -269,5 +269,5 @@ def _try_load_json(json_string: str) -> Union[Dict[Any, Any], List[Any]]:
269
269
  return result
270
270
 
271
271
 
272
- def _finetune_impl(operation: str, session: Optional[snowpark.Session], function_args: List[Any]) -> str:
272
+ def _finetune_impl(operation: str, session: Optional[snowpark.Session], function_args: list[Any]) -> str:
273
273
  return call_sql_function_literals(_CORTEX_FINETUNE_SYSTEM_FUNCTION_NAME, session, operation, *function_args)
snowflake/cortex/_util.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Any, Dict, List, Optional, Union, cast
1
+ from typing import Any, Optional, Union, cast
2
2
 
3
3
  from snowflake import snowpark
4
4
  from snowflake.ml._internal.exceptions import error_codes, exceptions
@@ -11,22 +11,18 @@ CORTEX_FUNCTIONS_TELEMETRY_PROJECT = "CortexFunctions"
11
11
  class SnowflakeAuthenticationException(Exception):
12
12
  """This exception is raised when there is an issue with Snowflake's configuration."""
13
13
 
14
- pass
15
-
16
14
 
17
15
  class SnowflakeConfigurationException(Exception):
18
16
  """This exception is raised when there is an issue with Snowflake's configuration."""
19
17
 
20
- pass
21
-
22
18
 
23
19
  # Calls a sql function, handling both immediate (e.g. python types) and batch
24
20
  # (e.g. snowpark column and literal type modes).
25
21
  def call_sql_function(
26
22
  function: str,
27
23
  session: Optional[snowpark.Session],
28
- *args: Union[str, List[str], snowpark.Column, Dict[str, Union[int, float]]],
29
- ) -> Union[str, List[float], snowpark.Column]:
24
+ *args: Union[str, list[str], snowpark.Column, dict[str, Union[int, float]]],
25
+ ) -> Union[str, list[float], snowpark.Column]:
30
26
  handle_as_column = False
31
27
 
32
28
  for arg in args:
@@ -34,15 +30,15 @@ def call_sql_function(
34
30
  handle_as_column = True
35
31
 
36
32
  if handle_as_column:
37
- return cast(Union[str, List[float], snowpark.Column], _call_sql_function_column(function, *args))
33
+ return cast(Union[str, list[float], snowpark.Column], _call_sql_function_column(function, *args))
38
34
  return cast(
39
- Union[str, List[float], snowpark.Column],
35
+ Union[str, list[float], snowpark.Column],
40
36
  _call_sql_function_immediate(function, session, *args),
41
37
  )
42
38
 
43
39
 
44
40
  def _call_sql_function_column(
45
- function: str, *args: Union[str, List[str], snowpark.Column, Dict[str, Union[int, float]]]
41
+ function: str, *args: Union[str, list[str], snowpark.Column, dict[str, Union[int, float]]]
46
42
  ) -> snowpark.Column:
47
43
  return cast(snowpark.Column, functions.builtin(function)(*args))
48
44
 
@@ -50,8 +46,8 @@ def _call_sql_function_column(
50
46
  def _call_sql_function_immediate(
51
47
  function: str,
52
48
  session: Optional[snowpark.Session],
53
- *args: Union[str, List[str], snowpark.Column, Dict[str, Union[int, float]]],
54
- ) -> Union[str, List[float]]:
49
+ *args: Union[str, list[str], snowpark.Column, dict[str, Union[int, float]]],
50
+ ) -> Union[str, list[float]]:
55
51
  session = session or context.get_active_session()
56
52
  if session is None:
57
53
  raise SnowflakeAuthenticationException(
@@ -1,8 +1,9 @@
1
+ import os
1
2
  import platform
2
3
 
3
- from snowflake.ml import version
4
-
5
4
  SOURCE = "SnowML"
6
- VERSION = version.VERSION
7
5
  PYTHON_VERSION = platform.python_version()
8
6
  OS = platform.system()
7
+ IN_ML_RUNTIME_ENV_VAR = "IN_SPCS_ML_RUNTIME"
8
+ IN_ML_RUNTIME = os.getenv(IN_ML_RUNTIME_ENV_VAR)
9
+ USE_OPTIMIZED_DATA_INGESTOR = "USE_OPTIMIZED_DATA_INGESTOR"