snowflake-ml-python 1.9.2__py3-none-any.whl → 1.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. snowflake/ml/_internal/utils/service_logger.py +31 -17
  2. snowflake/ml/experiment/callback/keras.py +63 -0
  3. snowflake/ml/experiment/callback/lightgbm.py +59 -0
  4. snowflake/ml/experiment/callback/xgboost.py +67 -0
  5. snowflake/ml/experiment/utils.py +14 -0
  6. snowflake/ml/jobs/_utils/__init__.py +0 -0
  7. snowflake/ml/jobs/_utils/constants.py +4 -1
  8. snowflake/ml/jobs/_utils/payload_utils.py +55 -21
  9. snowflake/ml/jobs/_utils/query_helper.py +5 -1
  10. snowflake/ml/jobs/_utils/runtime_env_utils.py +63 -0
  11. snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +2 -2
  12. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +5 -5
  13. snowflake/ml/jobs/_utils/spec_utils.py +41 -8
  14. snowflake/ml/jobs/_utils/stage_utils.py +22 -9
  15. snowflake/ml/jobs/_utils/types.py +5 -7
  16. snowflake/ml/jobs/job.py +1 -1
  17. snowflake/ml/jobs/manager.py +1 -13
  18. snowflake/ml/model/_client/model/model_version_impl.py +219 -55
  19. snowflake/ml/model/_client/ops/service_ops.py +230 -30
  20. snowflake/ml/model/_client/service/model_deployment_spec.py +103 -27
  21. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +11 -5
  22. snowflake/ml/model/_model_composer/model_composer.py +1 -70
  23. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +2 -43
  24. snowflake/ml/model/event_handler.py +87 -18
  25. snowflake/ml/model/inference_engine.py +5 -0
  26. snowflake/ml/model/models/huggingface_pipeline.py +74 -51
  27. snowflake/ml/model/type_hints.py +26 -1
  28. snowflake/ml/registry/_manager/model_manager.py +37 -70
  29. snowflake/ml/registry/_manager/model_parameter_reconciler.py +294 -0
  30. snowflake/ml/registry/registry.py +0 -19
  31. snowflake/ml/version.py +1 -1
  32. {snowflake_ml_python-1.9.2.dist-info → snowflake_ml_python-1.11.0.dist-info}/METADATA +523 -491
  33. {snowflake_ml_python-1.9.2.dist-info → snowflake_ml_python-1.11.0.dist-info}/RECORD +36 -29
  34. snowflake/ml/experiment/callback.py +0 -121
  35. {snowflake_ml_python-1.9.2.dist-info → snowflake_ml_python-1.11.0.dist-info}/WHEEL +0 -0
  36. {snowflake_ml_python-1.9.2.dist-info → snowflake_ml_python-1.11.0.dist-info}/licenses/LICENSE.txt +0 -0
  37. {snowflake_ml_python-1.9.2.dist-info → snowflake_ml_python-1.11.0.dist-info}/top_level.txt +0 -0
@@ -1,17 +1,12 @@
1
1
  import pathlib
2
2
  import tempfile
3
3
  import uuid
4
- import warnings
5
4
  from types import ModuleType
6
5
  from typing import TYPE_CHECKING, Any, Optional, Union
7
6
  from urllib import parse
8
7
 
9
- from absl import logging
10
- from packaging import requirements
11
-
12
8
  from snowflake import snowpark
13
- from snowflake.ml import version as snowml_version
14
- from snowflake.ml._internal import env as snowml_env, env_utils, file_utils
9
+ from snowflake.ml._internal import file_utils
15
10
  from snowflake.ml._internal.lineage import lineage_utils
16
11
  from snowflake.ml.data import data_source
17
12
  from snowflake.ml.model import model_signature, type_hints as model_types
@@ -19,7 +14,6 @@ from snowflake.ml.model._model_composer.model_manifest import model_manifest
19
14
  from snowflake.ml.model._packager import model_packager
20
15
  from snowflake.ml.model._packager.model_meta import model_meta
21
16
  from snowflake.snowpark import Session
22
- from snowflake.snowpark._internal import utils as snowpark_utils
23
17
 
24
18
  if TYPE_CHECKING:
25
19
  from snowflake.ml.experiment._experiment_info import ExperimentInfo
