snowflake-ml-python 1.9.0__py3-none-any.whl → 1.9.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. snowflake/ml/_internal/env_utils.py +44 -3
  2. snowflake/ml/_internal/platform_capabilities.py +52 -2
  3. snowflake/ml/_internal/type_utils.py +1 -1
  4. snowflake/ml/_internal/utils/mixins.py +54 -42
  5. snowflake/ml/_internal/utils/service_logger.py +105 -3
  6. snowflake/ml/data/_internal/arrow_ingestor.py +15 -2
  7. snowflake/ml/data/data_connector.py +13 -2
  8. snowflake/ml/data/data_ingestor.py +8 -0
  9. snowflake/ml/data/torch_utils.py +1 -1
  10. snowflake/ml/dataset/dataset.py +2 -1
  11. snowflake/ml/dataset/dataset_reader.py +14 -4
  12. snowflake/ml/experiment/__init__.py +3 -0
  13. snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +98 -0
  14. snowflake/ml/experiment/_entities/__init__.py +4 -0
  15. snowflake/ml/experiment/_entities/experiment.py +10 -0
  16. snowflake/ml/experiment/_entities/run.py +62 -0
  17. snowflake/ml/experiment/_entities/run_metadata.py +68 -0
  18. snowflake/ml/experiment/_experiment_info.py +63 -0
  19. snowflake/ml/experiment/callback.py +121 -0
  20. snowflake/ml/experiment/experiment_tracking.py +319 -0
  21. snowflake/ml/jobs/_utils/constants.py +15 -4
  22. snowflake/ml/jobs/_utils/payload_utils.py +156 -54
  23. snowflake/ml/jobs/_utils/query_helper.py +16 -5
  24. snowflake/ml/jobs/_utils/scripts/constants.py +0 -22
  25. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +130 -23
  26. snowflake/ml/jobs/_utils/spec_utils.py +23 -8
  27. snowflake/ml/jobs/_utils/stage_utils.py +30 -14
  28. snowflake/ml/jobs/_utils/types.py +64 -4
  29. snowflake/ml/jobs/job.py +70 -75
  30. snowflake/ml/jobs/manager.py +59 -31
  31. snowflake/ml/lineage/lineage_node.py +2 -2
  32. snowflake/ml/model/_client/model/model_version_impl.py +16 -4
  33. snowflake/ml/model/_client/ops/service_ops.py +336 -137
  34. snowflake/ml/model/_client/service/model_deployment_spec.py +1 -1
  35. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -1
  36. snowflake/ml/model/_client/sql/service.py +1 -38
  37. snowflake/ml/model/_model_composer/model_composer.py +6 -1
  38. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +17 -3
  39. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
  40. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +41 -2
  41. snowflake/ml/model/_packager/model_handlers/sklearn.py +9 -5
  42. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +3 -1
  43. snowflake/ml/model/_packager/model_runtime/model_runtime.py +3 -3
  44. snowflake/ml/model/_signatures/pandas_handler.py +3 -0
  45. snowflake/ml/model/_signatures/utils.py +4 -0
  46. snowflake/ml/model/event_handler.py +117 -0
  47. snowflake/ml/model/model_signature.py +11 -9
  48. snowflake/ml/model/models/huggingface_pipeline.py +170 -1
  49. snowflake/ml/modeling/framework/base.py +1 -1
  50. snowflake/ml/modeling/metrics/classification.py +14 -14
  51. snowflake/ml/modeling/metrics/correlation.py +19 -8
  52. snowflake/ml/modeling/metrics/ranking.py +6 -6
  53. snowflake/ml/modeling/metrics/regression.py +9 -9
  54. snowflake/ml/monitoring/explain_visualize.py +12 -5
  55. snowflake/ml/registry/_manager/model_manager.py +32 -15
  56. snowflake/ml/registry/registry.py +48 -80
  57. snowflake/ml/version.py +1 -1
  58. {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/METADATA +107 -5
  59. {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/RECORD +62 -52
  60. {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/WHEEL +0 -0
  61. {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/licenses/LICENSE.txt +0 -0
  62. {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,5 @@
1
1
  import dataclasses
2
+ import enum
2
3
  import hashlib
3
4
  import logging
4
5
  import pathlib
@@ -22,20 +23,63 @@ module_logger = service_logger.get_logger(__name__, service_logger.LogColor.GREY
22
23
  module_logger.propagate = False
23
24
 
24
25
 
26
+ class DeploymentStep(enum.Enum):
27
+ MODEL_BUILD = ("model-build", "model_build_")
28
+ MODEL_INFERENCE = ("model-inference", None)
29
+ MODEL_LOGGING = ("model-logging", "model_logging_")
30
+
31
+ def __init__(self, container_name: str, service_name_prefix: Optional[str]) -> None:
32
+ self._container_name = container_name
33
+ self._service_name_prefix = service_name_prefix
34
+
35
+ @property
36
+ def container_name(self) -> str:
37
+ """Get the container name for the deployment step."""
38
+ return self._container_name
39
+
40
+ @property
41
+ def service_name_prefix(self) -> Optional[str]:
42
+ """Get the service name prefix for the deployment step."""
43
+ return self._service_name_prefix
44
+
45
+
25
46
  @dataclasses.dataclass
26
47
  class ServiceLogInfo:
27
48
  database_name: Optional[sql_identifier.SqlIdentifier]
28
49
  schema_name: Optional[sql_identifier.SqlIdentifier]
29
50
  service_name: sql_identifier.SqlIdentifier
30
- container_name: str
51
+ deployment_step: DeploymentStep
31
52
  instance_id: str = "0"
53
+ log_color: service_logger.LogColor = service_logger.LogColor.GREY
32
54
 
33
55
  def __post_init__(self) -> None:
34
56
  # service name used in logs for display
35
57
  self.display_service_name = sql_identifier.get_fully_qualified_name(
36
- self.database_name, self.schema_name, self.service_name
58
+ self.database_name,
59
+ self.schema_name,
60
+ self.service_name,
37
61
  )
38
62
 
63
+ def fetch_logs(
64
+ self,
65
+ service_client: service_sql.ServiceSQLClient,
66
+ offset: int,
67
+ statement_params: Optional[dict[str, Any]],
68
+ ) -> tuple[str, int]:
69
+ service_logs = service_client.get_service_logs(
70
+ database_name=self.database_name,
71
+ schema_name=self.schema_name,
72
+ service_name=self.service_name,
73
+ container_name=self.deployment_step.container_name,
74
+ statement_params=statement_params,
75
+ )
76
+
77
+ # return only new logs starting after the offset
78
+ new_logs = service_logs[offset:]
79
+ new_offset = max(offset, len(service_logs))
80
+
81
+ return new_logs, new_offset
82
+
39
83
 
40
84
  @dataclasses.dataclass
41
85
  class ServiceLogMetadata:
@@ -43,8 +87,49 @@ class ServiceLogMetadata:
43
87
  service: ServiceLogInfo
44
88
  service_status: Optional[service_sql.ServiceStatus]
45
89
  is_model_build_service_done: bool
90
+ is_model_logger_service_done: bool
46
91
  log_offset: int
47
92
 
93
+ def transition_service_log_metadata(
94
+ self,
95
+ to_service: ServiceLogInfo,
96
+ msg: str,
97
+ is_model_build_service_done: bool,
98
+ is_model_logger_service_done: bool,
99
+ operation_id: str,
100
+ propagate: bool = False,
101
+ ) -> None:
102
+ to_service_logger = service_logger.get_logger(
103
+ f"{to_service.display_service_name}-{to_service.instance_id}",
104
+ to_service.log_color,
105
+ operation_id=operation_id,
106
+ )
107
+ to_service_logger.propagate = propagate
108
+ self.service_logger = to_service_logger
109
+ self.service = to_service
110
+ self.service_status = None
111
+ self.is_model_build_service_done = is_model_build_service_done
112
+ self.is_model_logger_service_done = is_model_logger_service_done
113
+ self.log_offset = 0
114
+ block_size = 180
115
+ module_logger.info(msg)
116
+ module_logger.info("-" * block_size)
117
+
118
+
119
+ @dataclasses.dataclass
120
+ class HFModelArgs:
121
+ hf_model_name: str
122
+ hf_task: Optional[str] = None
123
+ hf_tokenizer: Optional[str] = None
124
+ hf_revision: Optional[str] = None
125
+ hf_token: Optional[str] = None
126
+ hf_trust_remote_code: bool = False
127
+ hf_model_kwargs: Optional[dict[str, Any]] = None
128
+ pip_requirements: Optional[list[str]] = None
129
+ conda_dependencies: Optional[list[str]] = None
130
+ comment: Optional[str] = None
131
+ warehouse: Optional[str] = None
132
+
48
133
 
49
134
  class ServiceOperator:
50
135
  """Service operator for container services logic."""
@@ -109,8 +194,13 @@ class ServiceOperator:
109
194
  build_external_access_integrations: Optional[list[sql_identifier.SqlIdentifier]],
110
195
  block: bool,
111
196
  statement_params: Optional[dict[str, Any]] = None,
197
+ # hf model
198
+ hf_model_args: Optional[HFModelArgs] = None,
112
199
  ) -> Union[str, async_job.AsyncJob]:
113
200
 
201
+ # Generate operation ID for this deployment
202
+ operation_id = service_logger.get_operation_id()
203
+
114
204
  # Fall back to the registry's database and schema if not provided
115
205
  database_name = database_name or self._database_name
116
206
  schema_name = schema_name or self._schema_name
@@ -153,6 +243,21 @@ class ServiceOperator:
153
243
  num_workers=num_workers,
154
244
  max_batch_rows=max_batch_rows,
155
245
  )
246
+ if hf_model_args:
247
+ # hf model
248
+ self._model_deployment_spec.add_hf_logger_spec(
249
+ hf_model_name=hf_model_args.hf_model_name,
250
+ hf_task=hf_model_args.hf_task,
251
+ hf_token=hf_model_args.hf_token,
252
+ hf_tokenizer=hf_model_args.hf_tokenizer,
253
+ hf_revision=hf_model_args.hf_revision,
254
+ hf_trust_remote_code=hf_model_args.hf_trust_remote_code,
255
+ pip_requirements=hf_model_args.pip_requirements,
256
+ conda_dependencies=hf_model_args.conda_dependencies,
257
+ comment=hf_model_args.comment,
258
+ warehouse=hf_model_args.warehouse,
259
+ **(hf_model_args.hf_model_kwargs if hf_model_args.hf_model_kwargs else {}),
260
+ )
156
261
  spec_yaml_str_or_path = self._model_deployment_spec.save()
157
262
  if self._workspace:
158
263
  assert stage_path is not None
@@ -187,22 +292,48 @@ class ServiceOperator:
187
292
  )
188
293
 
189
294
  # stream service logs in a thread
190
- model_build_service_name = sql_identifier.SqlIdentifier(self._get_model_build_service_name(query_id))
295
+ model_build_service_name = sql_identifier.SqlIdentifier(
296
+ self._get_service_id_from_deployment_step(query_id, DeploymentStep.MODEL_BUILD)
297
+ )
191
298
  model_build_service = ServiceLogInfo(
192
299
  database_name=service_database_name,
193
300
  schema_name=service_schema_name,
194
301
  service_name=model_build_service_name,
195
- container_name="model-build",
302
+ deployment_step=DeploymentStep.MODEL_BUILD,
303
+ log_color=service_logger.LogColor.GREEN,
196
304
  )
197
305
  model_inference_service = ServiceLogInfo(
198
306
  database_name=service_database_name,
199
307
  schema_name=service_schema_name,
200
308
  service_name=service_name,
201
- container_name="model-inference",
309
+ deployment_step=DeploymentStep.MODEL_INFERENCE,
310
+ log_color=service_logger.LogColor.BLUE,
202
311
  )
203
- services = [model_build_service, model_inference_service]
312
+
313
+ model_logger_service: Optional[ServiceLogInfo] = None
314
+ if hf_model_args:
315
+ model_logger_service_name = sql_identifier.SqlIdentifier(
316
+ self._get_service_id_from_deployment_step(query_id, DeploymentStep.MODEL_LOGGING)
317
+ )
318
+
319
+ model_logger_service = ServiceLogInfo(
320
+ database_name=service_database_name,
321
+ schema_name=service_schema_name,
322
+ service_name=model_logger_service_name,
323
+ deployment_step=DeploymentStep.MODEL_LOGGING,
324
+ log_color=service_logger.LogColor.ORANGE,
325
+ )
326
+
327
+ # start service log streaming
204
328
  log_thread = self._start_service_log_streaming(
205
- async_job, services, model_inference_service_exists, force_rebuild, statement_params
329
+ async_job=async_job,
330
+ model_logger_service=model_logger_service,
331
+ model_build_service=model_build_service,
332
+ model_inference_service=model_inference_service,
333
+ model_inference_service_exists=model_inference_service_exists,
334
+ force_rebuild=force_rebuild,
335
+ operation_id=operation_id,
336
+ statement_params=statement_params,
206
337
  )
207
338
 
208
339
  if block:
@@ -217,170 +348,232 @@ class ServiceOperator:
217
348
  def _start_service_log_streaming(
218
349
  self,
219
350
  async_job: snowpark.AsyncJob,
220
- services: list[ServiceLogInfo],
351
+ model_logger_service: Optional[ServiceLogInfo],
352
+ model_build_service: ServiceLogInfo,
353
+ model_inference_service: ServiceLogInfo,
221
354
  model_inference_service_exists: bool,
222
355
  force_rebuild: bool,
356
+ operation_id: str,
223
357
  statement_params: Optional[dict[str, Any]] = None,
224
358
  ) -> threading.Thread:
225
359
  """Start the service log streaming in a separate thread."""
360
+ # TODO: create a DAG of services and stream logs in the order of the DAG
226
361
  log_thread = threading.Thread(
227
362
  target=self._stream_service_logs,
228
363
  args=(
229
364
  async_job,
230
- services,
365
+ model_logger_service,
366
+ model_build_service,
367
+ model_inference_service,
231
368
  model_inference_service_exists,
232
369
  force_rebuild,
370
+ operation_id,
233
371
  statement_params,
234
372
  ),
235
373
  )
236
374
  log_thread.start()
237
375
  return log_thread
238
376
 
239
- def _stream_service_logs(
377
+ def _fetch_log_and_update_meta(
240
378
  self,
241
- async_job: snowpark.AsyncJob,
242
- services: list[ServiceLogInfo],
243
- model_inference_service_exists: bool,
244
379
  force_rebuild: bool,
380
+ service_log_meta: ServiceLogMetadata,
381
+ model_build_service: ServiceLogInfo,
382
+ model_inference_service: ServiceLogInfo,
383
+ operation_id: str,
245
384
  statement_params: Optional[dict[str, Any]] = None,
246
385
  ) -> None:
247
- """Stream service logs while the async job is running."""
386
+ """Helper function to fetch logs and update the service log metadata if needed.
387
+
388
+ This function checks the service status and fetches logs if the service exists.
389
+ It also updates the service log metadata with the
390
+ new service status and logs.
391
+ If the service is done, it transitions the service log metadata.
392
+
393
+ Args:
394
+ force_rebuild: Whether to force rebuild the model build image.
395
+ service_log_meta: The ServiceLogMetadata holds the state of the service log metadata.
396
+ model_build_service: The ServiceLogInfo for the model build service.
397
+ model_inference_service: The ServiceLogInfo for the model inference service.
398
+ operation_id: The operation ID for the service, e.g. "model_deploy_a1b2c3d4_1703875200"
399
+ statement_params: The statement parameters to use for the service client.
400
+ """
401
+
402
+ service = service_log_meta.service
403
+ # check if using an existing model build image
404
+ if (
405
+ service.deployment_step == DeploymentStep.MODEL_BUILD
406
+ and not force_rebuild
407
+ and service_log_meta.is_model_logger_service_done
408
+ and not service_log_meta.is_model_build_service_done
409
+ ):
410
+ model_build_service_exists = self._check_if_service_exists(
411
+ database_name=service.database_name,
412
+ schema_name=service.schema_name,
413
+ service_name=service.service_name,
414
+ statement_params=statement_params,
415
+ )
416
+ new_model_inference_service_exists = self._check_if_service_exists(
417
+ database_name=model_inference_service.database_name,
418
+ schema_name=model_inference_service.schema_name,
419
+ service_name=model_inference_service.service_name,
420
+ statement_params=statement_params,
421
+ )
422
+ if not model_build_service_exists and new_model_inference_service_exists:
423
+ service_log_meta.transition_service_log_metadata(
424
+ model_inference_service,
425
+ "Model build is not rebuilding the inference image, but using a previously built image.",
426
+ is_model_build_service_done=True,
427
+ is_model_logger_service_done=service_log_meta.is_model_logger_service_done,
428
+ operation_id=operation_id,
429
+ )
248
430
 
249
- def fetch_logs(service: ServiceLogInfo, offset: int) -> tuple[str, int]:
250
- service_logs = self._service_client.get_service_logs(
431
+ try:
432
+ statuses = self._service_client.get_service_container_statuses(
251
433
  database_name=service.database_name,
252
434
  schema_name=service.schema_name,
253
435
  service_name=service.service_name,
254
- container_name=service.container_name,
436
+ include_message=True,
255
437
  statement_params=statement_params,
256
438
  )
439
+ service_status = statuses[0].service_status
440
+ except exceptions.SnowparkSQLException:
441
+ # If the service is not found, log that the service is not found
442
+ # and wait for a few seconds before returning
443
+ module_logger.info(f"Service status for service {service.display_service_name} not found.")
444
+ time.sleep(5)
445
+ return
257
446
 
258
- # return only new logs starting after the offset
259
- if len(service_logs) > offset:
260
- new_logs = service_logs[offset:]
261
- new_offset = len(service_logs)
262
- else:
263
- new_logs = ""
264
- new_offset = offset
447
+ # Case 1: service_status is PENDING or the service_status changed
448
+ if (service_status != service_sql.ServiceStatus.RUNNING) or (service_status != service_log_meta.service_status):
449
+ service_log_meta.service_status = service_status
265
450
 
266
- return new_logs, new_offset
451
+ if service.deployment_step == DeploymentStep.MODEL_BUILD:
452
+ module_logger.info(
453
+ f"Image build service {service.display_service_name} is "
454
+ f"{service_log_meta.service_status.value}."
455
+ )
456
+ elif service.deployment_step == DeploymentStep.MODEL_INFERENCE:
457
+ module_logger.info(
458
+ f"Inference service {service.display_service_name} is {service_log_meta.service_status.value}."
459
+ )
460
+ elif service.deployment_step == DeploymentStep.MODEL_LOGGING:
461
+ module_logger.info(
462
+ f"Model logger service {service.display_service_name} is "
463
+ f"{service_log_meta.service_status.value}."
464
+ )
465
+ for status in statuses:
466
+ if status.instance_id is not None:
467
+ instance_status, container_status = None, None
468
+ if status.instance_status is not None:
469
+ instance_status = status.instance_status.value
470
+ if status.container_status is not None:
471
+ container_status = status.container_status.value
472
+ module_logger.info(
473
+ f"Instance[{status.instance_id}]: "
474
+ f"instance status: {instance_status}, "
475
+ f"container status: {container_status}, "
476
+ f"message: {status.message}"
477
+ )
478
+ time.sleep(5)
267
479
 
268
- def set_service_log_metadata_to_model_inference(
269
- meta: ServiceLogMetadata, inference_service: ServiceLogInfo, msg: str
270
- ) -> None:
271
- model_inference_service_logger = service_logger.get_logger( # InferenceServiceName-InstanceId
272
- f"{inference_service.display_service_name}-{inference_service.instance_id}",
273
- service_logger.LogColor.BLUE,
480
+ # Case 2: service_status is RUNNING
481
+ # stream logs and update the log offset
482
+ if service_status == service_sql.ServiceStatus.RUNNING:
483
+ new_logs, new_offset = service.fetch_logs(
484
+ self._service_client,
485
+ service_log_meta.log_offset,
486
+ statement_params=statement_params,
274
487
  )
275
- model_inference_service_logger.propagate = False
276
- meta.service_logger = model_inference_service_logger
277
- meta.service = inference_service
278
- meta.service_status = None
279
- meta.is_model_build_service_done = True
280
- meta.log_offset = 0
281
- block_size = 180
282
- module_logger.info(msg)
283
- module_logger.info("-" * block_size)
284
-
285
- model_build_service, model_inference_service = services[0], services[1]
488
+ if new_logs:
489
+ service_log_meta.service_logger.info(new_logs)
490
+ service_log_meta.log_offset = new_offset
491
+
492
+ # Case 3: service_status is DONE
493
+ if service_status == service_sql.ServiceStatus.DONE:
494
+ # check if model logger service is done
495
+ # and transition the service log metadata to the model image build service
496
+ if service.deployment_step == DeploymentStep.MODEL_LOGGING:
497
+ service_log_meta.transition_service_log_metadata(
498
+ model_build_service,
499
+ f"Model Logger service {service.display_service_name} complete.",
500
+ is_model_build_service_done=False,
501
+ is_model_logger_service_done=service_log_meta.is_model_logger_service_done,
502
+ operation_id=operation_id,
503
+ )
504
+ # check if model build service is done
505
+ # and transition the service log metadata to the model inference service
506
+ elif service.deployment_step == DeploymentStep.MODEL_BUILD:
507
+ service_log_meta.transition_service_log_metadata(
508
+ model_inference_service,
509
+ f"Image build service {service.display_service_name} complete.",
510
+ is_model_build_service_done=True,
511
+ is_model_logger_service_done=service_log_meta.is_model_logger_service_done,
512
+ operation_id=operation_id,
513
+ )
514
+ else:
515
+ module_logger.warning(f"Service {service.display_service_name} is done, but not transitioning.")
516
+
517
+ def _stream_service_logs(
518
+ self,
519
+ async_job: snowpark.AsyncJob,
520
+ model_logger_service: Optional[ServiceLogInfo],
521
+ model_build_service: ServiceLogInfo,
522
+ model_inference_service: ServiceLogInfo,
523
+ model_inference_service_exists: bool,
524
+ force_rebuild: bool,
525
+ operation_id: str,
526
+ statement_params: Optional[dict[str, Any]] = None,
527
+ ) -> None:
528
+ """Stream service logs while the async job is running."""
529
+
286
530
  model_build_service_logger = service_logger.get_logger( # BuildJobName
287
- model_build_service.display_service_name, service_logger.LogColor.GREEN
288
- )
289
- model_build_service_logger.propagate = False
290
- service_log_meta = ServiceLogMetadata(
291
- service_logger=model_build_service_logger,
292
- service=model_build_service,
293
- service_status=None,
294
- is_model_build_service_done=False,
295
- log_offset=0,
531
+ model_build_service.display_service_name,
532
+ model_build_service.log_color,
533
+ operation_id=operation_id,
296
534
  )
535
+ if model_logger_service:
536
+ model_logger_service_logger = service_logger.get_logger( # ModelLoggerName
537
+ model_logger_service.display_service_name,
538
+ model_logger_service.log_color,
539
+ operation_id=operation_id,
540
+ )
541
+
542
+ service_log_meta = ServiceLogMetadata(
543
+ service_logger=model_logger_service_logger,
544
+ service=model_logger_service,
545
+ service_status=None,
546
+ is_model_build_service_done=False,
547
+ is_model_logger_service_done=False,
548
+ log_offset=0,
549
+ )
550
+ else:
551
+ service_log_meta = ServiceLogMetadata(
552
+ service_logger=model_build_service_logger,
553
+ service=model_build_service,
554
+ service_status=None,
555
+ is_model_build_service_done=False,
556
+ is_model_logger_service_done=True,
557
+ log_offset=0,
558
+ )
559
+
297
560
  while not async_job.is_done():
298
561
  if model_inference_service_exists:
299
562
  time.sleep(5)
300
563
  continue
301
564
 
302
565
  try:
303
- # check if using an existing model build image
304
- if not force_rebuild and not service_log_meta.is_model_build_service_done:
305
- model_build_service_exists = self._check_if_service_exists(
306
- database_name=model_build_service.database_name,
307
- schema_name=model_build_service.schema_name,
308
- service_name=model_build_service.service_name,
309
- statement_params=statement_params,
310
- )
311
- new_model_inference_service_exists = self._check_if_service_exists(
312
- database_name=model_inference_service.database_name,
313
- schema_name=model_inference_service.schema_name,
314
- service_name=model_inference_service.service_name,
315
- statement_params=statement_params,
316
- )
317
- if not model_build_service_exists and new_model_inference_service_exists:
318
- set_service_log_metadata_to_model_inference(
319
- service_log_meta,
320
- model_inference_service,
321
- (
322
- "Model Inference image build is not rebuilding the image, but using a previously built "
323
- "image."
324
- ),
325
- )
326
- continue
327
-
328
- statuses = self._service_client.get_service_container_statuses(
329
- database_name=service_log_meta.service.database_name,
330
- schema_name=service_log_meta.service.schema_name,
331
- service_name=service_log_meta.service.service_name,
332
- include_message=True,
566
+ # fetch logs for the service
567
+ # (model logging, model build, or model inference)
568
+ # upon completion, transition to the next service if any
569
+ self._fetch_log_and_update_meta(
570
+ service_log_meta=service_log_meta,
571
+ force_rebuild=force_rebuild,
572
+ model_build_service=model_build_service,
573
+ model_inference_service=model_inference_service,
574
+ operation_id=operation_id,
333
575
  statement_params=statement_params,
334
576
  )
335
- service_status = statuses[0].service_status
336
- if (service_status != service_sql.ServiceStatus.RUNNING) or (
337
- service_status != service_log_meta.service_status
338
- ):
339
- service_log_meta.service_status = service_status
340
- module_logger.info(
341
- f"{'Inference' if service_log_meta.is_model_build_service_done else 'Image build'} service "
342
- f"{service_log_meta.service.display_service_name} is "
343
- f"{service_log_meta.service_status.value}."
344
- )
345
- for status in statuses:
346
- if status.instance_id is not None:
347
- instance_status, container_status = None, None
348
- if status.instance_status is not None:
349
- instance_status = status.instance_status.value
350
- if status.container_status is not None:
351
- container_status = status.container_status.value
352
- module_logger.info(
353
- f"Instance[{status.instance_id}]: "
354
- f"instance status: {instance_status}, "
355
- f"container status: {container_status}, "
356
- f"message: {status.message}"
357
- )
358
-
359
- new_logs, new_offset = fetch_logs(
360
- service_log_meta.service,
361
- service_log_meta.log_offset,
362
- )
363
- if new_logs:
364
- service_log_meta.service_logger.info(new_logs)
365
- service_log_meta.log_offset = new_offset
366
-
367
- # check if model build service is done
368
- if not service_log_meta.is_model_build_service_done:
369
- statuses = self._service_client.get_service_container_statuses(
370
- database_name=model_build_service.database_name,
371
- schema_name=model_build_service.schema_name,
372
- service_name=model_build_service.service_name,
373
- include_message=False,
374
- statement_params=statement_params,
375
- )
376
- service_status = statuses[0].service_status
377
-
378
- if service_status == service_sql.ServiceStatus.DONE:
379
- set_service_log_metadata_to_model_inference(
380
- service_log_meta,
381
- model_inference_service,
382
- f"Image build service {model_build_service.display_service_name} complete.",
383
- )
384
577
  except Exception as ex:
385
578
  pattern = r"002003 \(02000\)" # error code: service does not exist
386
579
  is_snowpark_sql_exception = isinstance(ex, exceptions.SnowparkSQLException)
@@ -388,8 +581,7 @@ class ServiceOperator:
388
581
  matches_pattern = service_log_meta.service_status is None and re.search(pattern, str(ex)) is not None
389
582
  if not (is_snowpark_sql_exception and (contains_msg or matches_pattern)):
390
583
  module_logger.warning(f"Caught an exception when logging: {repr(ex)}")
391
-
392
- time.sleep(5)
584
+ time.sleep(5)
393
585
 
394
586
  if model_inference_service_exists:
395
587
  module_logger.info(
@@ -397,7 +589,10 @@ class ServiceOperator:
397
589
  )
398
590
  else:
399
591
  self._finalize_logs(
400
- service_log_meta.service_logger, service_log_meta.service, service_log_meta.log_offset, statement_params
592
+ service_log_meta.service_logger,
593
+ service_log_meta.service,
594
+ service_log_meta.log_offset,
595
+ statement_params,
401
596
  )
402
597
 
403
598
  def _finalize_logs(
@@ -414,7 +609,7 @@ class ServiceOperator:
414
609
  database_name=service.database_name,
415
610
  schema_name=service.schema_name,
416
611
  service_name=service.service_name,
417
- container_name=service.container_name,
612
+ container_name=service.deployment_step.container_name,
418
613
  statement_params=statement_params,
419
614
  )
420
615
 
@@ -424,13 +619,17 @@ class ServiceOperator:
424
619
  module_logger.warning(f"Caught an exception when logging: {repr(ex)}")
425
620
 
426
621
  @staticmethod
427
- def _get_model_build_service_name(query_id: str) -> str:
428
- """Get the model build service name through the server-side logic."""
622
+ def _get_service_id_from_deployment_step(query_id: str, deployment_step: DeploymentStep) -> str:
623
+ """Get the service ID through the server-side logic."""
429
624
  uuid = query_id.replace("-", "")
430
625
  big_int = int(uuid, 16)
431
626
  md5_hash = hashlib.md5(str(big_int).encode()).hexdigest()
432
627
  identifier = md5_hash[:8]
433
- return ("model_build_" + identifier).upper()
628
+ service_name_prefix = deployment_step.service_name_prefix
629
+ if service_name_prefix is None:
630
+ # raise an exception if the service name prefix is None
631
+ raise ValueError(f"Service name prefix is {service_name_prefix} for deployment step {deployment_step}.")
632
+ return (service_name_prefix + identifier).upper()
434
633
 
435
634
  def _check_if_service_exists(
436
635
  self,
@@ -347,7 +347,7 @@ class ModelDeploymentSpec:
347
347
  hf_model = model_deployment_spec_schema.HuggingFaceModel(
348
348
  hf_model_name=hf_model_name,
349
349
  task=hf_task,
350
- hf_token=hf_token,
350
+ token=hf_token,
351
351
  tokenizer=hf_tokenizer,
352
352
  trust_remote_code=hf_trust_remote_code,
353
353
  revision=hf_revision,
@@ -55,7 +55,7 @@ class HuggingFaceModel(BaseModel):
55
55
  hf_model_name: str
56
56
  task: Optional[str] = None
57
57
  tokenizer: Optional[str] = None
58
- hf_token: Optional[str] = None
58
+ token: Optional[str] = None
59
59
  trust_remote_code: Optional[bool] = False
60
60
  revision: Optional[str] = None
61
61
  hf_model_kwargs: Optional[str] = "{}"