snowflake-ml-python 1.9.2__py3-none-any.whl → 1.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -180,9 +180,7 @@ class ServiceOperator:
180
180
  service_name: sql_identifier.SqlIdentifier,
181
181
  image_build_compute_pool_name: sql_identifier.SqlIdentifier,
182
182
  service_compute_pool_name: sql_identifier.SqlIdentifier,
183
- image_repo_database_name: Optional[sql_identifier.SqlIdentifier],
184
- image_repo_schema_name: Optional[sql_identifier.SqlIdentifier],
185
- image_repo_name: sql_identifier.SqlIdentifier,
183
+ image_repo: str,
186
184
  ingress_enabled: bool,
187
185
  max_instances: int,
188
186
  cpu_requests: Optional[str],
@@ -193,6 +191,7 @@ class ServiceOperator:
193
191
  force_rebuild: bool,
194
192
  build_external_access_integrations: Optional[list[sql_identifier.SqlIdentifier]],
195
193
  block: bool,
194
+ progress_status: type_hints.ProgressStatus,
196
195
  statement_params: Optional[dict[str, Any]] = None,
197
196
  # hf model
198
197
  hf_model_args: Optional[HFModelArgs] = None,
@@ -209,8 +208,17 @@ class ServiceOperator:
209
208
  service_database_name = service_database_name or database_name or self._database_name
210
209
  service_schema_name = service_schema_name or schema_name or self._schema_name
211
210
 
211
+ # Parse image repo
212
+ image_repo_database_name, image_repo_schema_name, image_repo_name = sql_identifier.parse_fully_qualified_name(
213
+ image_repo
214
+ )
212
215
  image_repo_database_name = image_repo_database_name or database_name or self._database_name
213
216
  image_repo_schema_name = image_repo_schema_name or schema_name or self._schema_name
217
+
218
+ # Step 1: Preparing deployment artifacts
219
+ progress_status.update("preparing deployment artifacts...")
220
+ progress_status.increment()
221
+
214
222
  if self._workspace:
215
223
  stage_path = self._create_temp_stage(database_name, schema_name, statement_params)
216
224
  else:
@@ -259,6 +267,11 @@ class ServiceOperator:
259
267
  **(hf_model_args.hf_model_kwargs if hf_model_args.hf_model_kwargs else {}),
260
268
  )
261
269
  spec_yaml_str_or_path = self._model_deployment_spec.save()
270
+
271
+ # Step 2: Uploading deployment artifacts
272
+ progress_status.update("uploading deployment artifacts...")
273
+ progress_status.increment()
274
+
262
275
  if self._workspace:
263
276
  assert stage_path is not None
264
277
  file_utils.upload_directory_to_stage(
@@ -281,6 +294,10 @@ class ServiceOperator:
281
294
  statement_params=statement_params,
282
295
  )
283
296
 
297
+ # Step 3: Initiating model deployment
298
+ progress_status.update("initiating model deployment...")
299
+ progress_status.increment()
300
+
284
301
  # deploy the model service
285
302
  query_id, async_job = self._service_client.deploy_model(
286
303
  stage_path=stage_path if self._workspace else None,
@@ -337,13 +354,63 @@ class ServiceOperator:
337
354
  )
338
355
 
339
356
  if block:
340
- log_thread.join()
357
+ try:
358
+ # Step 4: Starting model build: waits for build to start
359
+ progress_status.update("starting model image build...")
360
+ progress_status.increment()
361
+
362
+ # Poll for model build to start if not using existing service
363
+ if not model_inference_service_exists:
364
+ self._wait_for_service_status(
365
+ model_build_service_name,
366
+ service_sql.ServiceStatus.RUNNING,
367
+ service_database_name,
368
+ service_schema_name,
369
+ async_job,
370
+ statement_params,
371
+ )
341
372
 