@@ -142,73 +136,10 @@ class ModelComposer:
142
136
  experiment_info: Optional["ExperimentInfo"] = None,
143
137
  options: Optional[model_types.ModelSaveOption] = None,
144
138
  ) -> model_meta.ModelMetadata:
145
- # set enable_explainability=False if the model is not runnable in WH or the target platforms include SPCS
146
- conda_dep_dict = env_utils.validate_conda_dependency_string_list(
147
- conda_dependencies if conda_dependencies else []
148
- )
149
-
150
- enable_explainability = None
151
-
152
- if options:
153
- enable_explainability = options.get("enable_explainability", None)
154
-
155
- # skip everything if user said False explicitly
156
- if enable_explainability is None or enable_explainability is True:
157
- is_warehouse_runnable = (
158
- not conda_dep_dict
159
- or all(
160
- chan == env_utils.DEFAULT_CHANNEL_NAME or chan == env_utils.SNOWFLAKE_CONDA_CHANNEL_URL
161
- for chan in conda_dep_dict
162
- )
163
- ) and (not pip_requirements)
164
-
165
- only_spcs = (
166
- target_platforms
167
- and len(target_platforms) == 1
168
- and model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES in target_platforms
169
- )
170
- if only_spcs or (not is_warehouse_runnable):
171
- # if only SPCS and user asked for explainability we fail
172
- if enable_explainability is True:
173
- raise ValueError(
174
- "`enable_explainability` cannot be set to True when the model is not runnable in WH "
175
- "or the target platforms include SPCS."
176
- )
177
- elif not options: # explicitly set flag to false in these cases if not specified
178
- options = model_types.BaseModelSaveOption()
179
- options["enable_explainability"] = False
180
- elif (
181
- target_platforms
182
- and len(target_platforms) > 1
183
- and model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES in target_platforms
184
- ): # if both then only available for WH
185
- if enable_explainability is True:
186
- warnings.warn(
187
- ("Explain function will only be available for model deployed to warehouse."),
188
- category=UserWarning,
189
- stacklevel=2,
190
- )
191
139
 
192
140
  if not options:
193
141
  options = model_types.BaseModelSaveOption()
194
142
 
195
- if not snowpark_utils.is_in_stored_procedure() and target_platforms != [ # type: ignore[no-untyped-call]
196
- model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES # no information schema check for SPCS-only models
197
- ]:
198
- snowml_matched_versions = env_utils.get_matched_package_versions_in_information_schema(
199
- self.session,
200
- reqs=[requirements.Requirement(f"{env_utils.SNOWPARK_ML_PKG_NAME}=={snowml_version.VERSION}")],
201
- python_version=python_version or snowml_env.PYTHON_VERSION,
202
- statement_params=self._statement_params,
203
- ).get(env_utils.SNOWPARK_ML_PKG_NAME, [])
204
-
205
- if len(snowml_matched_versions) < 1 and options.get("embed_local_ml_library", False) is False:
206
- logging.info(
207
- f"Local snowflake-ml-python library has version {snowml_version.VERSION},"
208
- " which is not available in the Snowflake server, embedding local ML library automatically."
209
- )
210
- options["embed_local_ml_library"] = True
211
-
212
143
  model_metadata: model_meta.ModelMetadata = self.packager.save(
213
144
  name=name,
214
145
  model=model,
@@ -1,13 +1,11 @@
1
1
  import collections
2
2
  import logging
3
3
  import pathlib
4
- import warnings
5
4
  from typing import TYPE_CHECKING, Optional, cast
6
5
 
7
6
  import yaml
8
7
 
9
8
  from snowflake.ml._internal import env_utils
10
- from snowflake.ml._internal.exceptions import error_codes, exceptions
11
9
  from snowflake.ml.data import data_source
12
10
  from snowflake.ml.model import type_hints
13
11
  from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
@@ -55,47 +53,8 @@ class ModelManifest:
55
53
  experiment_info: Optional["ExperimentInfo"] = None,
56
54
  target_platforms: Optional[list[type_hints.TargetPlatform]] = None,
57
55
  ) -> None:
