snowflake-ml-python 1.20.0__py3-none-any.whl → 1.22.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/platform_capabilities.py +36 -0
- snowflake/ml/_internal/utils/url.py +42 -0
- snowflake/ml/data/_internal/arrow_ingestor.py +67 -2
- snowflake/ml/data/data_connector.py +103 -1
- snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +8 -2
- snowflake/ml/experiment/callback/__init__.py +0 -0
- snowflake/ml/experiment/callback/keras.py +25 -2
- snowflake/ml/experiment/callback/lightgbm.py +27 -2
- snowflake/ml/experiment/callback/xgboost.py +25 -2
- snowflake/ml/experiment/experiment_tracking.py +93 -3
- snowflake/ml/experiment/utils.py +6 -0
- snowflake/ml/feature_store/feature_view.py +34 -24
- snowflake/ml/jobs/_interop/protocols.py +3 -0
- snowflake/ml/jobs/_utils/constants.py +1 -0
- snowflake/ml/jobs/_utils/payload_utils.py +354 -356
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +95 -8
- snowflake/ml/jobs/_utils/scripts/start_mlruntime.sh +92 -0
- snowflake/ml/jobs/_utils/scripts/startup.sh +112 -0
- snowflake/ml/jobs/_utils/spec_utils.py +1 -445
- snowflake/ml/jobs/_utils/stage_utils.py +22 -1
- snowflake/ml/jobs/_utils/types.py +14 -7
- snowflake/ml/jobs/job.py +2 -8
- snowflake/ml/jobs/manager.py +57 -135
- snowflake/ml/lineage/lineage_node.py +1 -1
- snowflake/ml/model/__init__.py +6 -0
- snowflake/ml/model/_client/model/batch_inference_specs.py +16 -1
- snowflake/ml/model/_client/model/model_version_impl.py +130 -14
- snowflake/ml/model/_client/ops/deployment_step.py +36 -0
- snowflake/ml/model/_client/ops/model_ops.py +93 -8
- snowflake/ml/model/_client/ops/service_ops.py +32 -52
- snowflake/ml/model/_client/service/import_model_spec_schema.py +23 -0
- snowflake/ml/model/_client/service/model_deployment_spec.py +12 -4
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +3 -0
- snowflake/ml/model/_client/sql/model_version.py +30 -6
- snowflake/ml/model/_client/sql/service.py +94 -5
- snowflake/ml/model/_model_composer/model_composer.py +1 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +5 -0
- snowflake/ml/model/_model_composer/model_method/model_method.py +61 -2
- snowflake/ml/model/_packager/model_handler.py +8 -2
- snowflake/ml/model/_packager/model_handlers/custom.py +52 -0
- snowflake/ml/model/_packager/model_handlers/{huggingface_pipeline.py → huggingface.py} +203 -76
- snowflake/ml/model/_packager/model_handlers/mlflow.py +6 -1
- snowflake/ml/model/_packager/model_handlers/xgboost.py +26 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +40 -7
- snowflake/ml/model/_packager/model_packager.py +1 -1
- snowflake/ml/model/_signatures/core.py +390 -8
- snowflake/ml/model/_signatures/utils.py +13 -4
- snowflake/ml/model/code_path.py +104 -0
- snowflake/ml/model/compute_pool.py +2 -0
- snowflake/ml/model/custom_model.py +55 -13
- snowflake/ml/model/model_signature.py +13 -1
- snowflake/ml/model/models/huggingface.py +285 -0
- snowflake/ml/model/models/huggingface_pipeline.py +19 -208
- snowflake/ml/model/type_hints.py +7 -1
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +12 -0
- snowflake/ml/monitoring/_manager/model_monitor_manager.py +12 -0
- snowflake/ml/monitoring/entities/model_monitor_config.py +5 -0
- snowflake/ml/registry/_manager/model_manager.py +230 -15
- snowflake/ml/registry/registry.py +4 -4
- snowflake/ml/utils/html_utils.py +67 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/METADATA +81 -7
- {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/RECORD +67 -59
- snowflake/ml/jobs/_utils/runtime_env_utils.py +0 -63
- {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/top_level.txt +0 -0
snowflake/ml/jobs/manager.py
CHANGED
|
@@ -1,27 +1,25 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import logging
|
|
3
|
+
import os
|
|
3
4
|
import pathlib
|
|
4
5
|
import sys
|
|
5
|
-
import textwrap
|
|
6
6
|
from pathlib import PurePath
|
|
7
7
|
from typing import Any, Callable, Optional, TypeVar, Union, cast, overload
|
|
8
8
|
from uuid import uuid4
|
|
9
9
|
|
|
10
10
|
import pandas as pd
|
|
11
|
-
import yaml
|
|
12
11
|
|
|
13
12
|
from snowflake import snowpark
|
|
14
13
|
from snowflake.ml._internal import telemetry
|
|
15
14
|
from snowflake.ml._internal.utils import identifier
|
|
16
15
|
from snowflake.ml.jobs import job as jb
|
|
17
16
|
from snowflake.ml.jobs._utils import (
|
|
17
|
+
constants,
|
|
18
18
|
feature_flags,
|
|
19
19
|
payload_utils,
|
|
20
20
|
query_helper,
|
|
21
|
-
spec_utils,
|
|
22
21
|
types,
|
|
23
22
|
)
|
|
24
|
-
from snowflake.snowpark._internal import utils as sp_utils
|
|
25
23
|
from snowflake.snowpark.context import get_active_session
|
|
26
24
|
from snowflake.snowpark.exceptions import SnowparkSQLException
|
|
27
25
|
from snowflake.snowpark.functions import coalesce, col, lit, when
|
|
@@ -259,7 +257,7 @@ def submit_directory(
|
|
|
259
257
|
dir_path: str,
|
|
260
258
|
compute_pool: str,
|
|
261
259
|
*,
|
|
262
|
-
entrypoint: str,
|
|
260
|
+
entrypoint: Union[str, list[str]],
|
|
263
261
|
stage_name: str,
|
|
264
262
|
args: Optional[list[str]] = None,
|
|
265
263
|
target_instances: int = 1,
|
|
@@ -274,7 +272,11 @@ def submit_directory(
|
|
|
274
272
|
Args:
|
|
275
273
|
dir_path: The path to the directory containing the job payload.
|
|
276
274
|
compute_pool: The compute pool to use for the job.
|
|
277
|
-
entrypoint: The
|
|
275
|
+
entrypoint: The entry point for job execution. Can be:
|
|
276
|
+
- A string path to the entry point script inside the source directory.
|
|
277
|
+
- A list of strings representing a custom command (e.g., ["arctic_training"])
|
|
278
|
+
which is passed through as-is without local resolution or validation.
|
|
279
|
+
This is useful for entrypoints that are installed via pip_requirements.
|
|
278
280
|
stage_name: The name of the stage where the job payload will be uploaded.
|
|
279
281
|
args: A list of arguments to pass to the job.
|
|
280
282
|
target_instances: The number of nodes in the job. If none specified, create a single node job.
|
|
@@ -315,7 +317,7 @@ def submit_from_stage(
|
|
|
315
317
|
source: str,
|
|
316
318
|
compute_pool: str,
|
|
317
319
|
*,
|
|
318
|
-
entrypoint: str,
|
|
320
|
+
entrypoint: Union[str, list[str]],
|
|
319
321
|
stage_name: str,
|
|
320
322
|
args: Optional[list[str]] = None,
|
|
321
323
|
target_instances: int = 1,
|
|
@@ -330,7 +332,11 @@ def submit_from_stage(
|
|
|
330
332
|
Args:
|
|
331
333
|
source: a stage path or a stage containing the job payload.
|
|
332
334
|
compute_pool: The compute pool to use for the job.
|
|
333
|
-
entrypoint:
|
|
335
|
+
entrypoint: The entry point for job execution. Can be:
|
|
336
|
+
- A string path to the entry point script inside the source directory.
|
|
337
|
+
- A list of strings representing a custom command (e.g., ["arctic_training"])
|
|
338
|
+
which is passed through as-is without local resolution or validation.
|
|
339
|
+
This is useful for entrypoints that are installed via pip_requirements.
|
|
334
340
|
stage_name: The name of the stage where the job payload will be uploaded.
|
|
335
341
|
args: A list of arguments to pass to the job.
|
|
336
342
|
target_instances: The number of nodes in the job. If none specified, create a single node job.
|
|
@@ -375,7 +381,7 @@ def _submit_job(
|
|
|
375
381
|
compute_pool: str,
|
|
376
382
|
*,
|
|
377
383
|
stage_name: str,
|
|
378
|
-
entrypoint: Optional[str] = None,
|
|
384
|
+
entrypoint: Optional[Union[str, list[str]]] = None,
|
|
379
385
|
args: Optional[list[str]] = None,
|
|
380
386
|
target_instances: int = 1,
|
|
381
387
|
pip_requirements: Optional[list[str]] = None,
|
|
@@ -392,7 +398,7 @@ def _submit_job(
|
|
|
392
398
|
compute_pool: str,
|
|
393
399
|
*,
|
|
394
400
|
stage_name: str,
|
|
395
|
-
entrypoint: Optional[str] = None,
|
|
401
|
+
entrypoint: Optional[Union[str, list[str]]] = None,
|
|
396
402
|
args: Optional[list[str]] = None,
|
|
397
403
|
target_instances: int = 1,
|
|
398
404
|
pip_requirements: Optional[list[str]] = None,
|
|
@@ -424,7 +430,7 @@ def _submit_job(
|
|
|
424
430
|
compute_pool: str,
|
|
425
431
|
*,
|
|
426
432
|
stage_name: str,
|
|
427
|
-
entrypoint: Optional[str] = None,
|
|
433
|
+
entrypoint: Optional[Union[str, list[str]]] = None,
|
|
428
434
|
args: Optional[list[str]] = None,
|
|
429
435
|
target_instances: int = 1,
|
|
430
436
|
session: Optional[snowpark.Session] = None,
|
|
@@ -437,7 +443,11 @@ def _submit_job(
|
|
|
437
443
|
source: The file/directory path containing payload source code or a serializable Python callable.
|
|
438
444
|
compute_pool: The compute pool to use for the job.
|
|
439
445
|
stage_name: The name of the stage where the job payload will be uploaded.
|
|
440
|
-
entrypoint: The entry point for the job execution.
|
|
446
|
+
entrypoint: The entry point for the job execution. Can be:
|
|
447
|
+
- A string path to a Python script (required if source is a directory).
|
|
448
|
+
- A list of strings representing a custom command (e.g., ["arctic_training"])
|
|
449
|
+
which is passed through as-is without local resolution or validation.
|
|
450
|
+
This is useful for entrypoints that are installed via pip_requirements.
|
|
441
451
|
args: A list of arguments to pass to the job.
|
|
442
452
|
target_instances: The number of instances to use for the job. If none specified, single node job is created.
|
|
443
453
|
session: The Snowpark session to use. If none specified, uses active session.
|
|
@@ -449,7 +459,6 @@ def _submit_job(
|
|
|
449
459
|
Raises:
|
|
450
460
|
ValueError: If database or schema value(s) are invalid
|
|
451
461
|
RuntimeError: If schema is not specified in session context or job submission
|
|
452
|
-
SnowparkSQLException: if failed to upload payload
|
|
453
462
|
"""
|
|
454
463
|
session = _ensure_session(session)
|
|
455
464
|
|
|
@@ -481,7 +490,8 @@ def _submit_job(
|
|
|
481
490
|
enable_metrics = kwargs.pop("enable_metrics", True)
|
|
482
491
|
query_warehouse = kwargs.pop("query_warehouse", session.get_current_warehouse())
|
|
483
492
|
imports = kwargs.pop("imports", None) or imports
|
|
484
|
-
|
|
493
|
+
# if the mljob is submitted from a notebook, we use the same image tag as the notebook
|
|
494
|
+
runtime_environment = kwargs.pop("runtime_environment", os.environ.get(constants.RUNTIME_IMAGE_TAG_ENV_VAR, None))
|
|
485
495
|
|
|
486
496
|
# Warn if there are unknown kwargs
|
|
487
497
|
if kwargs:
|
|
@@ -513,7 +523,7 @@ def _submit_job(
|
|
|
513
523
|
try:
|
|
514
524
|
# Upload payload
|
|
515
525
|
uploaded_payload = payload_utils.JobPayload(
|
|
516
|
-
source, entrypoint=entrypoint, pip_requirements=pip_requirements,
|
|
526
|
+
source, entrypoint=entrypoint, pip_requirements=pip_requirements, imports=imports
|
|
517
527
|
).upload(session, stage_path)
|
|
518
528
|
except SnowparkSQLException as e:
|
|
519
529
|
if e.sql_error_code == 90106:
|
|
@@ -528,125 +538,36 @@ def _submit_job(
|
|
|
528
538
|
) from e
|
|
529
539
|
raise
|
|
530
540
|
|
|
531
|
-
|
|
532
|
-
# Add default env vars (extracted from spec_utils.generate_service_spec)
|
|
533
|
-
combined_env_vars = {**uploaded_payload.env_vars, **(env_vars or {})}
|
|
541
|
+
combined_env_vars = {**uploaded_payload.env_vars, **(env_vars or {})}
|
|
534
542
|
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
) from e
|
|
561
|
-
# SNOW-2390287: SYSTEM$EXECUTE_ML_JOB() is erroneously blocked in owner's rights
|
|
562
|
-
# stored procedures. This will be fixed in an upcoming release.
|
|
563
|
-
logger.warning(
|
|
564
|
-
"Job submission using V2 failed with error {}. Falling back to V1.".format(
|
|
565
|
-
str(e).split("\n", 1)[0],
|
|
566
|
-
)
|
|
567
|
-
)
|
|
568
|
-
|
|
569
|
-
# Fall back to v1
|
|
570
|
-
# Generate service spec
|
|
571
|
-
spec = spec_utils.generate_service_spec(
|
|
572
|
-
session,
|
|
573
|
-
compute_pool=compute_pool,
|
|
574
|
-
payload=uploaded_payload,
|
|
575
|
-
args=args,
|
|
576
|
-
target_instances=target_instances,
|
|
577
|
-
min_instances=min_instances,
|
|
578
|
-
enable_metrics=enable_metrics,
|
|
579
|
-
runtime_environment=runtime_environment,
|
|
580
|
-
)
|
|
581
|
-
|
|
582
|
-
# Generate spec overrides
|
|
583
|
-
spec_overrides = spec_utils.generate_spec_overrides(
|
|
584
|
-
environment_vars=env_vars,
|
|
585
|
-
custom_overrides=spec_overrides,
|
|
586
|
-
)
|
|
587
|
-
if spec_overrides:
|
|
588
|
-
spec = spec_utils.merge_patch(spec, spec_overrides, display_name="spec_overrides")
|
|
589
|
-
|
|
590
|
-
return _do_submit_job_v1(
|
|
591
|
-
session, spec, external_access_integrations, query_warehouse, target_instances, compute_pool, job_id
|
|
592
|
-
)
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
def _do_submit_job_v1(
|
|
596
|
-
session: snowpark.Session,
|
|
597
|
-
spec: dict[str, Any],
|
|
598
|
-
external_access_integrations: list[str],
|
|
599
|
-
query_warehouse: Optional[str],
|
|
600
|
-
target_instances: int,
|
|
601
|
-
compute_pool: str,
|
|
602
|
-
job_id: str,
|
|
603
|
-
) -> jb.MLJob[Any]:
|
|
604
|
-
"""
|
|
605
|
-
Generate the SQL query for job submission.
|
|
606
|
-
|
|
607
|
-
Args:
|
|
608
|
-
session: The Snowpark session to use.
|
|
609
|
-
spec: The service spec for the job.
|
|
610
|
-
external_access_integrations: The external access integrations for the job.
|
|
611
|
-
query_warehouse: The query warehouse for the job.
|
|
612
|
-
target_instances: The number of instances for the job.
|
|
613
|
-
session: The Snowpark session to use.
|
|
614
|
-
compute_pool: The compute pool to use for the job.
|
|
615
|
-
job_id: The ID of the job.
|
|
616
|
-
|
|
617
|
-
Returns:
|
|
618
|
-
The job object.
|
|
619
|
-
"""
|
|
620
|
-
query_template = textwrap.dedent(
|
|
621
|
-
"""\
|
|
622
|
-
EXECUTE JOB SERVICE
|
|
623
|
-
IN COMPUTE POOL IDENTIFIER(?)
|
|
624
|
-
FROM SPECIFICATION $$
|
|
625
|
-
{}
|
|
626
|
-
$$
|
|
627
|
-
NAME = IDENTIFIER(?)
|
|
628
|
-
ASYNC = TRUE
|
|
629
|
-
"""
|
|
630
|
-
)
|
|
631
|
-
params: list[Any] = [compute_pool, job_id]
|
|
632
|
-
query = query_template.format(yaml.dump(spec)).splitlines()
|
|
633
|
-
if external_access_integrations:
|
|
634
|
-
external_access_integration_list = ",".join(f"{e}" for e in external_access_integrations)
|
|
635
|
-
query.append(f"EXTERNAL_ACCESS_INTEGRATIONS = ({external_access_integration_list})")
|
|
636
|
-
if query_warehouse:
|
|
637
|
-
query.append("QUERY_WAREHOUSE = IDENTIFIER(?)")
|
|
638
|
-
params.append(query_warehouse)
|
|
639
|
-
if target_instances > 1:
|
|
640
|
-
query.append("REPLICAS = ?")
|
|
641
|
-
params.append(target_instances)
|
|
642
|
-
|
|
643
|
-
query_text = "\n".join(line for line in query if line)
|
|
644
|
-
_ = query_helper.run_query(session, query_text, params=params)
|
|
645
|
-
|
|
646
|
-
return get_job(job_id, session=session)
|
|
543
|
+
try:
|
|
544
|
+
return _do_submit_job(
|
|
545
|
+
session=session,
|
|
546
|
+
payload=uploaded_payload,
|
|
547
|
+
args=args,
|
|
548
|
+
env_vars=combined_env_vars,
|
|
549
|
+
spec_overrides=spec_overrides,
|
|
550
|
+
compute_pool=compute_pool,
|
|
551
|
+
job_id=job_id,
|
|
552
|
+
external_access_integrations=external_access_integrations,
|
|
553
|
+
query_warehouse=query_warehouse,
|
|
554
|
+
target_instances=target_instances,
|
|
555
|
+
min_instances=min_instances,
|
|
556
|
+
enable_metrics=enable_metrics,
|
|
557
|
+
use_async=True,
|
|
558
|
+
runtime_environment=runtime_environment,
|
|
559
|
+
)
|
|
560
|
+
except SnowparkSQLException as e:
|
|
561
|
+
if e.sql_error_code == 3001 and "schema" in str(e).lower():
|
|
562
|
+
raise RuntimeError(
|
|
563
|
+
"please grant privileges on schema before submitting a job, see",
|
|
564
|
+
"https://docs.snowflake.com/en/developer-guide/snowflake-ml/ml-jobs/access-control-requirements"
|
|
565
|
+
" for more details",
|
|
566
|
+
) from e
|
|
567
|
+
raise
|
|
647
568
|
|
|
648
569
|
|
|
649
|
-
def
|
|
570
|
+
def _do_submit_job(
|
|
650
571
|
session: snowpark.Session,
|
|
651
572
|
payload: types.UploadedPayload,
|
|
652
573
|
args: Optional[list[str]],
|
|
@@ -684,9 +605,7 @@ def _do_submit_job_v2(
|
|
|
684
605
|
Returns:
|
|
685
606
|
The job object.
|
|
686
607
|
"""
|
|
687
|
-
args = [
|
|
688
|
-
(payload.stage_path.joinpath(v).as_posix() if isinstance(v, PurePath) else v) for v in payload.entrypoint
|
|
689
|
-
] + (args or [])
|
|
608
|
+
args = [(v.as_posix() if isinstance(v, PurePath) else v) for v in payload.entrypoint] + (args or [])
|
|
690
609
|
spec_options = {
|
|
691
610
|
"STAGE_PATH": payload.stage_path.as_posix(),
|
|
692
611
|
"ENTRYPOINT": ["/usr/local/bin/_entrypoint.sh"],
|
|
@@ -695,8 +614,8 @@ def _do_submit_job_v2(
|
|
|
695
614
|
"ENABLE_METRICS": enable_metrics,
|
|
696
615
|
"SPEC_OVERRIDES": spec_overrides,
|
|
697
616
|
}
|
|
698
|
-
# for the image tag or full image URL, we use that directly
|
|
699
617
|
if runtime_environment:
|
|
618
|
+
# for the image tag or full image URL, we use that directly
|
|
700
619
|
spec_options["RUNTIME"] = runtime_environment
|
|
701
620
|
elif feature_flags.FeatureFlags.ENABLE_RUNTIME_VERSIONS.is_enabled():
|
|
702
621
|
# when feature flag is enabled, we get the local python version and wrap it in a dict
|
|
@@ -710,6 +629,9 @@ def _do_submit_job_v2(
|
|
|
710
629
|
"MIN_INSTANCES": min_instances,
|
|
711
630
|
"ASYNC": use_async,
|
|
712
631
|
}
|
|
632
|
+
|
|
633
|
+
if feature_flags.FeatureFlags.ENABLE_STAGE_MOUNT_V2.is_enabled(default=True):
|
|
634
|
+
spec_options["ENABLE_STAGE_MOUNT_V2"] = True
|
|
713
635
|
if payload.payload_name:
|
|
714
636
|
job_options["GENERATE_SUFFIX"] = True
|
|
715
637
|
job_options = {k: v for k, v in job_options.items() if v is not None}
|
|
@@ -7,7 +7,7 @@ from snowflake.ml._internal import telemetry
|
|
|
7
7
|
from snowflake.ml._internal.utils import identifier, mixins
|
|
8
8
|
|
|
9
9
|
if TYPE_CHECKING:
|
|
10
|
-
from snowflake.ml import dataset
|
|
10
|
+
from snowflake.ml.dataset import dataset
|
|
11
11
|
from snowflake.ml.feature_store import feature_view
|
|
12
12
|
from snowflake.ml.model._client.model import model_version_impl
|
|
13
13
|
|
snowflake/ml/model/__init__.py
CHANGED
|
@@ -2,16 +2,20 @@ import sys
|
|
|
2
2
|
import warnings
|
|
3
3
|
|
|
4
4
|
from snowflake.ml.model._client.model.batch_inference_specs import (
|
|
5
|
+
ColumnHandlingOptions,
|
|
6
|
+
FileEncoding,
|
|
5
7
|
JobSpec,
|
|
6
8
|
OutputSpec,
|
|
7
9
|
SaveMode,
|
|
8
10
|
)
|
|
9
11
|
from snowflake.ml.model._client.model.model_impl import Model
|
|
10
12
|
from snowflake.ml.model._client.model.model_version_impl import ExportMode, ModelVersion
|
|
13
|
+
from snowflake.ml.model.code_path import CodePath
|
|
11
14
|
from snowflake.ml.model.models.huggingface_pipeline import HuggingFacePipelineModel
|
|
12
15
|
from snowflake.ml.model.volatility import Volatility
|
|
13
16
|
|
|
14
17
|
__all__ = [
|
|
18
|
+
"CodePath",
|
|
15
19
|
"Model",
|
|
16
20
|
"ModelVersion",
|
|
17
21
|
"ExportMode",
|
|
@@ -20,6 +24,8 @@ __all__ = [
|
|
|
20
24
|
"OutputSpec",
|
|
21
25
|
"SaveMode",
|
|
22
26
|
"Volatility",
|
|
27
|
+
"FileEncoding",
|
|
28
|
+
"ColumnHandlingOptions",
|
|
23
29
|
]
|
|
24
30
|
|
|
25
31
|
_deprecation_warning_msg_for_3_9 = (
|
|
@@ -2,6 +2,7 @@ from enum import Enum
|
|
|
2
2
|
from typing import Optional
|
|
3
3
|
|
|
4
4
|
from pydantic import BaseModel
|
|
5
|
+
from typing_extensions import TypedDict
|
|
5
6
|
|
|
6
7
|
|
|
7
8
|
class SaveMode(str, Enum):
|
|
@@ -18,6 +19,20 @@ class SaveMode(str, Enum):
|
|
|
18
19
|
ERROR = "error"
|
|
19
20
|
|
|
20
21
|
|
|
22
|
+
class FileEncoding(str, Enum):
|
|
23
|
+
"""The encoding of the file content that will be passed to the custom model."""
|
|
24
|
+
|
|
25
|
+
RAW_BYTES = "raw_bytes"
|
|
26
|
+
BASE64 = "base64"
|
|
27
|
+
BASE64_DATA_URL = "base64_data_url"
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class ColumnHandlingOptions(TypedDict):
|
|
31
|
+
"""Options for handling specific columns during run_batch for file I/O."""
|
|
32
|
+
|
|
33
|
+
encoding: FileEncoding
|
|
34
|
+
|
|
35
|
+
|
|
21
36
|
class OutputSpec(BaseModel):
|
|
22
37
|
"""Specification for batch inference output.
|
|
23
38
|
|
|
@@ -74,7 +89,7 @@ class JobSpec(BaseModel):
|
|
|
74
89
|
the memory of the node.
|
|
75
90
|
gpu_requests (Optional[str]): The gpu limit for GPU based inference. Can be integer or
|
|
76
91
|
string values. Use CPU if None.
|
|
77
|
-
replicas (Optional[int]): Number of job
|
|
92
|
+
replicas (Optional[int]): Number of SPCS job nodes used for distributed inference.
|
|
78
93
|
If not specified, defaults to 1 replica.
|
|
79
94
|
|
|
80
95
|
Example:
|