342
- res = cast(str, cast(list[row.Row], async_job.result())[0][0])
343
- module_logger.info(f"Inference service {service_name} deployment complete: {res}")
344
- return res
345
- else:
346
- return async_job
373
+ # Step 5: Building model image
374
+ progress_status.update("building model image...")
375
+ progress_status.increment()
376
+
377
+ # Poll for model build completion
378
+ if not model_inference_service_exists:
379
+ self._wait_for_service_status(
380
+ model_build_service_name,
381
+ service_sql.ServiceStatus.DONE,
382
+ service_database_name,
383
+ service_schema_name,
384
+ async_job,
385
+ statement_params,
386
+ )
387
+
388
+ # Step 6: Deploying model service (push complete, starting inference service)
389
+ progress_status.update("deploying model service...")
390
+ progress_status.increment()
391
+
392
+ log_thread.join()
393
+
394
+ res = cast(str, cast(list[row.Row], async_job.result())[0][0])
395
+ module_logger.info(f"Inference service {service_name} deployment complete: {res}")
396
+ return res
397
+
398
+ except RuntimeError as e:
399
+ # Handle service creation/deployment failures
400
+ error_msg = f"Model service deployment failed: {str(e)}"
401
+ module_logger.error(error_msg)
402
+
403
+ # Update progress status to show failure
404
+ progress_status.update(error_msg, state="error")
405
+
406
+ # Stop the log thread if it's running
407
+ if "log_thread" in locals() and log_thread.is_alive():
408
+ log_thread.join(timeout=5) # Give it a few seconds to finish gracefully
409
+
410
+ # Re-raise the exception to propagate the error
411
+ raise RuntimeError(error_msg) from e
412
+
413
+ return async_job
347
414
 