58
- if options is None:
59
- options = {}
60
-
61
- has_pip_requirements = len(model_meta.env.pip_requirements) > 0
62
- only_spcs = (
63
- target_platforms
64
- and len(target_platforms) == 1
65
- and target_platforms[0] == type_hints.TargetPlatform.SNOWPARK_CONTAINER_SERVICES
66
- )
67
-
68
- if "relax_version" not in options:
69
- if has_pip_requirements or only_spcs:
70
- logger.info(
71
- "Setting `relax_version=False` as this model will run in Snowpark Container Services "
72
- "or in Warehouse with a specified artifact_repository_map where exact version "
73
- " specifications will be honored."
74
- )
75
- relax_version = False
76
- else:
77
- warnings.warn(
78
- (
79
- "`relax_version` is not set and therefore defaulted to True. Dependency version constraints"
80
- " relaxed from ==x.y.z to >=x.y, <(x+1). To use specific dependency versions for compatibility,"
81
- " reproducibility, etc., set `options={'relax_version': False}` when logging the model."
82
- ),
83
- category=UserWarning,
84
- stacklevel=2,
85
- )
86
- relax_version = True
87
- options["relax_version"] = relax_version
88
- else:
89
- relax_version = options.get("relax_version", True)
90
- if relax_version and (has_pip_requirements or only_spcs):
91
- raise exceptions.SnowflakeMLException(
92
- error_code=error_codes.INVALID_ARGUMENT,
93
- original_exception=ValueError(
94
- "Setting `relax_version=True` is only allowed for models to be run in Warehouse with "
95
- "Snowflake Conda Channel dependencies. It cannot be used with pip requirements or when "
96
- "targeting only Snowpark Container Services."
97
- ),
98
- )
56
+ assert options is not None, "ModelParameterReconciler should have set options with relax_version"
57
+ relax_version = options["relax_version"]
99
58
 
