snowflake-ml-python 1.9.2__py3-none-any.whl → 1.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/utils/service_logger.py +31 -17
- snowflake/ml/experiment/callback/lightgbm.py +55 -0
- snowflake/ml/experiment/callback/xgboost.py +63 -0
- snowflake/ml/experiment/utils.py +14 -0
- snowflake/ml/jobs/_utils/payload_utils.py +13 -7
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +2 -2
- snowflake/ml/model/_client/model/model_version_impl.py +56 -48
- snowflake/ml/model/_client/ops/service_ops.py +177 -12
- snowflake/ml/model/event_handler.py +87 -18
- snowflake/ml/model/models/huggingface_pipeline.py +71 -49
- snowflake/ml/model/type_hints.py +26 -1
- snowflake/ml/registry/_manager/model_manager.py +30 -35
- snowflake/ml/registry/_manager/model_parameter_reconciler.py +105 -0
- snowflake/ml/registry/registry.py +0 -19
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.9.2.dist-info → snowflake_ml_python-1.10.0.dist-info}/METADATA +505 -492
- {snowflake_ml_python-1.9.2.dist-info → snowflake_ml_python-1.10.0.dist-info}/RECORD +20 -17
- snowflake/ml/experiment/callback.py +0 -121
- {snowflake_ml_python-1.9.2.dist-info → snowflake_ml_python-1.10.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.9.2.dist-info → snowflake_ml_python-1.10.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.9.2.dist-info → snowflake_ml_python-1.10.0.dist-info}/top_level.txt +0 -0
|
@@ -180,9 +180,7 @@ class ServiceOperator:
|
|
|
180
180
|
service_name: sql_identifier.SqlIdentifier,
|
|
181
181
|
image_build_compute_pool_name: sql_identifier.SqlIdentifier,
|
|
182
182
|
service_compute_pool_name: sql_identifier.SqlIdentifier,
|
|
183
|
-
|
|
184
|
-
image_repo_schema_name: Optional[sql_identifier.SqlIdentifier],
|
|
185
|
-
image_repo_name: sql_identifier.SqlIdentifier,
|
|
183
|
+
image_repo: str,
|
|
186
184
|
ingress_enabled: bool,
|
|
187
185
|
max_instances: int,
|
|
188
186
|
cpu_requests: Optional[str],
|
|
@@ -193,6 +191,7 @@ class ServiceOperator:
|
|
|
193
191
|
force_rebuild: bool,
|
|
194
192
|
build_external_access_integrations: Optional[list[sql_identifier.SqlIdentifier]],
|
|
195
193
|
block: bool,
|
|
194
|
+
progress_status: type_hints.ProgressStatus,
|
|
196
195
|
statement_params: Optional[dict[str, Any]] = None,
|
|
197
196
|
# hf model
|
|
198
197
|
hf_model_args: Optional[HFModelArgs] = None,
|
|
@@ -209,8 +208,17 @@ class ServiceOperator:
|
|
|
209
208
|
service_database_name = service_database_name or database_name or self._database_name
|
|
210
209
|
service_schema_name = service_schema_name or schema_name or self._schema_name
|
|
211
210
|
|
|
211
|
+
# Parse image repo
|
|
212
|
+
image_repo_database_name, image_repo_schema_name, image_repo_name = sql_identifier.parse_fully_qualified_name(
|
|
213
|
+
image_repo
|
|
214
|
+
)
|
|
212
215
|
image_repo_database_name = image_repo_database_name or database_name or self._database_name
|
|
213
216
|
image_repo_schema_name = image_repo_schema_name or schema_name or self._schema_name
|
|
217
|
+
|
|
218
|
+
# Step 1: Preparing deployment artifacts
|
|
219
|
+
progress_status.update("preparing deployment artifacts...")
|
|
220
|
+
progress_status.increment()
|
|
221
|
+
|
|
214
222
|
if self._workspace:
|
|
215
223
|
stage_path = self._create_temp_stage(database_name, schema_name, statement_params)
|
|
216
224
|
else:
|
|
@@ -259,6 +267,11 @@ class ServiceOperator:
|
|
|
259
267
|
**(hf_model_args.hf_model_kwargs if hf_model_args.hf_model_kwargs else {}),
|
|
260
268
|
)
|
|
261
269
|
spec_yaml_str_or_path = self._model_deployment_spec.save()
|
|
270
|
+
|
|
271
|
+
# Step 2: Uploading deployment artifacts
|
|
272
|
+
progress_status.update("uploading deployment artifacts...")
|
|
273
|
+
progress_status.increment()
|
|
274
|
+
|
|
262
275
|
if self._workspace:
|
|
263
276
|
assert stage_path is not None
|
|
264
277
|
file_utils.upload_directory_to_stage(
|
|
@@ -281,6 +294,10 @@ class ServiceOperator:
|
|
|
281
294
|
statement_params=statement_params,
|
|
282
295
|
)
|
|
283
296
|
|
|
297
|
+
# Step 3: Initiating model deployment
|
|
298
|
+
progress_status.update("initiating model deployment...")
|
|
299
|
+
progress_status.increment()
|
|
300
|
+
|
|
284
301
|
# deploy the model service
|
|
285
302
|
query_id, async_job = self._service_client.deploy_model(
|
|
286
303
|
stage_path=stage_path if self._workspace else None,
|
|
@@ -337,13 +354,63 @@ class ServiceOperator:
|
|
|
337
354
|
)
|
|
338
355
|
|
|
339
356
|
if block:
|
|
340
|
-
|
|
357
|
+
try:
|
|
358
|
+
# Step 4: Starting model build: waits for build to start
|
|
359
|
+
progress_status.update("starting model image build...")
|
|
360
|
+
progress_status.increment()
|
|
361
|
+
|
|
362
|
+
# Poll for model build to start if not using existing service
|
|
363
|
+
if not model_inference_service_exists:
|
|
364
|
+
self._wait_for_service_status(
|
|
365
|
+
model_build_service_name,
|
|
366
|
+
service_sql.ServiceStatus.RUNNING,
|
|
367
|
+
service_database_name,
|
|
368
|
+
service_schema_name,
|
|
369
|
+
async_job,
|
|
370
|
+
statement_params,
|
|
371
|
+
)
|
|
341
372
|
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
373
|
+
# Step 5: Building model image
|
|
374
|
+
progress_status.update("building model image...")
|
|
375
|
+
progress_status.increment()
|
|
376
|
+
|
|
377
|
+
# Poll for model build completion
|
|
378
|
+
if not model_inference_service_exists:
|
|
379
|
+
self._wait_for_service_status(
|
|
380
|
+
model_build_service_name,
|
|
381
|
+
service_sql.ServiceStatus.DONE,
|
|
382
|
+
service_database_name,
|
|
383
|
+
service_schema_name,
|
|
384
|
+
async_job,
|
|
385
|
+
statement_params,
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
# Step 6: Deploying model service (push complete, starting inference service)
|
|
389
|
+
progress_status.update("deploying model service...")
|
|
390
|
+
progress_status.increment()
|
|
391
|
+
|
|
392
|
+
log_thread.join()
|
|
393
|
+
|
|
394
|
+
res = cast(str, cast(list[row.Row], async_job.result())[0][0])
|
|
395
|
+
module_logger.info(f"Inference service {service_name} deployment complete: {res}")
|
|
396
|
+
return res
|
|
397
|
+
|
|
398
|
+
except RuntimeError as e:
|
|
399
|
+
# Handle service creation/deployment failures
|
|
400
|
+
error_msg = f"Model service deployment failed: {str(e)}"
|
|
401
|
+
module_logger.error(error_msg)
|
|
402
|
+
|
|
403
|
+
# Update progress status to show failure
|
|
404
|
+
progress_status.update(error_msg, state="error")
|
|
405
|
+
|
|
406
|
+
# Stop the log thread if it's running
|
|
407
|
+
if "log_thread" in locals() and log_thread.is_alive():
|
|
408
|
+
log_thread.join(timeout=5) # Give it a few seconds to finish gracefully
|
|
409
|
+
|
|
410
|
+
# Re-raise the exception to propagate the error
|
|
411
|
+
raise RuntimeError(error_msg) from e
|
|
412
|
+
|
|
413
|
+
return async_job
|
|
347
414
|
|
|
348
415
|
def _start_service_log_streaming(
|
|
349
416
|
self,
|
|
@@ -579,6 +646,7 @@ class ServiceOperator:
|
|
|
579
646
|
is_snowpark_sql_exception = isinstance(ex, exceptions.SnowparkSQLException)
|
|
580
647
|
contains_msg = any(msg in str(ex) for msg in ["Pending scheduling", "Waiting to start"])
|
|
581
648
|
matches_pattern = service_log_meta.service_status is None and re.search(pattern, str(ex)) is not None
|
|
649
|
+
|
|
582
650
|
if not (is_snowpark_sql_exception and (contains_msg or matches_pattern)):
|
|
583
651
|
module_logger.warning(f"Caught an exception when logging: {repr(ex)}")
|
|
584
652
|
time.sleep(5)
|
|
@@ -618,6 +686,101 @@ class ServiceOperator:
|
|
|
618
686
|
except Exception as ex:
|
|
619
687
|
module_logger.warning(f"Caught an exception when logging: {repr(ex)}")
|
|
620
688
|
|
|
689
|
+
def _wait_for_service_status(
|
|
690
|
+
self,
|
|
691
|
+
service_name: sql_identifier.SqlIdentifier,
|
|
692
|
+
target_status: service_sql.ServiceStatus,
|
|
693
|
+
service_database_name: Optional[sql_identifier.SqlIdentifier],
|
|
694
|
+
service_schema_name: Optional[sql_identifier.SqlIdentifier],
|
|
695
|
+
async_job: snowpark.AsyncJob,
|
|
696
|
+
statement_params: Optional[dict[str, Any]] = None,
|
|
697
|
+
timeout_minutes: int = 30,
|
|
698
|
+
) -> None:
|
|
699
|
+
"""Wait for service to reach the specified status while monitoring async job for failures.
|
|
700
|
+
|
|
701
|
+
Args:
|
|
702
|
+
service_name: The service to monitor
|
|
703
|
+
target_status: The target status to wait for
|
|
704
|
+
service_database_name: Database containing the service
|
|
705
|
+
service_schema_name: Schema containing the service
|
|
706
|
+
async_job: The async job to monitor for completion/failure
|
|
707
|
+
statement_params: SQL statement parameters
|
|
708
|
+
timeout_minutes: Maximum time to wait before timing out
|
|
709
|
+
|
|
710
|
+
Raises:
|
|
711
|
+
RuntimeError: If service fails, times out, or enters an error state
|
|
712
|
+
"""
|
|
713
|
+
start_time = time.time()
|
|
714
|
+
timeout_seconds = timeout_minutes * 60
|
|
715
|
+
service_seen_before = False
|
|
716
|
+
|
|
717
|
+
while True:
|
|
718
|
+
# Check if async job has failed (but don't return on success - we need specific service status)
|
|
719
|
+
if async_job.is_done():
|
|
720
|
+
try:
|
|
721
|
+
async_job.result()
|
|
722
|
+
# Async job completed successfully, but we're waiting for a specific service status
|
|
723
|
+
# This might mean the service completed and was cleaned up
|
|
724
|
+
module_logger.debug(
|
|
725
|
+
f"Async job completed but we're still waiting for {service_name} to reach {target_status.value}"
|
|
726
|
+
)
|
|
727
|
+
except Exception as e:
|
|
728
|
+
raise RuntimeError(f"Service deployment failed: {e}")
|
|
729
|
+
|
|
730
|
+
try:
|
|
731
|
+
statuses = self._service_client.get_service_container_statuses(
|
|
732
|
+
database_name=service_database_name,
|
|
733
|
+
schema_name=service_schema_name,
|
|
734
|
+
service_name=service_name,
|
|
735
|
+
include_message=True,
|
|
736
|
+
statement_params=statement_params,
|
|
737
|
+
)
|
|
738
|
+
|
|
739
|
+
if statuses:
|
|
740
|
+
service_seen_before = True
|
|
741
|
+
current_status = statuses[0].service_status
|
|
742
|
+
|
|
743
|
+
# Check if we've reached the target status
|
|
744
|
+
if current_status == target_status:
|
|
745
|
+
return
|
|
746
|
+
|
|
747
|
+
# Check for failure states
|
|
748
|
+
if current_status in [service_sql.ServiceStatus.FAILED, service_sql.ServiceStatus.INTERNAL_ERROR]:
|
|
749
|
+
error_msg = f"Service {service_name} failed with status {current_status.value}"
|
|
750
|
+
if statuses[0].message:
|
|
751
|
+
error_msg += f": {statuses[0].message}"
|
|
752
|
+
raise RuntimeError(error_msg)
|
|
753
|
+
|
|
754
|
+
except exceptions.SnowparkSQLException as e:
|
|
755
|
+
# Service might not exist yet - this is expected during initial deployment
|
|
756
|
+
if "does not exist" in str(e) or "002003" in str(e):
|
|
757
|
+
# If we're waiting for DONE status and we've seen the service before,
|
|
758
|
+
# it likely completed and was cleaned up
|
|
759
|
+
if target_status == service_sql.ServiceStatus.DONE and service_seen_before:
|
|
760
|
+
module_logger.debug(
|
|
761
|
+
f"Service {service_name} disappeared after being seen, "
|
|
762
|
+
f"assuming it reached {target_status.value} and was cleaned up"
|
|
763
|
+
)
|
|
764
|
+
return
|
|
765
|
+
|
|
766
|
+
module_logger.debug(f"Service {service_name} not found yet, continuing to wait...")
|
|
767
|
+
else:
|
|
768
|
+
# Re-raise unexpected SQL exceptions
|
|
769
|
+
raise RuntimeError(f"Error checking service status: {e}")
|
|
770
|
+
except Exception as e:
|
|
771
|
+
# Re-raise unexpected exceptions instead of masking them
|
|
772
|
+
raise RuntimeError(f"Unexpected error while waiting for service status: {e}")
|
|
773
|
+
|
|
774
|
+
# Check timeout
|
|
775
|
+
elapsed_time = time.time() - start_time
|
|
776
|
+
if elapsed_time > timeout_seconds:
|
|
777
|
+
raise RuntimeError(
|
|
778
|
+
f"Timeout waiting for service {service_name} to reach status {target_status.value} "
|
|
779
|
+
f"after {timeout_minutes} minutes"
|
|
780
|
+
)
|
|
781
|
+
|
|
782
|
+
time.sleep(2) # Poll every 2 seconds
|
|
783
|
+
|
|
621
784
|
@staticmethod
|
|
622
785
|
def _get_service_id_from_deployment_step(query_id: str, deployment_step: DeploymentStep) -> str:
|
|
623
786
|
"""Get the service ID through the server-side logic."""
|
|
@@ -675,9 +838,7 @@ class ServiceOperator:
|
|
|
675
838
|
job_name: sql_identifier.SqlIdentifier,
|
|
676
839
|
compute_pool_name: sql_identifier.SqlIdentifier,
|
|
677
840
|
warehouse_name: sql_identifier.SqlIdentifier,
|
|
678
|
-
|
|
679
|
-
image_repo_schema_name: Optional[sql_identifier.SqlIdentifier],
|
|
680
|
-
image_repo_name: sql_identifier.SqlIdentifier,
|
|
841
|
+
image_repo: str,
|
|
681
842
|
output_table_database_name: Optional[sql_identifier.SqlIdentifier],
|
|
682
843
|
output_table_schema_name: Optional[sql_identifier.SqlIdentifier],
|
|
683
844
|
output_table_name: sql_identifier.SqlIdentifier,
|
|
@@ -698,6 +859,10 @@ class ServiceOperator:
|
|
|
698
859
|
job_database_name = job_database_name or database_name or self._database_name
|
|
699
860
|
job_schema_name = job_schema_name or schema_name or self._schema_name
|
|
700
861
|
|
|
862
|
+
# Parse image repo
|
|
863
|
+
image_repo_database_name, image_repo_schema_name, image_repo_name = sql_identifier.parse_fully_qualified_name(
|
|
864
|
+
image_repo
|
|
865
|
+
)
|
|
701
866
|
image_repo_database_name = image_repo_database_name or database_name or self._database_name
|
|
702
867
|
image_repo_schema_name = image_repo_schema_name or schema_name or self._schema_name
|
|
703
868
|
|
|
@@ -23,12 +23,24 @@ class _TqdmStatusContext:
|
|
|
23
23
|
if state == "complete":
|
|
24
24
|
self._progress_bar.update(self._progress_bar.total - self._progress_bar.n)
|
|
25
25
|
self._progress_bar.set_description(label)
|
|
26
|
+
elif state == "error":
|
|
27
|
+
# For error state, use the label as-is and mark with ERROR prefix
|
|
28
|
+
# Don't update progress bar position for errors - leave it where it was
|
|
29
|
+
self._progress_bar.set_description(f"❌ ERROR: {label}")
|
|
26
30
|
else:
|
|
27
|
-
|
|
31
|
+
combined_desc = f"{self._label}: {label}" if label != self._label else self._label
|
|
32
|
+
self._progress_bar.set_description(combined_desc)
|
|
28
33
|
|
|
29
|
-
def increment(self
|
|
34
|
+
def increment(self) -> None:
|
|
30
35
|
"""Increment the progress bar."""
|
|
31
|
-
self._progress_bar.update(
|
|
36
|
+
self._progress_bar.update(1)
|
|
37
|
+
|
|
38
|
+
def complete(self) -> None:
|
|
39
|
+
"""Complete the progress bar to full state."""
|
|
40
|
+
if self._total:
|
|
41
|
+
remaining = self._total - self._progress_bar.n
|
|
42
|
+
if remaining > 0:
|
|
43
|
+
self._progress_bar.update(remaining)
|
|
32
44
|
|
|
33
45
|
|
|
34
46
|
class _StreamlitStatusContext:
|
|
@@ -39,6 +51,7 @@ class _StreamlitStatusContext:
|
|
|
39
51
|
self._streamlit = streamlit_module
|
|
40
52
|
self._total = total
|
|
41
53
|
self._current = 0
|
|
54
|
+
self._current_label = label
|
|
42
55
|
self._progress_bar = None
|
|
43
56
|
|
|
44
57
|
def __enter__(self) -> "_StreamlitStatusContext":
|
|
@@ -49,26 +62,70 @@ class _StreamlitStatusContext:
|
|
|
49
62
|
return self
|
|
50
63
|
|
|
51
64
|
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
|
52
|
-
|
|
65
|
+
# Only update to complete if there was no exception
|
|
66
|
+
if exc_type is None:
|
|
67
|
+
self._status_container.update(state="complete")
|
|
53
68
|
|
|
54
69
|
def update(self, label: str, *, state: str = "running", expanded: bool = True) -> None:
|
|
55
70
|
"""Update the status label."""
|
|
56
|
-
if state
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
71
|
+
if state == "complete" or state == "error":
|
|
72
|
+
# For completion/error, use the message as-is and update main status
|
|
73
|
+
self._status_container.update(label=label, state=state, expanded=expanded)
|
|
74
|
+
self._current_label = label
|
|
75
|
+
|
|
76
|
+
# For error state, update progress bar text but preserve position
|
|
77
|
+
if state == "error" and self._total is not None and self._progress_bar is not None:
|
|
78
|
+
self._progress_bar.progress(
|
|
79
|
+
self._current / self._total if self._total > 0 else 0,
|
|
80
|
+
text=f"ERROR - ({self._current}/{self._total})",
|
|
81
|
+
)
|
|
82
|
+
else:
|
|
83
|
+
combined_label = f"{self._label}: {label}" if label != self._label else self._label
|
|
84
|
+
self._status_container.update(label=combined_label, state=state, expanded=expanded)
|
|
85
|
+
self._current_label = label
|
|
86
|
+
if self._total is not None and self._progress_bar is not None:
|
|
87
|
+
progress_value = self._current / self._total if self._total > 0 else 0
|
|
88
|
+
self._progress_bar.progress(progress_value, text=f"({self._current}/{self._total})")
|
|
89
|
+
|
|
90
|
+
def increment(self) -> None:
|
|
66
91
|
"""Increment the progress."""
|
|
67
92
|
if self._total is not None:
|
|
68
|
-
self._current = min(self._current +
|
|
93
|
+
self._current = min(self._current + 1, self._total)
|
|
69
94
|
if self._progress_bar is not None:
|
|
70
95
|
progress_value = self._current / self._total if self._total > 0 else 0
|
|
71
|
-
self._progress_bar.progress(progress_value, text=f"{self._current}/{self._total}")
|
|
96
|
+
self._progress_bar.progress(progress_value, text=f"({self._current}/{self._total})")
|
|
97
|
+
|
|
98
|
+
def complete(self) -> None:
|
|
99
|
+
"""Complete the progress bar to full state."""
|
|
100
|
+
if self._total is not None:
|
|
101
|
+
self._current = self._total
|
|
102
|
+
if self._progress_bar is not None:
|
|
103
|
+
self._progress_bar.progress(1.0, text=f"({self._current}/{self._total})")
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class _NoOpStatusContext:
|
|
107
|
+
"""A no-op context manager for when status updates should be disabled."""
|
|
108
|
+
|
|
109
|
+
def __init__(self, label: str) -> None:
|
|
110
|
+
self._label = label
|
|
111
|
+
|
|
112
|
+
def __enter__(self) -> "_NoOpStatusContext":
|
|
113
|
+
return self
|
|
114
|
+
|
|
115
|
+
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
|
116
|
+
pass
|
|
117
|
+
|
|
118
|
+
def update(self, label: str, *, state: str = "running", expanded: bool = True) -> None:
|
|
119
|
+
"""No-op update method."""
|
|
120
|
+
pass
|
|
121
|
+
|
|
122
|
+
def increment(self) -> None:
|
|
123
|
+
"""No-op increment method."""
|
|
124
|
+
pass
|
|
125
|
+
|
|
126
|
+
def complete(self) -> None:
|
|
127
|
+
"""No-op complete method."""
|
|
128
|
+
pass
|
|
72
129
|
|
|
73
130
|
|
|
74
131
|
class ModelEventHandler:
|
|
@@ -99,7 +156,15 @@ class ModelEventHandler:
|
|
|
99
156
|
else:
|
|
100
157
|
self._tqdm.tqdm.write(message)
|
|
101
158
|
|
|
102
|
-
def status(
|
|
159
|
+
def status(
|
|
160
|
+
self,
|
|
161
|
+
label: str,
|
|
162
|
+
*,
|
|
163
|
+
state: str = "running",
|
|
164
|
+
expanded: bool = True,
|
|
165
|
+
total: Optional[int] = None,
|
|
166
|
+
block: bool = True,
|
|
167
|
+
) -> Any:
|
|
103
168
|
"""Context manager that provides status updates with optional enhanced display capabilities.
|
|
104
169
|
|
|
105
170
|
Args:
|
|
@@ -107,10 +172,14 @@ class ModelEventHandler:
|
|
|
107
172
|
state: The initial state ("running", "complete", "error")
|
|
108
173
|
expanded: Whether to show expanded view (streamlit only)
|
|
109
174
|
total: Total number of steps for progress tracking (optional)
|
|
175
|
+
block: Whether to show progress updates (no-op if False)
|
|
110
176
|
|
|
111
177
|
Returns:
|
|
112
|
-
Status context (Streamlit or
|
|
178
|
+
Status context (Streamlit, Tqdm, or NoOp based on availability and block parameter)
|
|
113
179
|
"""
|
|
180
|
+
if not block:
|
|
181
|
+
return _NoOpStatusContext(label)
|
|
182
|
+
|
|
114
183
|
if self._streamlit is not None:
|
|
115
184
|
return _StreamlitStatusContext(label, self._streamlit, total)
|
|
116
185
|
else:
|
|
@@ -299,6 +299,7 @@ class HuggingFacePipelineModel:
|
|
|
299
299
|
Raises:
|
|
300
300
|
ValueError: if database and schema name is not provided and session doesn't have a
|
|
301
301
|
database and schema name.
|
|
302
|
+
exceptions.SnowparkSQLException: if service already exists.
|
|
302
303
|
|
|
303
304
|
Returns:
|
|
304
305
|
The service ID or an async job object.
|
|
@@ -327,7 +328,6 @@ class HuggingFacePipelineModel:
|
|
|
327
328
|
version_name = name_generator.generate()[1]
|
|
328
329
|
|
|
329
330
|
service_db_id, service_schema_id, service_id = sql_identifier.parse_fully_qualified_name(service_name)
|
|
330
|
-
image_repo_db_id, image_repo_schema_id, image_repo_id = sql_identifier.parse_fully_qualified_name(image_repo)
|
|
331
331
|
|
|
332
332
|
service_operator = service_ops.ServiceOperator(
|
|
333
333
|
session=session,
|
|
@@ -336,51 +336,73 @@ class HuggingFacePipelineModel:
|
|
|
336
336
|
)
|
|
337
337
|
logger.info(f"A service job is going to register the hf model as: {model_name}.{version_name}")
|
|
338
338
|
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
339
|
+
from snowflake.ml.model import event_handler
|
|
340
|
+
from snowflake.snowpark import exceptions
|
|
341
|
+
|
|
342
|
+
hf_event_handler = event_handler.ModelEventHandler()
|
|
343
|
+
with hf_event_handler.status("Creating HuggingFace model service", total=6, block=block) as status:
|
|
344
|
+
try:
|
|
345
|
+
result = service_operator.create_service(
|
|
346
|
+
database_name=database_name_id,
|
|
347
|
+
schema_name=schema_name_id,
|
|
348
|
+
model_name=model_name_id,
|
|
349
|
+
version_name=sql_identifier.SqlIdentifier(version_name),
|
|
350
|
+
service_database_name=service_db_id,
|
|
351
|
+
service_schema_name=service_schema_id,
|
|
352
|
+
service_name=service_id,
|
|
353
|
+
image_build_compute_pool_name=(
|
|
354
|
+
sql_identifier.SqlIdentifier(image_build_compute_pool)
|
|
355
|
+
if image_build_compute_pool
|
|
356
|
+
else sql_identifier.SqlIdentifier(service_compute_pool)
|
|
357
|
+
),
|
|
358
|
+
service_compute_pool_name=sql_identifier.SqlIdentifier(service_compute_pool),
|
|
359
|
+
image_repo=image_repo,
|
|
360
|
+
ingress_enabled=ingress_enabled,
|
|
361
|
+
max_instances=max_instances,
|
|
362
|
+
cpu_requests=cpu_requests,
|
|
363
|
+
memory_requests=memory_requests,
|
|
364
|
+
gpu_requests=gpu_requests,
|
|
365
|
+
num_workers=num_workers,
|
|
366
|
+
max_batch_rows=max_batch_rows,
|
|
367
|
+
force_rebuild=force_rebuild,
|
|
368
|
+
build_external_access_integrations=(
|
|
369
|
+
None
|
|
370
|
+
if build_external_access_integrations is None
|
|
371
|
+
else [sql_identifier.SqlIdentifier(eai) for eai in build_external_access_integrations]
|
|
372
|
+
),
|
|
373
|
+
block=block,
|
|
374
|
+
progress_status=status,
|
|
375
|
+
statement_params=statement_params,
|
|
376
|
+
# hf model
|
|
377
|
+
hf_model_args=service_ops.HFModelArgs(
|
|
378
|
+
hf_model_name=self.model,
|
|
379
|
+
hf_task=self.task,
|
|
380
|
+
hf_tokenizer=self.tokenizer,
|
|
381
|
+
hf_revision=self.revision,
|
|
382
|
+
hf_token=self.token,
|
|
383
|
+
hf_trust_remote_code=bool(self.trust_remote_code),
|
|
384
|
+
hf_model_kwargs=self.model_kwargs,
|
|
385
|
+
pip_requirements=pip_requirements,
|
|
386
|
+
conda_dependencies=conda_dependencies,
|
|
387
|
+
comment=comment,
|
|
388
|
+
# TODO: remove warehouse in the next release
|
|
389
|
+
warehouse=session.get_current_warehouse(),
|
|
390
|
+
),
|
|
391
|
+
)
|
|
392
|
+
status.update(label="HuggingFace model service created successfully", state="complete", expanded=False)
|
|
393
|
+
return result
|
|
394
|
+
except exceptions.SnowparkSQLException as e:
|
|
395
|
+
# Check if the error is because the service already exists
|
|
396
|
+
if "already exists" in str(e).lower() or "100132" in str(
|
|
397
|
+
e
|
|
398
|
+
): # 100132 is Snowflake error code for object already exists
|
|
399
|
+
# Update progress to show service already exists (preserve exception behavior)
|
|
400
|
+
status.update("service already exists")
|
|
401
|
+
status.complete() # Complete progress to full state
|
|
402
|
+
status.update(label="Service already exists", state="error", expanded=False)
|
|
403
|
+
# Re-raise the exception to preserve existing API behavior
|
|
404
|
+
raise
|
|
405
|
+
else:
|
|
406
|
+
# Re-raise other SQL exceptions
|
|
407
|
+
status.update(label="Service creation failed", state="error", expanded=False)
|
|
408
|
+
raise
|
snowflake/ml/model/type_hints.py
CHANGED
|
@@ -1,5 +1,14 @@
|
|
|
1
1
|
# mypy: disable-error-code="import"
|
|
2
|
-
from typing import
|
|
2
|
+
from typing import (
|
|
3
|
+
TYPE_CHECKING,
|
|
4
|
+
Any,
|
|
5
|
+
Literal,
|
|
6
|
+
Protocol,
|
|
7
|
+
Sequence,
|
|
8
|
+
TypedDict,
|
|
9
|
+
TypeVar,
|
|
10
|
+
Union,
|
|
11
|
+
)
|
|
3
12
|
|
|
4
13
|
import numpy.typing as npt
|
|
5
14
|
from typing_extensions import NotRequired
|
|
@@ -326,4 +335,20 @@ ModelLoadOption = Union[
|
|
|
326
335
|
SupportedTargetPlatformType = Union[TargetPlatform, str]
|
|
327
336
|
|
|
328
337
|
|
|
338
|
+
class ProgressStatus(Protocol):
|
|
339
|
+
"""Protocol for tracking progress during long-running operations."""
|
|
340
|
+
|
|
341
|
+
def update(self, message: str, *, state: str = "running", expanded: bool = True, **kwargs: Any) -> None:
|
|
342
|
+
"""Update the progress status with a new message."""
|
|
343
|
+
...
|
|
344
|
+
|
|
345
|
+
def increment(self) -> None:
|
|
346
|
+
"""Increment the progress by one step."""
|
|
347
|
+
...
|
|
348
|
+
|
|
349
|
+
def complete(self) -> None:
|
|
350
|
+
"""Complete the progress bar to full state."""
|
|
351
|
+
...
|
|
352
|
+
|
|
353
|
+
|
|
329
354
|
__all__ = ["TargetPlatform", "Task"]
|