348
415
  def _start_service_log_streaming(
349
416
  self,
@@ -579,6 +646,7 @@ class ServiceOperator:
579
646
  is_snowpark_sql_exception = isinstance(ex, exceptions.SnowparkSQLException)
580
647
  contains_msg = any(msg in str(ex) for msg in ["Pending scheduling", "Waiting to start"])
581
648
  matches_pattern = service_log_meta.service_status is None and re.search(pattern, str(ex)) is not None
649
+
582
650
  if not (is_snowpark_sql_exception and (contains_msg or matches_pattern)):
583
651
  module_logger.warning(f"Caught an exception when logging: {repr(ex)}")
584
652
  time.sleep(5)
@@ -618,6 +686,101 @@ class ServiceOperator:
618
686
  except Exception as ex:
619
687
  module_logger.warning(f"Caught an exception when logging: {repr(ex)}")
620
688
 
689
+ def _wait_for_service_status(
690
+ self,
691
+ service_name: sql_identifier.SqlIdentifier,
692
+ target_status: service_sql.ServiceStatus,
693
+ service_database_name: Optional[sql_identifier.SqlIdentifier],
694
+ service_schema_name: Optional[sql_identifier.SqlIdentifier],
695
+ async_job: snowpark.AsyncJob,
696
+ statement_params: Optional[dict[str, Any]] = None,
697
+ timeout_minutes: int = 30,
698
+ ) -> None:
699
+ """Wait for service to reach the specified status while monitoring async job for failures.
700
+
701
+ Args:
702
+ service_name: The service to monitor
703
+ target_status: The target status to wait for
704
+ service_database_name: Database containing the service
705
+ service_schema_name: Schema containing the service
706
+ async_job: The async job to monitor for completion/failure
707
+ statement_params: SQL statement parameters
708
+ timeout_minutes: Maximum time to wait before timing out
709
+
710
+ Raises:
711
+ RuntimeError: If service fails, times out, or enters an error state
712
+ """
713
+ start_time = time.time()
714
+ timeout_seconds = timeout_minutes * 60
715
+ service_seen_before = False
716
+
717
+ while True:
718
+ # Check if async job has failed (but don't return on success - we need specific service status)
719
+ if async_job.is_done():
720
+ try:
721
+ async_job.result()
722
+ # Async job completed successfully, but we're waiting for a specific service status
723
+ # This might mean the service completed and was cleaned up
724
+ module_logger.debug(
725
+ f"Async job completed but we're still waiting for {service_name} to reach {target_status.value}"
726
+ )
727
+ except Exception as e:
728
+ raise RuntimeError(f"Service deployment failed: {e}")
729
+
730
+ try:
731
+ statuses = self._service_client.get_service_container_statuses(
732
+ database_name=service_database_name,
733
+ schema_name=service_schema_name,
734
+ service_name=service_name,
735
+ include_message=True,
736
+ statement_params=statement_params,
737
+ )
738
+
739
+ if statuses:
740
+ service_seen_before = True
741
+ current_status = statuses[0].service_status
742
+
743
+ # Check if we've reached the target status
744
+ if current_status == target_status:
745
+ return
746
+
747
+ # Check for failure states
748
+ if current_status in [service_sql.ServiceStatus.FAILED, service_sql.ServiceStatus.INTERNAL_ERROR]:
749
+ error_msg = f"Service {service_name} failed with status {current_status.value}"
750
+ if statuses[0].message:
751
+ error_msg += f": {statuses[0].message}"
752
+ raise RuntimeError(error_msg)
753
+
754
+ except exceptions.SnowparkSQLException as e:
755
+ # Service might not exist yet - this is expected during initial deployment
756
+ if "does not exist" in str(e) or "002003" in str(e):
757
+ # If we're waiting for DONE status and we've seen the service before,
758
+ # it likely completed and was cleaned up
759
+ if target_status == service_sql.ServiceStatus.DONE and service_seen_before:
760
+ module_logger.debug(
761
+ f"Service {service_name} disappeared after being seen, "
762
+ f"assuming it reached {target_status.value} and was cleaned up"
763
+ )
764
+ return
765
+
766
+ module_logger.debug(f"Service {service_name} not found yet, continuing to wait...")
767
+ else:
768
+ # Re-raise unexpected SQL exceptions
769
+ raise RuntimeError(f"Error checking service status: {e}")
770
+ except Exception as e:
771
+ # Re-raise unexpected exceptions instead of masking them
772
+ raise RuntimeError(f"Unexpected error while waiting for service status: {e}")
773
+
774
+ # Check timeout
775
+ elapsed_time = time.time() - start_time
776
+ if elapsed_time > timeout_seconds:
777
+ raise RuntimeError(
778
+ f"Timeout waiting for service {service_name} to reach status {target_status.value} "
779
+ f"after {timeout_minutes} minutes"
780
+ )
781
+
782
+ time.sleep(2) # Poll every 2 seconds
783
+
621
784
  @staticmethod
622
785
  def _get_service_id_from_deployment_step(query_id: str, deployment_step: DeploymentStep) -> str:
623
786
  """Get the service ID through the server-side logic."""
@@ -675,9 +838,7 @@ class ServiceOperator:
675
838
  job_name: sql_identifier.SqlIdentifier,
676
839
  compute_pool_name: sql_identifier.SqlIdentifier,
677
840
  warehouse_name: sql_identifier.SqlIdentifier,
678
- image_repo_database_name: Optional[sql_identifier.SqlIdentifier],
679
- image_repo_schema_name: Optional[sql_identifier.SqlIdentifier],
680
- image_repo_name: sql_identifier.SqlIdentifier,
841
+ image_repo: str,
681
842
  output_table_database_name: Optional[sql_identifier.SqlIdentifier],
682
843
  output_table_schema_name: Optional[sql_identifier.SqlIdentifier],
683
844
  output_table_name: sql_identifier.SqlIdentifier,
@@ -698,6 +859,10 @@ class ServiceOperator:
698
859
  job_database_name = job_database_name or database_name or self._database_name
699
860
  job_schema_name = job_schema_name or schema_name or self._schema_name
700
861
 
862
+ # Parse image repo
863
+ image_repo_database_name, image_repo_schema_name, image_repo_name = sql_identifier.parse_fully_qualified_name(
864
+ image_repo
865
+ )
701
866
  image_repo_database_name = image_repo_database_name or database_name or self._database_name
702
867
  image_repo_schema_name = image_repo_schema_name or schema_name or self._schema_name
703
868
 
@@ -23,12 +23,24 @@ class _TqdmStatusContext:
23
23
  if state == "complete":
24
24
  self._progress_bar.update(self._progress_bar.total - self._progress_bar.n)
25
25
  self._progress_bar.set_description(label)
26
+ elif state == "error":
27
+ # For error state, use the label as-is and mark with ERROR prefix
28
+ # Don't update progress bar position for errors - leave it where it was
29
+ self._progress_bar.set_description(f"❌ ERROR: {label}")
26
30
  else:
27
- self._progress_bar.set_description(f"{self._label}: {label}")
31
+ combined_desc = f"{self._label}: {label}" if label != self._label else self._label
32
+ self._progress_bar.set_description(combined_desc)
28
33
 
29
- def increment(self, n: int = 1) -> None:
34
+ def increment(self) -> None:
30
35
  """Increment the progress bar."""
31
- self._progress_bar.update(n)
36
+ self._progress_bar.update(1)
37
+
38
+ def complete(self) -> None:
39
+ """Complete the progress bar to full state."""
40
+ if self._total:
41
+ remaining = self._total - self._progress_bar.n
42
+ if remaining > 0:
43
+ self._progress_bar.update(remaining)
32
44
 
33
45
 
34
46
  class _StreamlitStatusContext:
@@ -39,6 +51,7 @@ class _StreamlitStatusContext:
39
51
  self._streamlit = streamlit_module
40
52
  self._total = total
41
53
  self._current = 0
54
+ self._current_label = label
42
55
  self._progress_bar = None
43
56
 
44
57
  def __enter__(self) -> "_StreamlitStatusContext":
@@ -49,26 +62,70 @@ class _StreamlitStatusContext:
49
62
  return self
50
63
 
51
64
  def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
52
- self._status_container.update(state="complete")
65
+ # Only update to complete if there was no exception
66
+ if exc_type is None:
67
+ self._status_container.update(state="complete")
53
68
 
54
69
  def update(self, label: str, *, state: str = "running", expanded: bool = True) -> None:
55
70
  """Update the status label."""
56
- if state != "complete":
57
- label = f"{self._label}: {label}"
58
- self._status_container.update(label=label, state=state, expanded=expanded)
59
- if self._progress_bar is not None:
60
- self._progress_bar.progress(
61
- self._current / self._total if self._total > 0 else 0,
62
- text=f"{label} - {self._current}/{self._total}",
63
- )
64
-
65
- def increment(self, n: int = 1) -> None:
71
+ if state == "complete" or state == "error":
72
+ # For completion/error, use the message as-is and update main status
73
+ self._status_container.update(label=label, state=state, expanded=expanded)
74
+ self._current_label = label
75
+
76
+ # For error state, update progress bar text but preserve position
77
+ if state == "error" and self._total is not None and self._progress_bar is not None:
78
+ self._progress_bar.progress(
79
+ self._current / self._total if self._total > 0 else 0,
80
+ text=f"ERROR - ({self._current}/{self._total})",
81
+ )
82
+ else:
83
+ combined_label = f"{self._label}: {label}" if label != self._label else self._label
84
+ self._status_container.update(label=combined_label, state=state, expanded=expanded)
85
+ self._current_label = label
86
+ if self._total is not None and self._progress_bar is not None:
87
+ progress_value = self._current / self._total if self._total > 0 else 0
88
+ self._progress_bar.progress(progress_value, text=f"({self._current}/{self._total})")
89
+
90
+ def increment(self) -> None:
66
91
  """Increment the progress."""
67
92
  if self._total is not None:
68
- self._current = min(self._current + n, self._total)
93
+ self._current = min(self._current + 1, self._total)
69
94
  if self._progress_bar is not None:
70
95
  progress_value = self._current / self._total if self._total > 0 else 0
71
- self._progress_bar.progress(progress_value, text=f"{self._current}/{self._total}")
96
+ self._progress_bar.progress(progress_value, text=f"({self._current}/{self._total})")
97
+
98
+ def complete(self) -> None:
99
+ """Complete the progress bar to full state."""
100
+ if self._total is not None:
101
+ self._current = self._total
102
+ if self._progress_bar is not None:
103
+ self._progress_bar.progress(1.0, text=f"({self._current}/{self._total})")
104
+
105
+
106
+ class _NoOpStatusContext:
107
+ """A no-op context manager for when status updates should be disabled."""
108
+
109
+ def __init__(self, label: str) -> None:
110
+ self._label = label
111
+
112
+ def __enter__(self) -> "_NoOpStatusContext":
113
+ return self
114
+
115
+ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
116
+ pass
117
+
118
+ def update(self, label: str, *, state: str = "running", expanded: bool = True) -> None:
119
+ """No-op update method."""
120
+ pass
121
+
122
+ def increment(self) -> None:
123
+ """No-op increment method."""
124
+ pass
125
+
126
+ def complete(self) -> None:
127
+ """No-op complete method."""
128
+ pass
72
129
 
73
130
 
74
131
  class ModelEventHandler:
@@ -99,7 +156,15 @@ class ModelEventHandler:
99
156
  else:
100
157
  self._tqdm.tqdm.write(message)
101
158
 
102
- def status(self, label: str, *, state: str = "running", expanded: bool = True, total: Optional[int] = None) -> Any:
159
+ def status(
160
+ self,
161
+ label: str,
162
+ *,
163
+ state: str = "running",
164
+ expanded: bool = True,
165
+ total: Optional[int] = None,
166
+ block: bool = True,
167
+ ) -> Any:
103
168
  """Context manager that provides status updates with optional enhanced display capabilities.
104
169
 
105
170
  Args:
@@ -107,10 +172,14 @@ class ModelEventHandler:
107
172
  state: The initial state ("running", "complete", "error")
108
173
  expanded: Whether to show expanded view (streamlit only)
109
174
  total: Total number of steps for progress tracking (optional)
175
+ block: Whether to show progress updates (no-op if False)
110
176
 
111
177
  Returns:
112
- Status context (Streamlit or Tqdm)
178
+ Status context (Streamlit, Tqdm, or NoOp based on availability and block parameter)
113
179
  """
180
+ if not block:
181
+ return _NoOpStatusContext(label)
182
+
114
183
  if self._streamlit is not None:
115
184
  return _StreamlitStatusContext(label, self._streamlit, total)
116
185
  else:
@@ -299,6 +299,7 @@ class HuggingFacePipelineModel:
299
299
  Raises:
300
300
  ValueError: if database and schema name is not provided and session doesn't have a
301
301
  database and schema name.
302
+ exceptions.SnowparkSQLException: if service already exists.
302
303
 
303
304
  Returns:
304
305
  The service ID or an async job object.
@@ -327,7 +328,6 @@ class HuggingFacePipelineModel:
327
328
  version_name = name_generator.generate()[1]
328
329
 
329
330
  service_db_id, service_schema_id, service_id = sql_identifier.parse_fully_qualified_name(service_name)
330
- image_repo_db_id, image_repo_schema_id, image_repo_id = sql_identifier.parse_fully_qualified_name(image_repo)
331
331
 
332
332
  service_operator = service_ops.ServiceOperator(
333
333
  session=session,
@@ -336,51 +336,73 @@ class HuggingFacePipelineModel:
336
336
  )
337
337
  logger.info(f"A service job is going to register the hf model as: {model_name}.{version_name}")
338
338
 
339
- return service_operator.create_service(
340
- database_name=database_name_id,
341
- schema_name=schema_name_id,
342
- model_name=model_name_id,
343
- version_name=sql_identifier.SqlIdentifier(version_name),
344
- service_database_name=service_db_id,
345
- service_schema_name=service_schema_id,
346
- service_name=service_id,
347
- image_build_compute_pool_name=(
348
- sql_identifier.SqlIdentifier(image_build_compute_pool)
349
- if image_build_compute_pool
350
- else sql_identifier.SqlIdentifier(service_compute_pool)
351
- ),
352
- service_compute_pool_name=sql_identifier.SqlIdentifier(service_compute_pool),
353
- image_repo_database_name=image_repo_db_id,
354
- image_repo_schema_name=image_repo_schema_id,
355
- image_repo_name=image_repo_id,
356
- ingress_enabled=ingress_enabled,
357
- max_instances=max_instances,
358
- cpu_requests=cpu_requests,
359
- memory_requests=memory_requests,
360
- gpu_requests=gpu_requests,
361
- num_workers=num_workers,
362
- max_batch_rows=max_batch_rows,
363
- force_rebuild=force_rebuild,
364
- build_external_access_integrations=(
365
- None
366
- if build_external_access_integrations is None
367
- else [sql_identifier.SqlIdentifier(eai) for eai in build_external_access_integrations]
368
- ),
369
- block=block,
370
- statement_params=statement_params,
371
- # hf model
372
- hf_model_args=service_ops.HFModelArgs(
373
- hf_model_name=self.model,
374
- hf_task=self.task,
375
- hf_tokenizer=self.tokenizer,
376
- hf_revision=self.revision,
377
- hf_token=self.token,
378
- hf_trust_remote_code=bool(self.trust_remote_code),
379
- hf_model_kwargs=self.model_kwargs,
380
- pip_requirements=pip_requirements,
381
- conda_dependencies=conda_dependencies,
382
- comment=comment,
383
- # TODO: remove warehouse in the next release
384
- warehouse=session.get_current_warehouse(),
385
- ),
386
- )
339
+ from snowflake.ml.model import event_handler
340
+ from snowflake.snowpark import exceptions
341
+
342
+ hf_event_handler = event_handler.ModelEventHandler()
343
+ with hf_event_handler.status("Creating HuggingFace model service", total=6, block=block) as status:
344
+ try:
345
+ result = service_operator.create_service(
346
+ database_name=database_name_id,
347
+ schema_name=schema_name_id,
348
+ model_name=model_name_id,
349
+ version_name=sql_identifier.SqlIdentifier(version_name),
350
+ service_database_name=service_db_id,
351
+ service_schema_name=service_schema_id,
352
+ service_name=service_id,
353
+ image_build_compute_pool_name=(
354
+ sql_identifier.SqlIdentifier(image_build_compute_pool)
355
+ if image_build_compute_pool
356
+ else sql_identifier.SqlIdentifier(service_compute_pool)
357
+ ),
358
+ service_compute_pool_name=sql_identifier.SqlIdentifier(service_compute_pool),
359
+ image_repo=image_repo,
360
+ ingress_enabled=ingress_enabled,
361
+ max_instances=max_instances,
362
+ cpu_requests=cpu_requests,
363
+ memory_requests=memory_requests,
364
+ gpu_requests=gpu_requests,
365
+ num_workers=num_workers,
366
+ max_batch_rows=max_batch_rows,
367
+ force_rebuild=force_rebuild,
368
+ build_external_access_integrations=(
369
+ None
370
+ if build_external_access_integrations is None
371
+ else [sql_identifier.SqlIdentifier(eai) for eai in build_external_access_integrations]
372
+ ),
373
+ block=block,
374
+ progress_status=status,
375
+ statement_params=statement_params,
376
+ # hf model
377
+ hf_model_args=service_ops.HFModelArgs(
378
+ hf_model_name=self.model,
379
+ hf_task=self.task,
380
+ hf_tokenizer=self.tokenizer,
381
+ hf_revision=self.revision,
382
+ hf_token=self.token,
383
+ hf_trust_remote_code=bool(self.trust_remote_code),
384
+ hf_model_kwargs=self.model_kwargs,
385
+ pip_requirements=pip_requirements,
386
+ conda_dependencies=conda_dependencies,
387
+ comment=comment,
388
+ # TODO: remove warehouse in the next release
389
+ warehouse=session.get_current_warehouse(),
390
+ ),
391
+ )
392
+ status.update(label="HuggingFace model service created successfully", state="complete", expanded=False)
393
+ return result
394
+ except exceptions.SnowparkSQLException as e:
395
+ # Check if the error is because the service already exists
396
+ if "already exists" in str(e).lower() or "100132" in str(
397
+ e
398
+ ): # 100132 is Snowflake error code for object already exists
399
+ # Update progress to show service already exists (preserve exception behavior)
400
+ status.update("service already exists")
401
+ status.complete() # Complete progress to full state
402
+ status.update(label="Service already exists", state="error", expanded=False)
403
+ # Re-raise the exception to preserve existing API behavior
404
+ raise
405
+ else:
406
+ # Re-raise other SQL exceptions
407
+ status.update(label="Service creation failed", state="error", expanded=False)
408
+ raise
@@ -1,5 +1,14 @@
1
1
  # mypy: disable-error-code="import"
2
- from typing import TYPE_CHECKING, Literal, Sequence, TypedDict, TypeVar, Union
2
+ from typing import (
3
+ TYPE_CHECKING,
4
+ Any,
5
+ Literal,
6
+ Protocol,
7
+ Sequence,
8
+ TypedDict,
9
+ TypeVar,
10
+ Union,
11
+ )
3
12
 
4
13
  import numpy.typing as npt
5
14
  from typing_extensions import NotRequired
@@ -326,4 +335,20 @@ ModelLoadOption = Union[
326
335
  SupportedTargetPlatformType = Union[TargetPlatform, str]
327
336
 
328
337
 
338
+ class ProgressStatus(Protocol):
339
+ """Protocol for tracking progress during long-running operations."""
340
+
341
+ def update(self, message: str, *, state: str = "running", expanded: bool = True, **kwargs: Any) -> None:
342
+ """Update the progress status with a new message."""
343
+ ...
344
+
345
+ def increment(self) -> None:
346
+ """Increment the progress by one step."""
347
+ ...
348
+
349
+ def complete(self) -> None:
350
+ """Complete the progress bar to full state."""
351
+ ...
352
+
353
+
329
354
  __all__ = ["TargetPlatform", "Task"]