chalkpy 2.89.22__py3-none-any.whl → 2.95.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- chalk/__init__.py +2 -1
- chalk/_gen/chalk/arrow/v1/arrow_pb2.py +7 -5
- chalk/_gen/chalk/arrow/v1/arrow_pb2.pyi +6 -0
- chalk/_gen/chalk/artifacts/v1/chart_pb2.py +36 -33
- chalk/_gen/chalk/artifacts/v1/chart_pb2.pyi +41 -1
- chalk/_gen/chalk/artifacts/v1/cron_query_pb2.py +8 -7
- chalk/_gen/chalk/artifacts/v1/cron_query_pb2.pyi +5 -0
- chalk/_gen/chalk/common/v1/offline_query_pb2.py +19 -13
- chalk/_gen/chalk/common/v1/offline_query_pb2.pyi +37 -0
- chalk/_gen/chalk/common/v1/online_query_pb2.py +54 -54
- chalk/_gen/chalk/common/v1/online_query_pb2.pyi +13 -1
- chalk/_gen/chalk/common/v1/script_task_pb2.py +13 -11
- chalk/_gen/chalk/common/v1/script_task_pb2.pyi +19 -1
- chalk/_gen/chalk/dataframe/__init__.py +0 -0
- chalk/_gen/chalk/dataframe/v1/__init__.py +0 -0
- chalk/_gen/chalk/dataframe/v1/dataframe_pb2.py +48 -0
- chalk/_gen/chalk/dataframe/v1/dataframe_pb2.pyi +123 -0
- chalk/_gen/chalk/dataframe/v1/dataframe_pb2_grpc.py +4 -0
- chalk/_gen/chalk/dataframe/v1/dataframe_pb2_grpc.pyi +4 -0
- chalk/_gen/chalk/graph/v1/graph_pb2.py +150 -149
- chalk/_gen/chalk/graph/v1/graph_pb2.pyi +25 -0
- chalk/_gen/chalk/graph/v1/sources_pb2.py +94 -84
- chalk/_gen/chalk/graph/v1/sources_pb2.pyi +56 -0
- chalk/_gen/chalk/kubernetes/v1/horizontalpodautoscaler_pb2.py +79 -0
- chalk/_gen/chalk/kubernetes/v1/horizontalpodautoscaler_pb2.pyi +377 -0
- chalk/_gen/chalk/kubernetes/v1/horizontalpodautoscaler_pb2_grpc.py +4 -0
- chalk/_gen/chalk/kubernetes/v1/horizontalpodautoscaler_pb2_grpc.pyi +4 -0
- chalk/_gen/chalk/kubernetes/v1/scaledobject_pb2.py +43 -7
- chalk/_gen/chalk/kubernetes/v1/scaledobject_pb2.pyi +252 -2
- chalk/_gen/chalk/protosql/v1/sql_service_pb2.py +54 -27
- chalk/_gen/chalk/protosql/v1/sql_service_pb2.pyi +131 -3
- chalk/_gen/chalk/protosql/v1/sql_service_pb2_grpc.py +45 -0
- chalk/_gen/chalk/protosql/v1/sql_service_pb2_grpc.pyi +14 -0
- chalk/_gen/chalk/python/v1/types_pb2.py +14 -14
- chalk/_gen/chalk/python/v1/types_pb2.pyi +8 -0
- chalk/_gen/chalk/server/v1/benchmark_pb2.py +76 -0
- chalk/_gen/chalk/server/v1/benchmark_pb2.pyi +156 -0
- chalk/_gen/chalk/server/v1/benchmark_pb2_grpc.py +258 -0
- chalk/_gen/chalk/server/v1/benchmark_pb2_grpc.pyi +84 -0
- chalk/_gen/chalk/server/v1/billing_pb2.py +40 -38
- chalk/_gen/chalk/server/v1/billing_pb2.pyi +17 -1
- chalk/_gen/chalk/server/v1/branches_pb2.py +45 -0
- chalk/_gen/chalk/server/v1/branches_pb2.pyi +80 -0
- chalk/_gen/chalk/server/v1/branches_pb2_grpc.pyi +36 -0
- chalk/_gen/chalk/server/v1/builder_pb2.py +372 -272
- chalk/_gen/chalk/server/v1/builder_pb2.pyi +479 -12
- chalk/_gen/chalk/server/v1/builder_pb2_grpc.py +360 -0
- chalk/_gen/chalk/server/v1/builder_pb2_grpc.pyi +96 -0
- chalk/_gen/chalk/server/v1/chart_pb2.py +10 -10
- chalk/_gen/chalk/server/v1/chart_pb2.pyi +18 -2
- chalk/_gen/chalk/server/v1/clickhouse_pb2.py +42 -0
- chalk/_gen/chalk/server/v1/clickhouse_pb2.pyi +17 -0
- chalk/_gen/chalk/server/v1/clickhouse_pb2_grpc.py +78 -0
- chalk/_gen/chalk/server/v1/clickhouse_pb2_grpc.pyi +38 -0
- chalk/_gen/chalk/server/v1/cloud_components_pb2.py +153 -107
- chalk/_gen/chalk/server/v1/cloud_components_pb2.pyi +146 -4
- chalk/_gen/chalk/server/v1/cloud_components_pb2_grpc.py +180 -0
- chalk/_gen/chalk/server/v1/cloud_components_pb2_grpc.pyi +48 -0
- chalk/_gen/chalk/server/v1/cloud_credentials_pb2.py +11 -3
- chalk/_gen/chalk/server/v1/cloud_credentials_pb2.pyi +20 -0
- chalk/_gen/chalk/server/v1/cloud_credentials_pb2_grpc.py +45 -0
- chalk/_gen/chalk/server/v1/cloud_credentials_pb2_grpc.pyi +12 -0
- chalk/_gen/chalk/server/v1/dataplanejobqueue_pb2.py +59 -35
- chalk/_gen/chalk/server/v1/dataplanejobqueue_pb2.pyi +127 -1
- chalk/_gen/chalk/server/v1/dataplanejobqueue_pb2_grpc.py +135 -0
- chalk/_gen/chalk/server/v1/dataplanejobqueue_pb2_grpc.pyi +36 -0
- chalk/_gen/chalk/server/v1/dataplaneworkflows_pb2.py +90 -0
- chalk/_gen/chalk/server/v1/dataplaneworkflows_pb2.pyi +264 -0
- chalk/_gen/chalk/server/v1/dataplaneworkflows_pb2_grpc.py +170 -0
- chalk/_gen/chalk/server/v1/dataplaneworkflows_pb2_grpc.pyi +62 -0
- chalk/_gen/chalk/server/v1/datasets_pb2.py +36 -24
- chalk/_gen/chalk/server/v1/datasets_pb2.pyi +71 -2
- chalk/_gen/chalk/server/v1/datasets_pb2_grpc.py +45 -0
- chalk/_gen/chalk/server/v1/datasets_pb2_grpc.pyi +12 -0
- chalk/_gen/chalk/server/v1/deploy_pb2.py +9 -3
- chalk/_gen/chalk/server/v1/deploy_pb2.pyi +12 -0
- chalk/_gen/chalk/server/v1/deploy_pb2_grpc.py +45 -0
- chalk/_gen/chalk/server/v1/deploy_pb2_grpc.pyi +12 -0
- chalk/_gen/chalk/server/v1/deployment_pb2.py +20 -15
- chalk/_gen/chalk/server/v1/deployment_pb2.pyi +25 -0
- chalk/_gen/chalk/server/v1/environment_pb2.py +25 -15
- chalk/_gen/chalk/server/v1/environment_pb2.pyi +93 -1
- chalk/_gen/chalk/server/v1/eventbus_pb2.py +44 -0
- chalk/_gen/chalk/server/v1/eventbus_pb2.pyi +64 -0
- chalk/_gen/chalk/server/v1/eventbus_pb2_grpc.py +4 -0
- chalk/_gen/chalk/server/v1/eventbus_pb2_grpc.pyi +4 -0
- chalk/_gen/chalk/server/v1/files_pb2.py +65 -0
- chalk/_gen/chalk/server/v1/files_pb2.pyi +167 -0
- chalk/_gen/chalk/server/v1/files_pb2_grpc.py +4 -0
- chalk/_gen/chalk/server/v1/files_pb2_grpc.pyi +4 -0
- chalk/_gen/chalk/server/v1/graph_pb2.py +41 -3
- chalk/_gen/chalk/server/v1/graph_pb2.pyi +191 -0
- chalk/_gen/chalk/server/v1/graph_pb2_grpc.py +92 -0
- chalk/_gen/chalk/server/v1/graph_pb2_grpc.pyi +32 -0
- chalk/_gen/chalk/server/v1/incident_pb2.py +57 -0
- chalk/_gen/chalk/server/v1/incident_pb2.pyi +165 -0
- chalk/_gen/chalk/server/v1/incident_pb2_grpc.py +4 -0
- chalk/_gen/chalk/server/v1/incident_pb2_grpc.pyi +4 -0
- chalk/_gen/chalk/server/v1/indexing_job_pb2.py +44 -0
- chalk/_gen/chalk/server/v1/indexing_job_pb2.pyi +38 -0
- chalk/_gen/chalk/server/v1/indexing_job_pb2_grpc.py +78 -0
- chalk/_gen/chalk/server/v1/indexing_job_pb2_grpc.pyi +38 -0
- chalk/_gen/chalk/server/v1/integrations_pb2.py +11 -9
- chalk/_gen/chalk/server/v1/integrations_pb2.pyi +34 -2
- chalk/_gen/chalk/server/v1/kube_pb2.py +29 -19
- chalk/_gen/chalk/server/v1/kube_pb2.pyi +28 -0
- chalk/_gen/chalk/server/v1/kube_pb2_grpc.py +45 -0
- chalk/_gen/chalk/server/v1/kube_pb2_grpc.pyi +12 -0
- chalk/_gen/chalk/server/v1/log_pb2.py +21 -3
- chalk/_gen/chalk/server/v1/log_pb2.pyi +68 -0
- chalk/_gen/chalk/server/v1/log_pb2_grpc.py +90 -0
- chalk/_gen/chalk/server/v1/log_pb2_grpc.pyi +24 -0
- chalk/_gen/chalk/server/v1/metadataplanejobqueue_pb2.py +73 -0
- chalk/_gen/chalk/server/v1/metadataplanejobqueue_pb2.pyi +212 -0
- chalk/_gen/chalk/server/v1/metadataplanejobqueue_pb2_grpc.py +217 -0
- chalk/_gen/chalk/server/v1/metadataplanejobqueue_pb2_grpc.pyi +74 -0
- chalk/_gen/chalk/server/v1/model_registry_pb2.py +10 -10
- chalk/_gen/chalk/server/v1/model_registry_pb2.pyi +4 -1
- chalk/_gen/chalk/server/v1/monitoring_pb2.py +84 -75
- chalk/_gen/chalk/server/v1/monitoring_pb2.pyi +1 -0
- chalk/_gen/chalk/server/v1/monitoring_pb2_grpc.py +136 -0
- chalk/_gen/chalk/server/v1/monitoring_pb2_grpc.pyi +38 -0
- chalk/_gen/chalk/server/v1/offline_queries_pb2.py +32 -10
- chalk/_gen/chalk/server/v1/offline_queries_pb2.pyi +73 -0
- chalk/_gen/chalk/server/v1/offline_queries_pb2_grpc.py +90 -0
- chalk/_gen/chalk/server/v1/offline_queries_pb2_grpc.pyi +24 -0
- chalk/_gen/chalk/server/v1/plandebug_pb2.py +53 -0
- chalk/_gen/chalk/server/v1/plandebug_pb2.pyi +86 -0
- chalk/_gen/chalk/server/v1/plandebug_pb2_grpc.py +168 -0
- chalk/_gen/chalk/server/v1/plandebug_pb2_grpc.pyi +60 -0
- chalk/_gen/chalk/server/v1/queries_pb2.py +76 -48
- chalk/_gen/chalk/server/v1/queries_pb2.pyi +155 -2
- chalk/_gen/chalk/server/v1/queries_pb2_grpc.py +180 -0
- chalk/_gen/chalk/server/v1/queries_pb2_grpc.pyi +48 -0
- chalk/_gen/chalk/server/v1/scheduled_query_pb2.py +4 -2
- chalk/_gen/chalk/server/v1/scheduled_query_pb2_grpc.py +45 -0
- chalk/_gen/chalk/server/v1/scheduled_query_pb2_grpc.pyi +12 -0
- chalk/_gen/chalk/server/v1/scheduled_query_run_pb2.py +12 -6
- chalk/_gen/chalk/server/v1/scheduled_query_run_pb2.pyi +75 -2
- chalk/_gen/chalk/server/v1/scheduler_pb2.py +24 -12
- chalk/_gen/chalk/server/v1/scheduler_pb2.pyi +61 -1
- chalk/_gen/chalk/server/v1/scheduler_pb2_grpc.py +90 -0
- chalk/_gen/chalk/server/v1/scheduler_pb2_grpc.pyi +24 -0
- chalk/_gen/chalk/server/v1/script_tasks_pb2.py +26 -14
- chalk/_gen/chalk/server/v1/script_tasks_pb2.pyi +33 -3
- chalk/_gen/chalk/server/v1/script_tasks_pb2_grpc.py +90 -0
- chalk/_gen/chalk/server/v1/script_tasks_pb2_grpc.pyi +24 -0
- chalk/_gen/chalk/server/v1/sql_interface_pb2.py +75 -0
- chalk/_gen/chalk/server/v1/sql_interface_pb2.pyi +142 -0
- chalk/_gen/chalk/server/v1/sql_interface_pb2_grpc.py +349 -0
- chalk/_gen/chalk/server/v1/sql_interface_pb2_grpc.pyi +114 -0
- chalk/_gen/chalk/server/v1/sql_queries_pb2.py +48 -0
- chalk/_gen/chalk/server/v1/sql_queries_pb2.pyi +150 -0
- chalk/_gen/chalk/server/v1/sql_queries_pb2_grpc.py +123 -0
- chalk/_gen/chalk/server/v1/sql_queries_pb2_grpc.pyi +52 -0
- chalk/_gen/chalk/server/v1/team_pb2.py +156 -137
- chalk/_gen/chalk/server/v1/team_pb2.pyi +56 -10
- chalk/_gen/chalk/server/v1/team_pb2_grpc.py +90 -0
- chalk/_gen/chalk/server/v1/team_pb2_grpc.pyi +24 -0
- chalk/_gen/chalk/server/v1/topic_pb2.py +5 -3
- chalk/_gen/chalk/server/v1/topic_pb2.pyi +10 -1
- chalk/_gen/chalk/server/v1/trace_pb2.py +50 -28
- chalk/_gen/chalk/server/v1/trace_pb2.pyi +121 -0
- chalk/_gen/chalk/server/v1/trace_pb2_grpc.py +135 -0
- chalk/_gen/chalk/server/v1/trace_pb2_grpc.pyi +42 -0
- chalk/_gen/chalk/server/v1/webhook_pb2.py +9 -3
- chalk/_gen/chalk/server/v1/webhook_pb2.pyi +18 -0
- chalk/_gen/chalk/server/v1/webhook_pb2_grpc.py +45 -0
- chalk/_gen/chalk/server/v1/webhook_pb2_grpc.pyi +12 -0
- chalk/_gen/chalk/streaming/v1/debug_service_pb2.py +62 -0
- chalk/_gen/chalk/streaming/v1/debug_service_pb2.pyi +75 -0
- chalk/_gen/chalk/streaming/v1/debug_service_pb2_grpc.py +221 -0
- chalk/_gen/chalk/streaming/v1/debug_service_pb2_grpc.pyi +88 -0
- chalk/_gen/chalk/streaming/v1/simple_streaming_service_pb2.py +19 -7
- chalk/_gen/chalk/streaming/v1/simple_streaming_service_pb2.pyi +96 -3
- chalk/_gen/chalk/streaming/v1/simple_streaming_service_pb2_grpc.py +48 -0
- chalk/_gen/chalk/streaming/v1/simple_streaming_service_pb2_grpc.pyi +20 -0
- chalk/_gen/chalk/utils/v1/field_change_pb2.py +32 -0
- chalk/_gen/chalk/utils/v1/field_change_pb2.pyi +42 -0
- chalk/_gen/chalk/utils/v1/field_change_pb2_grpc.py +4 -0
- chalk/_gen/chalk/utils/v1/field_change_pb2_grpc.pyi +4 -0
- chalk/_lsp/error_builder.py +11 -0
- chalk/_monitoring/Chart.py +1 -3
- chalk/_version.py +1 -1
- chalk/cli.py +5 -10
- chalk/client/client.py +178 -64
- chalk/client/client_async.py +154 -0
- chalk/client/client_async_impl.py +22 -0
- chalk/client/client_grpc.py +738 -112
- chalk/client/client_impl.py +541 -136
- chalk/client/dataset.py +27 -6
- chalk/client/models.py +99 -2
- chalk/client/serialization/model_serialization.py +126 -10
- chalk/config/project_config.py +1 -1
- chalk/df/LazyFramePlaceholder.py +1154 -0
- chalk/df/ast_parser.py +2 -10
- chalk/features/_class_property.py +7 -0
- chalk/features/_embedding/embedding.py +1 -0
- chalk/features/_embedding/sentence_transformer.py +1 -1
- chalk/features/_encoding/converter.py +83 -2
- chalk/features/_encoding/pyarrow.py +20 -4
- chalk/features/_encoding/rich.py +1 -3
- chalk/features/_tensor.py +1 -2
- chalk/features/dataframe/_filters.py +14 -5
- chalk/features/dataframe/_impl.py +91 -36
- chalk/features/dataframe/_validation.py +11 -7
- chalk/features/feature_field.py +40 -30
- chalk/features/feature_set.py +1 -2
- chalk/features/feature_set_decorator.py +1 -0
- chalk/features/feature_wrapper.py +42 -3
- chalk/features/hooks.py +81 -12
- chalk/features/inference.py +65 -10
- chalk/features/resolver.py +338 -56
- chalk/features/tag.py +1 -3
- chalk/features/underscore_features.py +2 -1
- chalk/functions/__init__.py +456 -21
- chalk/functions/holidays.py +1 -3
- chalk/gitignore/gitignore_parser.py +5 -1
- chalk/importer.py +186 -74
- chalk/ml/__init__.py +6 -2
- chalk/ml/model_hooks.py +368 -51
- chalk/ml/model_reference.py +68 -10
- chalk/ml/model_version.py +34 -21
- chalk/ml/utils.py +143 -40
- chalk/operators/_utils.py +14 -3
- chalk/parsed/_proto/export.py +22 -0
- chalk/parsed/duplicate_input_gql.py +4 -0
- chalk/parsed/expressions.py +1 -3
- chalk/parsed/json_conversions.py +21 -14
- chalk/parsed/to_proto.py +16 -4
- chalk/parsed/user_types_to_json.py +31 -10
- chalk/parsed/validation_from_registries.py +182 -0
- chalk/queries/named_query.py +16 -6
- chalk/queries/scheduled_query.py +13 -1
- chalk/serialization/parsed_annotation.py +25 -12
- chalk/sql/__init__.py +221 -0
- chalk/sql/_internal/integrations/athena.py +6 -1
- chalk/sql/_internal/integrations/bigquery.py +22 -2
- chalk/sql/_internal/integrations/databricks.py +61 -18
- chalk/sql/_internal/integrations/mssql.py +281 -0
- chalk/sql/_internal/integrations/postgres.py +11 -3
- chalk/sql/_internal/integrations/redshift.py +4 -0
- chalk/sql/_internal/integrations/snowflake.py +11 -2
- chalk/sql/_internal/integrations/util.py +2 -1
- chalk/sql/_internal/sql_file_resolver.py +55 -10
- chalk/sql/_internal/sql_source.py +36 -2
- chalk/streams/__init__.py +1 -3
- chalk/streams/_kafka_source.py +5 -1
- chalk/streams/_windows.py +16 -4
- chalk/streams/types.py +1 -2
- chalk/utils/__init__.py +1 -3
- chalk/utils/_otel_version.py +13 -0
- chalk/utils/async_helpers.py +14 -5
- chalk/utils/df_utils.py +2 -2
- chalk/utils/duration.py +1 -3
- chalk/utils/job_log_display.py +538 -0
- chalk/utils/missing_dependency.py +5 -4
- chalk/utils/notebook.py +255 -2
- chalk/utils/pl_helpers.py +190 -37
- chalk/utils/pydanticutil/pydantic_compat.py +1 -2
- chalk/utils/storage_client.py +246 -0
- chalk/utils/threading.py +1 -3
- chalk/utils/tracing.py +194 -86
- {chalkpy-2.89.22.dist-info → chalkpy-2.95.3.dist-info}/METADATA +53 -21
- {chalkpy-2.89.22.dist-info → chalkpy-2.95.3.dist-info}/RECORD +268 -198
- {chalkpy-2.89.22.dist-info → chalkpy-2.95.3.dist-info}/WHEEL +0 -0
- {chalkpy-2.89.22.dist-info → chalkpy-2.95.3.dist-info}/entry_points.txt +0 -0
- {chalkpy-2.89.22.dist-info → chalkpy-2.95.3.dist-info}/top_level.txt +0 -0
chalk/ml/model_hooks.py
CHANGED
|
@@ -1,68 +1,385 @@
|
|
|
1
|
-
from typing import
|
|
1
|
+
from typing import TYPE_CHECKING, Any, Dict, Optional, Protocol, Tuple
|
|
2
2
|
|
|
3
|
-
import
|
|
3
|
+
from chalk.ml.utils import ModelClass, ModelEncoding, ModelType
|
|
4
4
|
|
|
5
|
-
|
|
5
|
+
if TYPE_CHECKING:
|
|
6
|
+
from chalk.features.resolver import ResourceHint
|
|
6
7
|
|
|
7
8
|
|
|
8
|
-
|
|
9
|
-
|
|
9
|
+
class ModelInference(Protocol):
|
|
10
|
+
"""Abstract base class for model loading and inference."""
|
|
10
11
|
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
12
|
+
def load_model(self, path: str, resource_hint: Optional["ResourceHint"] = None) -> Any:
|
|
13
|
+
"""Load a model from the given path."""
|
|
14
|
+
pass
|
|
14
15
|
|
|
16
|
+
def predict(self, model: Any, X: Any) -> Any:
|
|
17
|
+
"""Run inference on the model with input X."""
|
|
18
|
+
pass
|
|
15
19
|
|
|
16
|
-
def
|
|
17
|
-
|
|
20
|
+
def prepare_input(self, feature_table: Any) -> Any:
|
|
21
|
+
"""Convert PyArrow table to model input format.
|
|
18
22
|
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
23
|
+
Default implementation converts to numpy array via __array__().
|
|
24
|
+
Override for model-specific input formats (e.g., ONNX struct arrays).
|
|
25
|
+
"""
|
|
26
|
+
return feature_table.__array__()
|
|
22
27
|
|
|
28
|
+
def extract_output(self, result: Any, output_feature_name: str) -> Any:
|
|
29
|
+
"""Extract single output from model result.
|
|
23
30
|
|
|
24
|
-
|
|
25
|
-
|
|
31
|
+
Default implementation returns result as-is (for single outputs).
|
|
32
|
+
Override for models with structured outputs (e.g., ONNX struct arrays).
|
|
33
|
+
"""
|
|
34
|
+
return result
|
|
26
35
|
|
|
27
|
-
torch.set_grad_enabled(False)
|
|
28
|
-
model = torch.jit.load(f)
|
|
29
|
-
model.input_to_tensor = lambda X: torch.from_numpy(X.__array__()).float()
|
|
30
|
-
return model
|
|
31
36
|
|
|
37
|
+
class XGBoostClassifierInference(ModelInference):
|
|
38
|
+
"""Model inference for XGBoost classifiers."""
|
|
32
39
|
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
(ModelType.SKLEARN, ModelEncoding.PICKLE, None): lambda f: __import__("joblib").load(f),
|
|
36
|
-
(ModelType.TENSORFLOW, ModelEncoding.HDF5, None): lambda f: __import__("tensorflow").keras.models.load_model(f),
|
|
37
|
-
(ModelType.TENSORFLOW, ModelEncoding.SAFETENSOR, None): lambda f: __import__("tensorflow").keras.models.load_model(
|
|
38
|
-
f
|
|
39
|
-
),
|
|
40
|
-
(ModelType.XGBOOST, ModelEncoding.JSON, None): load_xgb_regressor,
|
|
41
|
-
(ModelType.XGBOOST, ModelEncoding.JSON, "classifier"): load_xgb_classifier,
|
|
42
|
-
(ModelType.XGBOOST, ModelEncoding.JSON, "regressor"): load_xgb_regressor,
|
|
43
|
-
(ModelType.LIGHTGBM, ModelEncoding.TEXT, None): lambda f: __import__("lightgbm").Booster(model_file=f),
|
|
44
|
-
(ModelType.CATBOOST, ModelEncoding.CBM, None): lambda f: __import__("catboost").CatBoost().load_model(f),
|
|
45
|
-
(ModelType.ONNX, ModelEncoding.PROTOBUF, None): lambda f: __import__("onnxruntime").InferenceSession(f),
|
|
46
|
-
}
|
|
40
|
+
def load_model(self, path: str, resource_hint: Optional["ResourceHint"] = None) -> Any:
|
|
41
|
+
import xgboost # pyright: ignore[reportMissingImports]
|
|
47
42
|
|
|
43
|
+
model = xgboost.XGBClassifier()
|
|
44
|
+
model.load_model(path)
|
|
45
|
+
return model
|
|
48
46
|
|
|
49
|
-
def
|
|
50
|
-
|
|
51
|
-
result = outputs.detach().numpy().astype("float64")
|
|
52
|
-
result = result.squeeze()
|
|
53
|
-
# Convert 0-dimensional array to scalar, or ensure we have a proper 1D array
|
|
54
|
-
if result.ndim == 0:
|
|
55
|
-
return result.item()
|
|
56
|
-
return result
|
|
47
|
+
def predict(self, model: Any, X: Any) -> Any:
|
|
48
|
+
return model.predict(X)
|
|
57
49
|
|
|
58
50
|
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
(
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
51
|
+
class XGBoostRegressorInference(ModelInference):
|
|
52
|
+
"""Model inference for XGBoost regressors."""
|
|
53
|
+
|
|
54
|
+
def load_model(self, path: str, resource_hint: Optional["ResourceHint"] = None) -> Any:
|
|
55
|
+
import xgboost # pyright: ignore[reportMissingImports]
|
|
56
|
+
|
|
57
|
+
model = xgboost.XGBRegressor()
|
|
58
|
+
model.load_model(path)
|
|
59
|
+
return model
|
|
60
|
+
|
|
61
|
+
def predict(self, model: Any, X: Any) -> Any:
|
|
62
|
+
return model.predict(X)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class PyTorchInference(ModelInference):
|
|
66
|
+
"""Model inference for PyTorch models."""
|
|
67
|
+
|
|
68
|
+
def load_model(self, path: str, resource_hint: Optional["ResourceHint"] = None) -> Any:
|
|
69
|
+
import torch # pyright: ignore[reportMissingImports]
|
|
70
|
+
|
|
71
|
+
torch.set_grad_enabled(False)
|
|
72
|
+
|
|
73
|
+
# Load the model
|
|
74
|
+
model = torch.jit.load(path)
|
|
75
|
+
|
|
76
|
+
# If resource_hint is "gpu", move model to GPU
|
|
77
|
+
if resource_hint == "gpu" and torch.cuda.is_available():
|
|
78
|
+
device = torch.device("cuda")
|
|
79
|
+
model = model.to(device)
|
|
80
|
+
model.input_to_tensor = lambda X: torch.from_numpy(X).float().to(device)
|
|
81
|
+
else:
|
|
82
|
+
model.input_to_tensor = lambda X: torch.from_numpy(X).float()
|
|
83
|
+
|
|
84
|
+
return model
|
|
85
|
+
|
|
86
|
+
def predict(self, model: Any, X: Any) -> Any:
|
|
87
|
+
outputs = model(model.input_to_tensor(X))
|
|
88
|
+
result = outputs.detach().cpu().numpy().astype("float64")
|
|
89
|
+
result = result.squeeze()
|
|
90
|
+
|
|
91
|
+
# Convert 0-dimensional array to scalar, or ensure we have a proper 1D array
|
|
92
|
+
if result.ndim == 0:
|
|
93
|
+
return result.item()
|
|
94
|
+
|
|
95
|
+
return result
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class SklearnInference(ModelInference):
|
|
99
|
+
"""Model inference for scikit-learn models."""
|
|
100
|
+
|
|
101
|
+
def load_model(self, path: str, resource_hint: Optional["ResourceHint"] = None) -> Any:
|
|
102
|
+
import joblib # pyright: ignore[reportMissingImports]
|
|
103
|
+
|
|
104
|
+
return joblib.load(path)
|
|
105
|
+
|
|
106
|
+
def predict(self, model: Any, X: Any) -> Any:
|
|
107
|
+
return model.predict(X)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
class TensorFlowInference(ModelInference):
|
|
111
|
+
"""Model inference for TensorFlow models."""
|
|
112
|
+
|
|
113
|
+
def load_model(self, path: str, resource_hint: Optional["ResourceHint"] = None) -> Any:
|
|
114
|
+
import tensorflow # pyright: ignore[reportMissingImports]
|
|
115
|
+
|
|
116
|
+
return tensorflow.keras.models.load_model(path)
|
|
117
|
+
|
|
118
|
+
def predict(self, model: Any, X: Any) -> Any:
|
|
119
|
+
return model.predict(X)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
class LightGBMInference(ModelInference):
|
|
123
|
+
"""Model inference for LightGBM models."""
|
|
124
|
+
|
|
125
|
+
def load_model(self, path: str, resource_hint: Optional["ResourceHint"] = None) -> Any:
|
|
126
|
+
import lightgbm # pyright: ignore[reportMissingImports]
|
|
127
|
+
|
|
128
|
+
return lightgbm.Booster(model_file=path)
|
|
129
|
+
|
|
130
|
+
def predict(self, model: Any, X: Any) -> Any:
|
|
131
|
+
return model.predict(X)
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
class CatBoostInference(ModelInference):
|
|
135
|
+
"""Model inference for CatBoost models."""
|
|
136
|
+
|
|
137
|
+
def load_model(self, path: str, resource_hint: Optional["ResourceHint"] = None) -> Any:
|
|
138
|
+
import catboost # pyright: ignore[reportMissingImports]
|
|
139
|
+
|
|
140
|
+
return catboost.CatBoost().load_model(path)
|
|
141
|
+
|
|
142
|
+
def predict(self, model: Any, X: Any) -> Any:
|
|
143
|
+
return model.predict(X)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
class ONNXInference(ModelInference):
|
|
147
|
+
"""Model inference for ONNX models with struct input/output support."""
|
|
148
|
+
|
|
149
|
+
def load_model(self, path: str, resource_hint: Optional["ResourceHint"] = None) -> Any:
|
|
150
|
+
import onnxruntime # pyright: ignore[reportMissingImports]
|
|
151
|
+
|
|
152
|
+
# Conditionally add CUDAExecutionProvider based on resource_hint
|
|
153
|
+
providers = (
|
|
154
|
+
["CUDAExecutionProvider", "CPUExecutionProvider"] if resource_hint == "gpu" else ["CPUExecutionProvider"]
|
|
155
|
+
)
|
|
156
|
+
return onnxruntime.InferenceSession(path, providers=providers)
|
|
157
|
+
|
|
158
|
+
def prepare_input(self, feature_table: Any) -> Any:
|
|
159
|
+
"""Convert PyArrow table to struct array for ONNX models."""
|
|
160
|
+
import pyarrow as pa
|
|
161
|
+
|
|
162
|
+
# Get arrays for each column, combining chunks if necessary
|
|
163
|
+
arrays = []
|
|
164
|
+
for i in range(feature_table.num_columns):
|
|
165
|
+
col = feature_table.column(i)
|
|
166
|
+
if isinstance(col, pa.ChunkedArray):
|
|
167
|
+
arrays.append(col.combine_chunks())
|
|
168
|
+
else:
|
|
169
|
+
arrays.append(col)
|
|
170
|
+
|
|
171
|
+
# Create fields from schema, preserving original field names
|
|
172
|
+
# Field names should match ONNX input names exactly
|
|
173
|
+
fields = []
|
|
174
|
+
for field in feature_table.schema:
|
|
175
|
+
fields.append(pa.field(field.name, field.type))
|
|
176
|
+
|
|
177
|
+
# Create struct array where each row is a struct with named fields
|
|
178
|
+
return pa.StructArray.from_arrays(arrays, fields=fields)
|
|
179
|
+
|
|
180
|
+
def extract_output(self, result: Any, output_feature_name: str) -> Any:
|
|
181
|
+
"""Extract single field from ONNX struct output."""
|
|
182
|
+
import pyarrow as pa
|
|
183
|
+
|
|
184
|
+
if not isinstance(result, (pa.StructArray, pa.ChunkedArray)):
|
|
185
|
+
return result
|
|
186
|
+
|
|
187
|
+
struct_type = result.type if isinstance(result, pa.StructArray) else result.chunk(0).type
|
|
188
|
+
|
|
189
|
+
# Find matching field by name, or use first field
|
|
190
|
+
field_index = None
|
|
191
|
+
for i, field in enumerate(struct_type):
|
|
192
|
+
if field.name == output_feature_name:
|
|
193
|
+
field_index = i
|
|
194
|
+
break
|
|
195
|
+
|
|
196
|
+
return result.field(field_index if field_index is not None else 0)
|
|
197
|
+
|
|
198
|
+
def predict(self, model: Any, X: Any) -> Any:
|
|
199
|
+
"""Run ONNX inference with struct input/output."""
|
|
200
|
+
# Get ONNX model input/output names
|
|
201
|
+
input_names = [inp.name for inp in model.get_inputs()]
|
|
202
|
+
output_names = [out.name for out in model.get_outputs()]
|
|
203
|
+
|
|
204
|
+
# Convert struct input to ONNX input dict
|
|
205
|
+
input_dict = self._struct_to_inputs(X, input_names)
|
|
206
|
+
|
|
207
|
+
# Run ONNX inference
|
|
208
|
+
outputs = model.run(output_names, input_dict)
|
|
209
|
+
|
|
210
|
+
# Always return outputs as struct array
|
|
211
|
+
return self._outputs_to_struct(output_names, outputs)
|
|
212
|
+
|
|
213
|
+
def _struct_to_inputs(self, struct_array: Any, input_names: list) -> dict:
|
|
214
|
+
"""Extract ONNX inputs from struct array by matching field names.
|
|
215
|
+
|
|
216
|
+
Struct field names must match ONNX input names (supports list/Tensor types).
|
|
217
|
+
If ONNX expects a single input but struct has multiple scalar fields,
|
|
218
|
+
stack them into a 2D array.
|
|
219
|
+
"""
|
|
220
|
+
import numpy as np
|
|
221
|
+
import pyarrow as pa
|
|
222
|
+
|
|
223
|
+
if isinstance(struct_array, pa.ChunkedArray):
|
|
224
|
+
struct_array = struct_array.combine_chunks()
|
|
225
|
+
|
|
226
|
+
input_dict = {}
|
|
227
|
+
struct_fields = {field.name: i for i, field in enumerate(struct_array.type)}
|
|
228
|
+
|
|
229
|
+
# Check if struct field names match ONNX input names
|
|
230
|
+
fields_match = all(input_name in struct_fields for input_name in input_names)
|
|
231
|
+
|
|
232
|
+
if not fields_match:
|
|
233
|
+
# Special case 1: ONNX expects single input and struct has single field
|
|
234
|
+
# Use that field regardless of name mismatch
|
|
235
|
+
if len(input_names) == 1 and len(struct_fields) == 1:
|
|
236
|
+
field_data = struct_array.field(0)
|
|
237
|
+
input_dict[input_names[0]] = self._arrow_to_numpy(field_data)
|
|
238
|
+
return input_dict
|
|
239
|
+
|
|
240
|
+
# Special case 2: ONNX expects single input, but struct has multiple scalar fields
|
|
241
|
+
# Stack them into a 2D array [batch_size, num_fields]
|
|
242
|
+
if len(input_names) == 1 and len(struct_fields) > 1:
|
|
243
|
+
# Check if all fields are scalar (not nested lists)
|
|
244
|
+
all_scalar = all(
|
|
245
|
+
not pa.types.is_list(struct_array.type[i].type)
|
|
246
|
+
and not pa.types.is_large_list(struct_array.type[i].type)
|
|
247
|
+
for i in range(len(struct_array.type))
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
if all_scalar:
|
|
251
|
+
# Stack all fields into a single 2D array
|
|
252
|
+
columns = []
|
|
253
|
+
for i in range(len(struct_array.type)):
|
|
254
|
+
field_data = struct_array.field(i)
|
|
255
|
+
col_array = self._arrow_to_numpy(field_data)
|
|
256
|
+
columns.append(col_array)
|
|
257
|
+
|
|
258
|
+
# Stack columns horizontally to create [batch_size, num_features]
|
|
259
|
+
stacked = np.column_stack(columns)
|
|
260
|
+
input_dict[input_names[0]] = stacked
|
|
261
|
+
return input_dict
|
|
262
|
+
|
|
263
|
+
raise ValueError(
|
|
264
|
+
f"ONNX inputs {input_names} not found in struct fields {list(struct_fields.keys())}. "
|
|
265
|
+
+ "Struct field names must match ONNX input names."
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
# Direct mapping: struct fields match ONNX inputs (for Tensor/list types or named inputs)
|
|
269
|
+
for input_name in input_names:
|
|
270
|
+
field_data = struct_array.field(struct_fields[input_name])
|
|
271
|
+
input_dict[input_name] = self._arrow_to_numpy(field_data)
|
|
272
|
+
|
|
273
|
+
return input_dict
|
|
274
|
+
|
|
275
|
+
def _arrow_to_numpy(self, arrow_array: Any) -> Any:
|
|
276
|
+
"""Convert Arrow array (including nested lists) to dense numpy array."""
|
|
277
|
+
import numpy as np
|
|
278
|
+
import pyarrow as pa
|
|
279
|
+
|
|
280
|
+
if isinstance(arrow_array, pa.ChunkedArray):
|
|
281
|
+
arrow_array = arrow_array.combine_chunks()
|
|
282
|
+
|
|
283
|
+
# Convert to Python list, then numpy - handles all cases (nested lists, flat arrays, etc.)
|
|
284
|
+
return np.array(arrow_array.to_pylist(), dtype=np.float32)
|
|
285
|
+
|
|
286
|
+
def _outputs_to_struct(self, output_names: list, outputs: list) -> Any:
|
|
287
|
+
"""Convert ONNX outputs to PyArrow struct array."""
|
|
288
|
+
import pyarrow as pa
|
|
289
|
+
|
|
290
|
+
if not outputs:
|
|
291
|
+
raise ValueError("ONNX model returned no outputs")
|
|
292
|
+
|
|
293
|
+
# Convert each output to Arrow array with proper type
|
|
294
|
+
fields = []
|
|
295
|
+
arrays = []
|
|
296
|
+
|
|
297
|
+
for name, output_array in zip(output_names, outputs):
|
|
298
|
+
arrow_array = self._numpy_to_arrow_array(output_array)
|
|
299
|
+
fields.append(pa.field(name, arrow_array.type))
|
|
300
|
+
arrays.append(arrow_array)
|
|
301
|
+
|
|
302
|
+
return pa.StructArray.from_arrays(arrays, fields=fields)
|
|
303
|
+
|
|
304
|
+
def _numpy_to_arrow_array(self, arr: Any) -> Any:
|
|
305
|
+
"""Convert numpy array to PyArrow array (possibly nested list)."""
|
|
306
|
+
import pyarrow as pa
|
|
307
|
+
|
|
308
|
+
# PyArrow can infer the correct nested list type from Python lists
|
|
309
|
+
# Shape (batch, dim1, dim2, ...) -> list[list[...]]
|
|
310
|
+
return pa.array(arr.tolist())
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
class ModelInferenceRegistry:
|
|
314
|
+
"""Registry for model inference implementations."""
|
|
315
|
+
|
|
316
|
+
def __init__(self):
|
|
317
|
+
super().__init__()
|
|
318
|
+
self._registry: Dict[Tuple[ModelType, ModelEncoding, Optional[ModelClass]], ModelInference] = {}
|
|
319
|
+
|
|
320
|
+
def register(
|
|
321
|
+
self,
|
|
322
|
+
model_type: ModelType,
|
|
323
|
+
encoding: ModelEncoding,
|
|
324
|
+
model_class: Optional[ModelClass],
|
|
325
|
+
inference: ModelInference,
|
|
326
|
+
) -> None:
|
|
327
|
+
"""Register a model inference implementation."""
|
|
328
|
+
self._registry[(model_type, encoding, model_class)] = inference
|
|
329
|
+
|
|
330
|
+
def register_for_all_classes(
|
|
331
|
+
self,
|
|
332
|
+
model_type: ModelType,
|
|
333
|
+
encoding: ModelEncoding,
|
|
334
|
+
inference: ModelInference,
|
|
335
|
+
) -> None:
|
|
336
|
+
"""Register inference for None, CLASSIFICATION, and REGRESSION variants."""
|
|
337
|
+
self.register(model_type, encoding, None, inference)
|
|
338
|
+
self.register(model_type, encoding, ModelClass.CLASSIFICATION, inference)
|
|
339
|
+
self.register(model_type, encoding, ModelClass.REGRESSION, inference)
|
|
340
|
+
|
|
341
|
+
def get(
|
|
342
|
+
self,
|
|
343
|
+
model_type: ModelType,
|
|
344
|
+
encoding: ModelEncoding,
|
|
345
|
+
model_class: Optional[ModelClass] = None,
|
|
346
|
+
) -> Optional[ModelInference]:
|
|
347
|
+
"""Get a model inference implementation from the registry."""
|
|
348
|
+
return self._registry.get((model_type, encoding, model_class), None)
|
|
349
|
+
|
|
350
|
+
def get_loader(
|
|
351
|
+
self,
|
|
352
|
+
model_type: ModelType,
|
|
353
|
+
encoding: ModelEncoding,
|
|
354
|
+
model_class: Optional[ModelClass] = None,
|
|
355
|
+
):
|
|
356
|
+
"""Get the load_model function for a given configuration."""
|
|
357
|
+
inference = self.get(model_type, encoding, model_class)
|
|
358
|
+
return inference.load_model if inference else None
|
|
359
|
+
|
|
360
|
+
def get_predictor(
|
|
361
|
+
self,
|
|
362
|
+
model_type: ModelType,
|
|
363
|
+
encoding: ModelEncoding,
|
|
364
|
+
model_class: Optional[ModelClass] = None,
|
|
365
|
+
):
|
|
366
|
+
"""Get the predict function for a given configuration."""
|
|
367
|
+
inference = self.get(model_type, encoding, model_class)
|
|
368
|
+
return inference.predict if inference else None
|
|
369
|
+
|
|
370
|
+
|
|
371
|
+
# Global registry instance
|
|
372
|
+
MODEL_REGISTRY = ModelInferenceRegistry()
|
|
373
|
+
|
|
374
|
+
# Register all model types
|
|
375
|
+
MODEL_REGISTRY.register_for_all_classes(ModelType.PYTORCH, ModelEncoding.PICKLE, PyTorchInference())
|
|
376
|
+
MODEL_REGISTRY.register_for_all_classes(ModelType.SKLEARN, ModelEncoding.PICKLE, SklearnInference())
|
|
377
|
+
MODEL_REGISTRY.register_for_all_classes(ModelType.TENSORFLOW, ModelEncoding.HDF5, TensorFlowInference())
|
|
378
|
+
MODEL_REGISTRY.register_for_all_classes(ModelType.LIGHTGBM, ModelEncoding.TEXT, LightGBMInference())
|
|
379
|
+
MODEL_REGISTRY.register_for_all_classes(ModelType.CATBOOST, ModelEncoding.CBM, CatBoostInference())
|
|
380
|
+
MODEL_REGISTRY.register_for_all_classes(ModelType.ONNX, ModelEncoding.PROTOBUF, ONNXInference())
|
|
381
|
+
|
|
382
|
+
# XGBoost requires different implementations for classification vs regression
|
|
383
|
+
MODEL_REGISTRY.register(ModelType.XGBOOST, ModelEncoding.JSON, None, XGBoostRegressorInference())
|
|
384
|
+
MODEL_REGISTRY.register(ModelType.XGBOOST, ModelEncoding.JSON, ModelClass.CLASSIFICATION, XGBoostClassifierInference())
|
|
385
|
+
MODEL_REGISTRY.register(ModelType.XGBOOST, ModelEncoding.JSON, ModelClass.REGRESSION, XGBoostRegressorInference())
|
chalk/ml/model_reference.py
CHANGED
|
@@ -3,12 +3,22 @@ from __future__ import annotations
|
|
|
3
3
|
import inspect
|
|
4
4
|
import os
|
|
5
5
|
from datetime import datetime
|
|
6
|
+
from typing import TYPE_CHECKING
|
|
6
7
|
|
|
7
8
|
from chalk.ml.model_version import ModelVersion
|
|
8
|
-
from chalk.ml.utils import
|
|
9
|
+
from chalk.ml.utils import (
|
|
10
|
+
ModelClass,
|
|
11
|
+
get_model_spec,
|
|
12
|
+
get_registry_metadata_file,
|
|
13
|
+
model_encoding_from_proto,
|
|
14
|
+
model_type_from_proto,
|
|
15
|
+
)
|
|
9
16
|
from chalk.utils.object_inspect import get_source_object_starting
|
|
10
17
|
from chalk.utils.source_parsing import should_skip_source_code_parsing
|
|
11
18
|
|
|
19
|
+
if TYPE_CHECKING:
|
|
20
|
+
from chalk.features.resolver import ResourceHint
|
|
21
|
+
|
|
12
22
|
|
|
13
23
|
class ModelReference:
|
|
14
24
|
def __init__(
|
|
@@ -18,6 +28,8 @@ class ModelReference:
|
|
|
18
28
|
version: int | None = None,
|
|
19
29
|
alias: str | None = None,
|
|
20
30
|
as_of_date: datetime | None = None,
|
|
31
|
+
resource_hint: "ResourceHint | None" = None,
|
|
32
|
+
resource_group: str | None = None,
|
|
21
33
|
):
|
|
22
34
|
"""Specifies the model version that should be loaded into the deployment.
|
|
23
35
|
|
|
@@ -68,6 +80,8 @@ class ModelReference:
|
|
|
68
80
|
self.as_of_date = as_of_date
|
|
69
81
|
self.alias = alias
|
|
70
82
|
self.identifier = identifier
|
|
83
|
+
self.resource_hint = resource_hint
|
|
84
|
+
self.resource_group = resource_group
|
|
71
85
|
|
|
72
86
|
self.filename = filename
|
|
73
87
|
self.source_line_start = source_line_start
|
|
@@ -89,7 +103,8 @@ class ModelReference:
|
|
|
89
103
|
MODEL_REFERENCE_REGISTRY[(name, identifier)] = self
|
|
90
104
|
|
|
91
105
|
# Only load model if the metadata file exists, which only happens in deployed environments
|
|
92
|
-
|
|
106
|
+
registry_metadata_file = get_registry_metadata_file()
|
|
107
|
+
if registry_metadata_file is not None and os.path.exists(registry_metadata_file):
|
|
93
108
|
model_artifact_metadata = get_model_spec(model_name=name, identifier=identifier)
|
|
94
109
|
|
|
95
110
|
mv = ModelVersion(
|
|
@@ -100,6 +115,11 @@ class ModelReference:
|
|
|
100
115
|
identifier=identifier,
|
|
101
116
|
model_type=model_type_from_proto(model_artifact_metadata.spec.model_type),
|
|
102
117
|
model_encoding=model_encoding_from_proto(model_artifact_metadata.spec.model_encoding),
|
|
118
|
+
model_class=ModelClass(model_artifact_metadata.spec.model_class)
|
|
119
|
+
if model_artifact_metadata.spec.model_class
|
|
120
|
+
else None,
|
|
121
|
+
resource_hint=resource_hint,
|
|
122
|
+
resource_group=resource_group,
|
|
103
123
|
)
|
|
104
124
|
|
|
105
125
|
from chalk.features.hooks import before_all
|
|
@@ -107,14 +127,22 @@ class ModelReference:
|
|
|
107
127
|
def hook():
|
|
108
128
|
mv.load_model()
|
|
109
129
|
|
|
110
|
-
before_all(hook)
|
|
130
|
+
before_all(hook, resource_hint=resource_hint, resource_group=resource_group)
|
|
111
131
|
|
|
112
132
|
self.model_version = mv
|
|
113
133
|
else:
|
|
114
|
-
self.model_version = ModelVersion(
|
|
134
|
+
self.model_version = ModelVersion(
|
|
135
|
+
name=name, identifier=identifier, resource_hint=resource_hint, resource_group=resource_group
|
|
136
|
+
)
|
|
115
137
|
|
|
116
138
|
@classmethod
|
|
117
|
-
def as_of(
|
|
139
|
+
def as_of(
|
|
140
|
+
cls,
|
|
141
|
+
name: str,
|
|
142
|
+
when: datetime,
|
|
143
|
+
resource_hint: "ResourceHint | None" = None,
|
|
144
|
+
resource_group: str | None = None,
|
|
145
|
+
) -> ModelVersion:
|
|
118
146
|
"""Creates a ModelReference for a specific point in time.
|
|
119
147
|
|
|
120
148
|
Parameters
|
|
@@ -123,6 +151,11 @@ class ModelReference:
|
|
|
123
151
|
The name of the model.
|
|
124
152
|
when
|
|
125
153
|
The datetime to use for creating the model version identifier.
|
|
154
|
+
resource_hint
|
|
155
|
+
Whether this model loading is bound by CPU, I/O, or GPU.
|
|
156
|
+
resource_group
|
|
157
|
+
The resource group for the model: this is used to isolate execution
|
|
158
|
+
onto a separate pod (or set of nodes), such as on a GPU-enabled node.
|
|
126
159
|
|
|
127
160
|
Returns
|
|
128
161
|
-------
|
|
@@ -134,13 +167,20 @@ class ModelReference:
|
|
|
134
167
|
>>> import datetime
|
|
135
168
|
>>> timestamp = datetime.datetime(2023, 10, 15, 14, 30, 0)
|
|
136
169
|
>>> model = ModelReference.as_of("fraud_model", timestamp)
|
|
170
|
+
>>> model = ModelReference.as_of("fraud_model", timestamp, resource_hint="gpu", resource_group="gpu-group")
|
|
137
171
|
"""
|
|
138
172
|
|
|
139
|
-
mr = ModelReference(name=name, as_of_date=when)
|
|
173
|
+
mr = ModelReference(name=name, as_of_date=when, resource_hint=resource_hint, resource_group=resource_group)
|
|
140
174
|
return mr.model_version
|
|
141
175
|
|
|
142
176
|
@classmethod
|
|
143
|
-
def from_version(
|
|
177
|
+
def from_version(
|
|
178
|
+
cls,
|
|
179
|
+
name: str,
|
|
180
|
+
version: int,
|
|
181
|
+
resource_hint: "ResourceHint | None" = None,
|
|
182
|
+
resource_group: str | None = None,
|
|
183
|
+
) -> ModelVersion:
|
|
144
184
|
"""Creates a ModelReference using a numeric version identifier.
|
|
145
185
|
|
|
146
186
|
Parameters
|
|
@@ -149,6 +189,11 @@ class ModelReference:
|
|
|
149
189
|
The name of the model.
|
|
150
190
|
version
|
|
151
191
|
The version number. Must be a non-negative integer.
|
|
192
|
+
resource_hint
|
|
193
|
+
Whether this model loading is bound by CPU, I/O, or GPU.
|
|
194
|
+
resource_group
|
|
195
|
+
The resource group for the model: this is used to isolate execution
|
|
196
|
+
onto a separate pod (or set of nodes), such as on a GPU-enabled node.
|
|
152
197
|
|
|
153
198
|
Returns
|
|
154
199
|
-------
|
|
@@ -163,15 +208,22 @@ class ModelReference:
|
|
|
163
208
|
Examples
|
|
164
209
|
--------
|
|
165
210
|
>>> model = ModelReference.from_version("fraud_model", 1)
|
|
211
|
+
>>> model = ModelReference.from_version("fraud_model", 1, resource_hint="gpu", resource_group="gpu-group")
|
|
166
212
|
"""
|
|
167
213
|
if version < 0:
|
|
168
214
|
raise ValueError("Version number must be a non-negative integer.")
|
|
169
215
|
|
|
170
|
-
mr = ModelReference(name=name, version=version)
|
|
216
|
+
mr = ModelReference(name=name, version=version, resource_hint=resource_hint, resource_group=resource_group)
|
|
171
217
|
return mr.model_version
|
|
172
218
|
|
|
173
219
|
@classmethod
|
|
174
|
-
def from_alias(
|
|
220
|
+
def from_alias(
|
|
221
|
+
cls,
|
|
222
|
+
name: str,
|
|
223
|
+
alias: str,
|
|
224
|
+
resource_hint: "ResourceHint | None" = None,
|
|
225
|
+
resource_group: str | None = None,
|
|
226
|
+
) -> ModelVersion:
|
|
175
227
|
"""Creates a ModelReference using an alias identifier.
|
|
176
228
|
|
|
177
229
|
Parameters
|
|
@@ -180,6 +232,11 @@ class ModelReference:
|
|
|
180
232
|
The name of the model.
|
|
181
233
|
alias
|
|
182
234
|
The alias string. Must be non-empty.
|
|
235
|
+
resource_hint
|
|
236
|
+
Whether this model loading is bound by CPU, I/O, or GPU.
|
|
237
|
+
resource_group
|
|
238
|
+
The resource group for the model: this is used to isolate execution
|
|
239
|
+
onto a separate pod (or set of nodes), such as on a GPU-enabled node.
|
|
183
240
|
|
|
184
241
|
Returns
|
|
185
242
|
-------
|
|
@@ -194,11 +251,12 @@ class ModelReference:
|
|
|
194
251
|
Examples
|
|
195
252
|
--------
|
|
196
253
|
>>> model = ModelReference.from_alias("fraud_model", "latest")
|
|
254
|
+
>>> model = ModelReference.from_alias("fraud_model", "latest", resource_hint="gpu", resource_group="gpu-group")
|
|
197
255
|
"""
|
|
198
256
|
if not alias:
|
|
199
257
|
raise ValueError("Alias must be a non-empty string.")
|
|
200
258
|
|
|
201
|
-
mr = ModelReference(name=name, alias=alias)
|
|
259
|
+
mr = ModelReference(name=name, alias=alias, resource_hint=resource_hint, resource_group=resource_group)
|
|
202
260
|
return mr.model_version
|
|
203
261
|
|
|
204
262
|
|