100
59
  runtime_to_use = model_runtime.ModelRuntime(
101
60
  name=self._DEFAULT_RUNTIME_NAME,
@@ -23,12 +23,24 @@ class _TqdmStatusContext:
23
23
  if state == "complete":
24
24
  self._progress_bar.update(self._progress_bar.total - self._progress_bar.n)
25
25
  self._progress_bar.set_description(label)
26
+ elif state == "error":
27
+ # For error state, use the label as-is and mark with ERROR prefix
28
+ # Don't update progress bar position for errors - leave it where it was
29
+ self._progress_bar.set_description(f"❌ ERROR: {label}")
26
30
  else:
27
- self._progress_bar.set_description(f"{self._label}: {label}")
31
+ combined_desc = f"{self._label}: {label}" if label != self._label else self._label
32
+ self._progress_bar.set_description(combined_desc)
28
33
 
29
- def increment(self, n: int = 1) -> None:
34
+ def increment(self) -> None:
30
35
  """Increment the progress bar."""
31
- self._progress_bar.update(n)
36
+ self._progress_bar.update(1)
37
+
38
+ def complete(self) -> None:
39
+ """Complete the progress bar to full state."""
40
+ if self._total:
41
+ remaining = self._total - self._progress_bar.n
42
+ if remaining > 0:
43
+ self._progress_bar.update(remaining)
32
44
 
33
45
 
34
46
  class _StreamlitStatusContext:
@@ -39,6 +51,7 @@ class _StreamlitStatusContext:
39
51
  self._streamlit = streamlit_module
40
52
  self._total = total
41
53
  self._current = 0
54
+ self._current_label = label
42
55
  self._progress_bar = None
43
56
 
44
57
  def __enter__(self) -> "_StreamlitStatusContext":
@@ -49,26 +62,70 @@ class _StreamlitStatusContext:
49
62
  return self
50
63
 
51
64
  def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
52
- self._status_container.update(state="complete")
65
+ # Only update to complete if there was no exception
66
+ if exc_type is None:
67
+ self._status_container.update(state="complete")
53
68
 
54
69
  def update(self, label: str, *, state: str = "running", expanded: bool = True) -> None:
55
70
  """Update the status label."""
56
- if state != "complete":
57
- label = f"{self._label}: {label}"
58
- self._status_container.update(label=label, state=state, expanded=expanded)
59
- if self._progress_bar is not None:
60
- self._progress_bar.progress(
61
- self._current / self._total if self._total > 0 else 0,
62
- text=f"{label} - {self._current}/{self._total}",
63
- )
64
-
65
- def increment(self, n: int = 1) -> None:
71
+ if state == "complete" or state == "error":
72
+ # For completion/error, use the message as-is and update main status
73
+ self._status_container.update(label=label, state=state, expanded=expanded)
74
+ self._current_label = label
75
+
76
+ # For error state, update progress bar text but preserve position
77
+ if state == "error" and self._total is not None and self._progress_bar is not None:
78
+ self._progress_bar.progress(
79
+ self._current / self._total if self._total > 0 else 0,
80
+ text=f"ERROR - ({self._current}/{self._total})",
81
+ )
82
+ else:
83
+ combined_label = f"{self._label}: {label}" if label != self._label else self._label
84
+ self._status_container.update(label=combined_label, state=state, expanded=expanded)
85
+ self._current_label = label
86
+ if self._total is not None and self._progress_bar is not None:
87
+ progress_value = self._current / self._total if self._total > 0 else 0
88
+ self._progress_bar.progress(progress_value, text=f"({self._current}/{self._total})")
89
+
90
+ def increment(self) -> None:
66
91
  """Increment the progress."""
67
92
  if self._total is not None:
68
- self._current = min(self._current + n, self._total)
93
+ self._current = min(self._current + 1, self._total)
69
94
  if self._progress_bar is not None:
70
95
  progress_value = self._current / self._total if self._total > 0 else 0
71
- self._progress_bar.progress(progress_value, text=f"{self._current}/{self._total}")
96
+ self._progress_bar.progress(progress_value, text=f"({self._current}/{self._total})")
97
+
98
+ def complete(self) -> None:
99
+ """Complete the progress bar to full state."""
100
+ if self._total is not None:
101
+ self._current = self._total
102
+ if self._progress_bar is not None:
103
+ self._progress_bar.progress(1.0, text=f"({self._current}/{self._total})")
104
+
105
+
106
+ class _NoOpStatusContext:
107
+ """A no-op context manager for when status updates should be disabled."""
108
+
109
+ def __init__(self, label: str) -> None:
110
+ self._label = label
111
+
112
+ def __enter__(self) -> "_NoOpStatusContext":
113
+ return self
114
+
115
+ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
116
+ pass
117
+
118
+ def update(self, label: str, *, state: str = "running", expanded: bool = True) -> None:
119
+ """No-op update method."""
120
+ pass
121
+
122
+ def increment(self) -> None:
123
+ """No-op increment method."""
124
+ pass
125
+
126
+ def complete(self) -> None:
127
+ """No-op complete method."""
128
+ pass
72
129
 
73
130
 
74
131
  class ModelEventHandler:
@@ -99,7 +156,15 @@ class ModelEventHandler:
99
156
  else:
100
157
  self._tqdm.tqdm.write(message)
101
158
 
102
- def status(self, label: str, *, state: str = "running", expanded: bool = True, total: Optional[int] = None) -> Any:
159
+ def status(
160
+ self,
161
+ label: str,
162
+ *,
163
+ state: str = "running",
164
+ expanded: bool = True,
165
+ total: Optional[int] = None,
166
+ block: bool = True,
167
+ ) -> Any:
103
168
  """Context manager that provides status updates with optional enhanced display capabilities.
104
169
 
105
170
  Args:
@@ -107,10 +172,14 @@ class ModelEventHandler:
107
172
  state: The initial state ("running", "complete", "error")
108
173
  expanded: Whether to show expanded view (streamlit only)
109
174
  total: Total number of steps for progress tracking (optional)
175
+ block: Whether to show progress updates (no-op if False)
110
176
 
111
177
  Returns:
112
- Status context (Streamlit or Tqdm)
178
+ Status context (Streamlit, Tqdm, or NoOp based on availability and block parameter)
113
179
  """
180
+ if not block:
181
+ return _NoOpStatusContext(label)
182
+
114
183
  if self._streamlit is not None:
115
184
  return _StreamlitStatusContext(label, self._streamlit, total)
116
185
  else:
@@ -0,0 +1,5 @@
1
+ import enum
2
+
3
+
4
+ class InferenceEngine(enum.Enum):
5
+ VLLM = "vllm"
@@ -258,7 +258,7 @@ class HuggingFacePipelineModel:
258
258
  # model_version_impl.create_service parameters
259
259
  service_name: str,
260
260
  service_compute_pool: str,
261
- image_repo: str,
261
+ image_repo: Optional[str] = None,
262
262
  image_build_compute_pool: Optional[str] = None,
263
263
  ingress_enabled: bool = False,
264
264
  max_instances: int = 1,
@@ -282,7 +282,8 @@ class HuggingFacePipelineModel:
282
282
  comment: Comment for the model. Defaults to None.
283
283
  service_name: The name of the service to create.
284
284
  service_compute_pool: The compute pool for the service.
285
- image_repo: The name of the image repository.
285
+ image_repo: The name of the image repository. This can be None, in that case a default hidden image
286
+ repository will be used.
286
287
  image_build_compute_pool: The name of the compute pool used to build the model inference image. It uses
287
288
  the service compute pool if None.
288
289
  ingress_enabled: Whether ingress is enabled. Defaults to False.
@@ -299,6 +300,7 @@ class HuggingFacePipelineModel:
299
300
  Raises:
300
301
  ValueError: if database and schema name is not provided and session doesn't have a
301
302
  database and schema name.
303
+ exceptions.SnowparkSQLException: if service already exists.
302
304
 
303
305
  Returns:
304
306
  The service ID or an async job object.
@@ -327,7 +329,6 @@ class HuggingFacePipelineModel:
327
329
  version_name = name_generator.generate()[1]
328
330
 
329
331
  service_db_id, service_schema_id, service_id = sql_identifier.parse_fully_qualified_name(service_name)
330
- image_repo_db_id, image_repo_schema_id, image_repo_id = sql_identifier.parse_fully_qualified_name(image_repo)
331
332
 
332
333
  service_operator = service_ops.ServiceOperator(
333
334
  session=session,
@@ -336,51 +337,73 @@ class HuggingFacePipelineModel:
336
337
  )
337
338
  logger.info(f"A service job is going to register the hf model as: {model_name}.{version_name}")
338
339
 
339
- return service_operator.create_service(
340
- database_name=database_name_id,
341
- schema_name=schema_name_id,
342
- model_name=model_name_id,
343
- version_name=sql_identifier.SqlIdentifier(version_name),
344
- service_database_name=service_db_id,
345
- service_schema_name=service_schema_id,
346
- service_name=service_id,
347
- image_build_compute_pool_name=(
348
- sql_identifier.SqlIdentifier(image_build_compute_pool)
349
- if image_build_compute_pool
350
- else sql_identifier.SqlIdentifier(service_compute_pool)
351
- ),
352
- service_compute_pool_name=sql_identifier.SqlIdentifier(service_compute_pool),
353
- image_repo_database_name=image_repo_db_id,
354
- image_repo_schema_name=image_repo_schema_id,
355
- image_repo_name=image_repo_id,
356
- ingress_enabled=ingress_enabled,
357
- max_instances=max_instances,
358
- cpu_requests=cpu_requests,
359
- memory_requests=memory_requests,
360
- gpu_requests=gpu_requests,
361
- num_workers=num_workers,
362
- max_batch_rows=max_batch_rows,
363
- force_rebuild=force_rebuild,
364
- build_external_access_integrations=(
365
- None
366
- if build_external_access_integrations is None
367
- else [sql_identifier.SqlIdentifier(eai) for eai in build_external_access_integrations]
368
- ),
369
- block=block,
370
- statement_params=statement_params,
371
- # hf model
372
- hf_model_args=service_ops.HFModelArgs(
373
- hf_model_name=self.model,
374
- hf_task=self.task,
375
- hf_tokenizer=self.tokenizer,
376
- hf_revision=self.revision,
377
- hf_token=self.token,
378
- hf_trust_remote_code=bool(self.trust_remote_code),
379
- hf_model_kwargs=self.model_kwargs,
380
- pip_requirements=pip_requirements,
381
- conda_dependencies=conda_dependencies,
382
- comment=comment,
383
- # TODO: remove warehouse in the next release
384
- warehouse=session.get_current_warehouse(),
385
- ),
386
- )
340
+ from snowflake.ml.model import event_handler
341
+ from snowflake.snowpark import exceptions
342
+
343
+ hf_event_handler = event_handler.ModelEventHandler()
344
+ with hf_event_handler.status("Creating HuggingFace model service", total=6, block=block) as status:
345
+ try:
346
+ result = service_operator.create_service(
347
+ database_name=database_name_id,
348
+ schema_name=schema_name_id,
349
+ model_name=model_name_id,
350
+ version_name=sql_identifier.SqlIdentifier(version_name),
351
+ service_database_name=service_db_id,
352
+ service_schema_name=service_schema_id,
353
+ service_name=service_id,
354
+ image_build_compute_pool_name=(
355
+ sql_identifier.SqlIdentifier(image_build_compute_pool)
356
+ if image_build_compute_pool
357
+ else sql_identifier.SqlIdentifier(service_compute_pool)
358
+ ),
359
+ service_compute_pool_name=sql_identifier.SqlIdentifier(service_compute_pool),
360
+ image_repo_name=image_repo,
361
+ ingress_enabled=ingress_enabled,
362
+ max_instances=max_instances,
363
+ cpu_requests=cpu_requests,
364
+ memory_requests=memory_requests,
365
+ gpu_requests=gpu_requests,
366
+ num_workers=num_workers,
367
+ max_batch_rows=max_batch_rows,
368
+ force_rebuild=force_rebuild,
369
+ build_external_access_integrations=(
370
+ None
371
+ if build_external_access_integrations is None
372
+ else [sql_identifier.SqlIdentifier(eai) for eai in build_external_access_integrations]
373
+ ),
374
+ block=block,
375
+ progress_status=status,
376
+ statement_params=statement_params,
377
+ # hf model
378
+ hf_model_args=service_ops.HFModelArgs(
379
+ hf_model_name=self.model,
380
+ hf_task=self.task,
381
+ hf_tokenizer=self.tokenizer,
382
+ hf_revision=self.revision,
383
+ hf_token=self.token,
384
+ hf_trust_remote_code=bool(self.trust_remote_code),
385
+ hf_model_kwargs=self.model_kwargs,
386
+ pip_requirements=pip_requirements,
387
+ conda_dependencies=conda_dependencies,
388
+ comment=comment,
389
+ # TODO: remove warehouse in the next release
390
+ warehouse=session.get_current_warehouse(),
391
+ ),
392
+ )
393
+ status.update(label="HuggingFace model service created successfully", state="complete", expanded=False)
394
+ return result
395
+ except exceptions.SnowparkSQLException as e:
396
+ # Check if the error is because the service already exists
397
+ if "already exists" in str(e).lower() or "100132" in str(
398
+ e
399
+ ): # 100132 is Snowflake error code for object already exists
400
+ # Update progress to show service already exists (preserve exception behavior)
401
+ status.update("service already exists")
402
+ status.complete() # Complete progress to full state
403
+ status.update(label="Service already exists", state="error", expanded=False)
404
+ # Re-raise the exception to preserve existing API behavior
405
+ raise
406
+ else:
407
+ # Re-raise other SQL exceptions
408
+ status.update(label="Service creation failed", state="error", expanded=False)
409
+ raise
@@ -1,5 +1,14 @@
1
1
  # mypy: disable-error-code="import"
2
- from typing import TYPE_CHECKING, Literal, Sequence, TypedDict, TypeVar, Union
2
+ from typing import (
3
+ TYPE_CHECKING,
4
+ Any,
5
+ Literal,
6
+ Protocol,
7
+ Sequence,
8
+ TypedDict,
9
+ TypeVar,
10
+ Union,
11
+ )
3
12
 
4
13
  import numpy.typing as npt
5
14
  from typing_extensions import NotRequired
@@ -326,4 +335,20 @@ ModelLoadOption = Union[
326
335
  SupportedTargetPlatformType = Union[TargetPlatform, str]
327
336
 
328
337
 
338
+ class ProgressStatus(Protocol):
339
+ """Protocol for tracking progress during long-running operations."""
340
+
341
+ def update(self, message: str, *, state: str = "running", expanded: bool = True, **kwargs: Any) -> None:
342
+ """Update the progress status with a new message."""
343
+ ...
344
+
345
+ def increment(self) -> None:
346
+ """Increment the progress by one step."""
347
+ ...
348
+
349
+ def complete(self) -> None:
350
+ """Complete the progress bar to full state."""
351
+ ...
352
+
353
+
329
354
  __all__ = ["TargetPlatform", "Task"]