chalkpy 2.89.22__py3-none-any.whl → 2.95.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- chalk/__init__.py +2 -1
- chalk/_gen/chalk/arrow/v1/arrow_pb2.py +7 -5
- chalk/_gen/chalk/arrow/v1/arrow_pb2.pyi +6 -0
- chalk/_gen/chalk/artifacts/v1/chart_pb2.py +36 -33
- chalk/_gen/chalk/artifacts/v1/chart_pb2.pyi +41 -1
- chalk/_gen/chalk/artifacts/v1/cron_query_pb2.py +8 -7
- chalk/_gen/chalk/artifacts/v1/cron_query_pb2.pyi +5 -0
- chalk/_gen/chalk/common/v1/offline_query_pb2.py +19 -13
- chalk/_gen/chalk/common/v1/offline_query_pb2.pyi +37 -0
- chalk/_gen/chalk/common/v1/online_query_pb2.py +54 -54
- chalk/_gen/chalk/common/v1/online_query_pb2.pyi +13 -1
- chalk/_gen/chalk/common/v1/script_task_pb2.py +13 -11
- chalk/_gen/chalk/common/v1/script_task_pb2.pyi +19 -1
- chalk/_gen/chalk/dataframe/__init__.py +0 -0
- chalk/_gen/chalk/dataframe/v1/__init__.py +0 -0
- chalk/_gen/chalk/dataframe/v1/dataframe_pb2.py +48 -0
- chalk/_gen/chalk/dataframe/v1/dataframe_pb2.pyi +123 -0
- chalk/_gen/chalk/dataframe/v1/dataframe_pb2_grpc.py +4 -0
- chalk/_gen/chalk/dataframe/v1/dataframe_pb2_grpc.pyi +4 -0
- chalk/_gen/chalk/graph/v1/graph_pb2.py +150 -149
- chalk/_gen/chalk/graph/v1/graph_pb2.pyi +25 -0
- chalk/_gen/chalk/graph/v1/sources_pb2.py +94 -84
- chalk/_gen/chalk/graph/v1/sources_pb2.pyi +56 -0
- chalk/_gen/chalk/kubernetes/v1/horizontalpodautoscaler_pb2.py +79 -0
- chalk/_gen/chalk/kubernetes/v1/horizontalpodautoscaler_pb2.pyi +377 -0
- chalk/_gen/chalk/kubernetes/v1/horizontalpodautoscaler_pb2_grpc.py +4 -0
- chalk/_gen/chalk/kubernetes/v1/horizontalpodautoscaler_pb2_grpc.pyi +4 -0
- chalk/_gen/chalk/kubernetes/v1/scaledobject_pb2.py +43 -7
- chalk/_gen/chalk/kubernetes/v1/scaledobject_pb2.pyi +252 -2
- chalk/_gen/chalk/protosql/v1/sql_service_pb2.py +54 -27
- chalk/_gen/chalk/protosql/v1/sql_service_pb2.pyi +131 -3
- chalk/_gen/chalk/protosql/v1/sql_service_pb2_grpc.py +45 -0
- chalk/_gen/chalk/protosql/v1/sql_service_pb2_grpc.pyi +14 -0
- chalk/_gen/chalk/python/v1/types_pb2.py +14 -14
- chalk/_gen/chalk/python/v1/types_pb2.pyi +8 -0
- chalk/_gen/chalk/server/v1/benchmark_pb2.py +76 -0
- chalk/_gen/chalk/server/v1/benchmark_pb2.pyi +156 -0
- chalk/_gen/chalk/server/v1/benchmark_pb2_grpc.py +258 -0
- chalk/_gen/chalk/server/v1/benchmark_pb2_grpc.pyi +84 -0
- chalk/_gen/chalk/server/v1/billing_pb2.py +40 -38
- chalk/_gen/chalk/server/v1/billing_pb2.pyi +17 -1
- chalk/_gen/chalk/server/v1/branches_pb2.py +45 -0
- chalk/_gen/chalk/server/v1/branches_pb2.pyi +80 -0
- chalk/_gen/chalk/server/v1/branches_pb2_grpc.pyi +36 -0
- chalk/_gen/chalk/server/v1/builder_pb2.py +372 -272
- chalk/_gen/chalk/server/v1/builder_pb2.pyi +479 -12
- chalk/_gen/chalk/server/v1/builder_pb2_grpc.py +360 -0
- chalk/_gen/chalk/server/v1/builder_pb2_grpc.pyi +96 -0
- chalk/_gen/chalk/server/v1/chart_pb2.py +10 -10
- chalk/_gen/chalk/server/v1/chart_pb2.pyi +18 -2
- chalk/_gen/chalk/server/v1/clickhouse_pb2.py +42 -0
- chalk/_gen/chalk/server/v1/clickhouse_pb2.pyi +17 -0
- chalk/_gen/chalk/server/v1/clickhouse_pb2_grpc.py +78 -0
- chalk/_gen/chalk/server/v1/clickhouse_pb2_grpc.pyi +38 -0
- chalk/_gen/chalk/server/v1/cloud_components_pb2.py +153 -107
- chalk/_gen/chalk/server/v1/cloud_components_pb2.pyi +146 -4
- chalk/_gen/chalk/server/v1/cloud_components_pb2_grpc.py +180 -0
- chalk/_gen/chalk/server/v1/cloud_components_pb2_grpc.pyi +48 -0
- chalk/_gen/chalk/server/v1/cloud_credentials_pb2.py +11 -3
- chalk/_gen/chalk/server/v1/cloud_credentials_pb2.pyi +20 -0
- chalk/_gen/chalk/server/v1/cloud_credentials_pb2_grpc.py +45 -0
- chalk/_gen/chalk/server/v1/cloud_credentials_pb2_grpc.pyi +12 -0
- chalk/_gen/chalk/server/v1/dataplanejobqueue_pb2.py +59 -35
- chalk/_gen/chalk/server/v1/dataplanejobqueue_pb2.pyi +127 -1
- chalk/_gen/chalk/server/v1/dataplanejobqueue_pb2_grpc.py +135 -0
- chalk/_gen/chalk/server/v1/dataplanejobqueue_pb2_grpc.pyi +36 -0
- chalk/_gen/chalk/server/v1/dataplaneworkflows_pb2.py +90 -0
- chalk/_gen/chalk/server/v1/dataplaneworkflows_pb2.pyi +264 -0
- chalk/_gen/chalk/server/v1/dataplaneworkflows_pb2_grpc.py +170 -0
- chalk/_gen/chalk/server/v1/dataplaneworkflows_pb2_grpc.pyi +62 -0
- chalk/_gen/chalk/server/v1/datasets_pb2.py +36 -24
- chalk/_gen/chalk/server/v1/datasets_pb2.pyi +71 -2
- chalk/_gen/chalk/server/v1/datasets_pb2_grpc.py +45 -0
- chalk/_gen/chalk/server/v1/datasets_pb2_grpc.pyi +12 -0
- chalk/_gen/chalk/server/v1/deploy_pb2.py +9 -3
- chalk/_gen/chalk/server/v1/deploy_pb2.pyi +12 -0
- chalk/_gen/chalk/server/v1/deploy_pb2_grpc.py +45 -0
- chalk/_gen/chalk/server/v1/deploy_pb2_grpc.pyi +12 -0
- chalk/_gen/chalk/server/v1/deployment_pb2.py +20 -15
- chalk/_gen/chalk/server/v1/deployment_pb2.pyi +25 -0
- chalk/_gen/chalk/server/v1/environment_pb2.py +25 -15
- chalk/_gen/chalk/server/v1/environment_pb2.pyi +93 -1
- chalk/_gen/chalk/server/v1/eventbus_pb2.py +44 -0
- chalk/_gen/chalk/server/v1/eventbus_pb2.pyi +64 -0
- chalk/_gen/chalk/server/v1/eventbus_pb2_grpc.py +4 -0
- chalk/_gen/chalk/server/v1/eventbus_pb2_grpc.pyi +4 -0
- chalk/_gen/chalk/server/v1/files_pb2.py +65 -0
- chalk/_gen/chalk/server/v1/files_pb2.pyi +167 -0
- chalk/_gen/chalk/server/v1/files_pb2_grpc.py +4 -0
- chalk/_gen/chalk/server/v1/files_pb2_grpc.pyi +4 -0
- chalk/_gen/chalk/server/v1/graph_pb2.py +41 -3
- chalk/_gen/chalk/server/v1/graph_pb2.pyi +191 -0
- chalk/_gen/chalk/server/v1/graph_pb2_grpc.py +92 -0
- chalk/_gen/chalk/server/v1/graph_pb2_grpc.pyi +32 -0
- chalk/_gen/chalk/server/v1/incident_pb2.py +57 -0
- chalk/_gen/chalk/server/v1/incident_pb2.pyi +165 -0
- chalk/_gen/chalk/server/v1/incident_pb2_grpc.py +4 -0
- chalk/_gen/chalk/server/v1/incident_pb2_grpc.pyi +4 -0
- chalk/_gen/chalk/server/v1/indexing_job_pb2.py +44 -0
- chalk/_gen/chalk/server/v1/indexing_job_pb2.pyi +38 -0
- chalk/_gen/chalk/server/v1/indexing_job_pb2_grpc.py +78 -0
- chalk/_gen/chalk/server/v1/indexing_job_pb2_grpc.pyi +38 -0
- chalk/_gen/chalk/server/v1/integrations_pb2.py +11 -9
- chalk/_gen/chalk/server/v1/integrations_pb2.pyi +34 -2
- chalk/_gen/chalk/server/v1/kube_pb2.py +29 -19
- chalk/_gen/chalk/server/v1/kube_pb2.pyi +28 -0
- chalk/_gen/chalk/server/v1/kube_pb2_grpc.py +45 -0
- chalk/_gen/chalk/server/v1/kube_pb2_grpc.pyi +12 -0
- chalk/_gen/chalk/server/v1/log_pb2.py +21 -3
- chalk/_gen/chalk/server/v1/log_pb2.pyi +68 -0
- chalk/_gen/chalk/server/v1/log_pb2_grpc.py +90 -0
- chalk/_gen/chalk/server/v1/log_pb2_grpc.pyi +24 -0
- chalk/_gen/chalk/server/v1/metadataplanejobqueue_pb2.py +73 -0
- chalk/_gen/chalk/server/v1/metadataplanejobqueue_pb2.pyi +212 -0
- chalk/_gen/chalk/server/v1/metadataplanejobqueue_pb2_grpc.py +217 -0
- chalk/_gen/chalk/server/v1/metadataplanejobqueue_pb2_grpc.pyi +74 -0
- chalk/_gen/chalk/server/v1/model_registry_pb2.py +10 -10
- chalk/_gen/chalk/server/v1/model_registry_pb2.pyi +4 -1
- chalk/_gen/chalk/server/v1/monitoring_pb2.py +84 -75
- chalk/_gen/chalk/server/v1/monitoring_pb2.pyi +1 -0
- chalk/_gen/chalk/server/v1/monitoring_pb2_grpc.py +136 -0
- chalk/_gen/chalk/server/v1/monitoring_pb2_grpc.pyi +38 -0
- chalk/_gen/chalk/server/v1/offline_queries_pb2.py +32 -10
- chalk/_gen/chalk/server/v1/offline_queries_pb2.pyi +73 -0
- chalk/_gen/chalk/server/v1/offline_queries_pb2_grpc.py +90 -0
- chalk/_gen/chalk/server/v1/offline_queries_pb2_grpc.pyi +24 -0
- chalk/_gen/chalk/server/v1/plandebug_pb2.py +53 -0
- chalk/_gen/chalk/server/v1/plandebug_pb2.pyi +86 -0
- chalk/_gen/chalk/server/v1/plandebug_pb2_grpc.py +168 -0
- chalk/_gen/chalk/server/v1/plandebug_pb2_grpc.pyi +60 -0
- chalk/_gen/chalk/server/v1/queries_pb2.py +76 -48
- chalk/_gen/chalk/server/v1/queries_pb2.pyi +155 -2
- chalk/_gen/chalk/server/v1/queries_pb2_grpc.py +180 -0
- chalk/_gen/chalk/server/v1/queries_pb2_grpc.pyi +48 -0
- chalk/_gen/chalk/server/v1/scheduled_query_pb2.py +4 -2
- chalk/_gen/chalk/server/v1/scheduled_query_pb2_grpc.py +45 -0
- chalk/_gen/chalk/server/v1/scheduled_query_pb2_grpc.pyi +12 -0
- chalk/_gen/chalk/server/v1/scheduled_query_run_pb2.py +12 -6
- chalk/_gen/chalk/server/v1/scheduled_query_run_pb2.pyi +75 -2
- chalk/_gen/chalk/server/v1/scheduler_pb2.py +24 -12
- chalk/_gen/chalk/server/v1/scheduler_pb2.pyi +61 -1
- chalk/_gen/chalk/server/v1/scheduler_pb2_grpc.py +90 -0
- chalk/_gen/chalk/server/v1/scheduler_pb2_grpc.pyi +24 -0
- chalk/_gen/chalk/server/v1/script_tasks_pb2.py +26 -14
- chalk/_gen/chalk/server/v1/script_tasks_pb2.pyi +33 -3
- chalk/_gen/chalk/server/v1/script_tasks_pb2_grpc.py +90 -0
- chalk/_gen/chalk/server/v1/script_tasks_pb2_grpc.pyi +24 -0
- chalk/_gen/chalk/server/v1/sql_interface_pb2.py +75 -0
- chalk/_gen/chalk/server/v1/sql_interface_pb2.pyi +142 -0
- chalk/_gen/chalk/server/v1/sql_interface_pb2_grpc.py +349 -0
- chalk/_gen/chalk/server/v1/sql_interface_pb2_grpc.pyi +114 -0
- chalk/_gen/chalk/server/v1/sql_queries_pb2.py +48 -0
- chalk/_gen/chalk/server/v1/sql_queries_pb2.pyi +150 -0
- chalk/_gen/chalk/server/v1/sql_queries_pb2_grpc.py +123 -0
- chalk/_gen/chalk/server/v1/sql_queries_pb2_grpc.pyi +52 -0
- chalk/_gen/chalk/server/v1/team_pb2.py +156 -137
- chalk/_gen/chalk/server/v1/team_pb2.pyi +56 -10
- chalk/_gen/chalk/server/v1/team_pb2_grpc.py +90 -0
- chalk/_gen/chalk/server/v1/team_pb2_grpc.pyi +24 -0
- chalk/_gen/chalk/server/v1/topic_pb2.py +5 -3
- chalk/_gen/chalk/server/v1/topic_pb2.pyi +10 -1
- chalk/_gen/chalk/server/v1/trace_pb2.py +50 -28
- chalk/_gen/chalk/server/v1/trace_pb2.pyi +121 -0
- chalk/_gen/chalk/server/v1/trace_pb2_grpc.py +135 -0
- chalk/_gen/chalk/server/v1/trace_pb2_grpc.pyi +42 -0
- chalk/_gen/chalk/server/v1/webhook_pb2.py +9 -3
- chalk/_gen/chalk/server/v1/webhook_pb2.pyi +18 -0
- chalk/_gen/chalk/server/v1/webhook_pb2_grpc.py +45 -0
- chalk/_gen/chalk/server/v1/webhook_pb2_grpc.pyi +12 -0
- chalk/_gen/chalk/streaming/v1/debug_service_pb2.py +62 -0
- chalk/_gen/chalk/streaming/v1/debug_service_pb2.pyi +75 -0
- chalk/_gen/chalk/streaming/v1/debug_service_pb2_grpc.py +221 -0
- chalk/_gen/chalk/streaming/v1/debug_service_pb2_grpc.pyi +88 -0
- chalk/_gen/chalk/streaming/v1/simple_streaming_service_pb2.py +19 -7
- chalk/_gen/chalk/streaming/v1/simple_streaming_service_pb2.pyi +96 -3
- chalk/_gen/chalk/streaming/v1/simple_streaming_service_pb2_grpc.py +48 -0
- chalk/_gen/chalk/streaming/v1/simple_streaming_service_pb2_grpc.pyi +20 -0
- chalk/_gen/chalk/utils/v1/field_change_pb2.py +32 -0
- chalk/_gen/chalk/utils/v1/field_change_pb2.pyi +42 -0
- chalk/_gen/chalk/utils/v1/field_change_pb2_grpc.py +4 -0
- chalk/_gen/chalk/utils/v1/field_change_pb2_grpc.pyi +4 -0
- chalk/_lsp/error_builder.py +11 -0
- chalk/_monitoring/Chart.py +1 -3
- chalk/_version.py +1 -1
- chalk/cli.py +5 -10
- chalk/client/client.py +178 -64
- chalk/client/client_async.py +154 -0
- chalk/client/client_async_impl.py +22 -0
- chalk/client/client_grpc.py +738 -112
- chalk/client/client_impl.py +541 -136
- chalk/client/dataset.py +27 -6
- chalk/client/models.py +99 -2
- chalk/client/serialization/model_serialization.py +126 -10
- chalk/config/project_config.py +1 -1
- chalk/df/LazyFramePlaceholder.py +1154 -0
- chalk/df/ast_parser.py +2 -10
- chalk/features/_class_property.py +7 -0
- chalk/features/_embedding/embedding.py +1 -0
- chalk/features/_embedding/sentence_transformer.py +1 -1
- chalk/features/_encoding/converter.py +83 -2
- chalk/features/_encoding/pyarrow.py +20 -4
- chalk/features/_encoding/rich.py +1 -3
- chalk/features/_tensor.py +1 -2
- chalk/features/dataframe/_filters.py +14 -5
- chalk/features/dataframe/_impl.py +91 -36
- chalk/features/dataframe/_validation.py +11 -7
- chalk/features/feature_field.py +40 -30
- chalk/features/feature_set.py +1 -2
- chalk/features/feature_set_decorator.py +1 -0
- chalk/features/feature_wrapper.py +42 -3
- chalk/features/hooks.py +81 -12
- chalk/features/inference.py +65 -10
- chalk/features/resolver.py +338 -56
- chalk/features/tag.py +1 -3
- chalk/features/underscore_features.py +2 -1
- chalk/functions/__init__.py +456 -21
- chalk/functions/holidays.py +1 -3
- chalk/gitignore/gitignore_parser.py +5 -1
- chalk/importer.py +186 -74
- chalk/ml/__init__.py +6 -2
- chalk/ml/model_hooks.py +368 -51
- chalk/ml/model_reference.py +68 -10
- chalk/ml/model_version.py +34 -21
- chalk/ml/utils.py +143 -40
- chalk/operators/_utils.py +14 -3
- chalk/parsed/_proto/export.py +22 -0
- chalk/parsed/duplicate_input_gql.py +4 -0
- chalk/parsed/expressions.py +1 -3
- chalk/parsed/json_conversions.py +21 -14
- chalk/parsed/to_proto.py +16 -4
- chalk/parsed/user_types_to_json.py +31 -10
- chalk/parsed/validation_from_registries.py +182 -0
- chalk/queries/named_query.py +16 -6
- chalk/queries/scheduled_query.py +13 -1
- chalk/serialization/parsed_annotation.py +25 -12
- chalk/sql/__init__.py +221 -0
- chalk/sql/_internal/integrations/athena.py +6 -1
- chalk/sql/_internal/integrations/bigquery.py +22 -2
- chalk/sql/_internal/integrations/databricks.py +61 -18
- chalk/sql/_internal/integrations/mssql.py +281 -0
- chalk/sql/_internal/integrations/postgres.py +11 -3
- chalk/sql/_internal/integrations/redshift.py +4 -0
- chalk/sql/_internal/integrations/snowflake.py +11 -2
- chalk/sql/_internal/integrations/util.py +2 -1
- chalk/sql/_internal/sql_file_resolver.py +55 -10
- chalk/sql/_internal/sql_source.py +36 -2
- chalk/streams/__init__.py +1 -3
- chalk/streams/_kafka_source.py +5 -1
- chalk/streams/_windows.py +16 -4
- chalk/streams/types.py +1 -2
- chalk/utils/__init__.py +1 -3
- chalk/utils/_otel_version.py +13 -0
- chalk/utils/async_helpers.py +14 -5
- chalk/utils/df_utils.py +2 -2
- chalk/utils/duration.py +1 -3
- chalk/utils/job_log_display.py +538 -0
- chalk/utils/missing_dependency.py +5 -4
- chalk/utils/notebook.py +255 -2
- chalk/utils/pl_helpers.py +190 -37
- chalk/utils/pydanticutil/pydantic_compat.py +1 -2
- chalk/utils/storage_client.py +246 -0
- chalk/utils/threading.py +1 -3
- chalk/utils/tracing.py +194 -86
- {chalkpy-2.89.22.dist-info → chalkpy-2.95.3.dist-info}/METADATA +53 -21
- {chalkpy-2.89.22.dist-info → chalkpy-2.95.3.dist-info}/RECORD +268 -198
- {chalkpy-2.89.22.dist-info → chalkpy-2.95.3.dist-info}/WHEEL +0 -0
- {chalkpy-2.89.22.dist-info → chalkpy-2.95.3.dist-info}/entry_points.txt +0 -0
- {chalkpy-2.89.22.dist-info → chalkpy-2.95.3.dist-info}/top_level.txt +0 -0
chalk/ml/model_version.py
CHANGED
|
@@ -1,10 +1,13 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
from datetime import datetime
|
|
4
|
-
from typing import
|
|
4
|
+
from typing import TYPE_CHECKING, Any
|
|
5
5
|
|
|
6
|
-
from chalk.ml.model_hooks import
|
|
7
|
-
from chalk.ml.utils import ModelEncoding, ModelType
|
|
6
|
+
from chalk.ml.model_hooks import MODEL_REGISTRY
|
|
7
|
+
from chalk.ml.utils import ModelClass, ModelEncoding, ModelType
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from chalk.features.resolver import ResourceHint
|
|
8
11
|
|
|
9
12
|
|
|
10
13
|
class ModelVersion:
|
|
@@ -18,8 +21,10 @@ class ModelVersion:
|
|
|
18
21
|
identifier: str | None = None,
|
|
19
22
|
model_type: ModelType | None = None,
|
|
20
23
|
model_encoding: ModelEncoding | None = None,
|
|
21
|
-
model_class:
|
|
24
|
+
model_class: ModelClass | None = None,
|
|
22
25
|
filename: str | None = None,
|
|
26
|
+
resource_hint: "ResourceHint | None" = None,
|
|
27
|
+
resource_group: str | None = None,
|
|
23
28
|
):
|
|
24
29
|
"""Specifies the model version that should be loaded into the deployment.
|
|
25
30
|
|
|
@@ -41,9 +46,11 @@ class ModelVersion:
|
|
|
41
46
|
self.model_encoding = model_encoding
|
|
42
47
|
self.model_class = model_class
|
|
43
48
|
self.filename = filename
|
|
49
|
+
self.resource_hint: "ResourceHint | None" = resource_hint
|
|
50
|
+
self.resource_group = resource_group
|
|
44
51
|
|
|
45
52
|
self._model = None
|
|
46
|
-
self.
|
|
53
|
+
self._predictor = None
|
|
47
54
|
|
|
48
55
|
def get_model_file(self) -> str | None:
|
|
49
56
|
"""Returns the filename of the model."""
|
|
@@ -54,26 +61,19 @@ class ModelVersion:
|
|
|
54
61
|
def load_model(self):
|
|
55
62
|
"""Loads the model from the specified filename using the appropriate hook."""
|
|
56
63
|
if self.model_type and self.model_encoding:
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
64
|
+
model = MODEL_REGISTRY.get(
|
|
65
|
+
model_type=self.model_type, encoding=self.model_encoding, model_class=self.model_class
|
|
66
|
+
)
|
|
67
|
+
if model is not None and self.filename is not None:
|
|
68
|
+
self._model = model.load_model(self.filename, resource_hint=self.resource_hint)
|
|
60
69
|
else:
|
|
61
70
|
raise ValueError(
|
|
62
|
-
f"No load function defined for type {self.model_type} and
|
|
71
|
+
f"No load function defined for type {self.model_type}, encoding {self.model_encoding}, and class {self.model_class}"
|
|
63
72
|
)
|
|
64
73
|
|
|
65
|
-
def predict(self, X:
|
|
66
|
-
"""
|
|
67
|
-
|
|
68
|
-
if self._predict_fn is None:
|
|
69
|
-
if self.model_type is None or self.model_encoding is None:
|
|
70
|
-
raise ValueError("Model type and encoding must be specified to use predict.")
|
|
71
|
-
self._predict_fn = PREDICT_HOOKS.get((self.model_type, self.model_encoding, self.model_class), None)
|
|
72
|
-
if self._predict_fn is None:
|
|
73
|
-
raise ValueError(
|
|
74
|
-
f"No predict function defined for type {self.model_type} and extension {self.model_encoding}"
|
|
75
|
-
)
|
|
76
|
-
return self._predict_fn(self.model, X)
|
|
74
|
+
def predict(self, X: Any):
|
|
75
|
+
"""Runs prediction using the loaded model."""
|
|
76
|
+
return self.predictor.predict(self.model, X)
|
|
77
77
|
|
|
78
78
|
@property
|
|
79
79
|
def model(self) -> Any:
|
|
@@ -82,3 +82,16 @@ class ModelVersion:
|
|
|
82
82
|
self.load_model()
|
|
83
83
|
|
|
84
84
|
return self._model
|
|
85
|
+
|
|
86
|
+
@property
|
|
87
|
+
def predictor(self) -> Any:
|
|
88
|
+
"""Returns the predictor instance, initializing it if needed."""
|
|
89
|
+
if self._predictor is None:
|
|
90
|
+
if self.model_type is None or self.model_encoding is None:
|
|
91
|
+
raise ValueError("Model type and encoding must be specified to use predictor.")
|
|
92
|
+
self._predictor = MODEL_REGISTRY.get(
|
|
93
|
+
model_type=self.model_type, encoding=self.model_encoding, model_class=self.model_class
|
|
94
|
+
)
|
|
95
|
+
if self._predictor is None:
|
|
96
|
+
raise ValueError(f"No predictor defined for type {self.model_type} and encoding {self.model_encoding}")
|
|
97
|
+
return self._predictor
|
chalk/ml/utils.py
CHANGED
|
@@ -2,23 +2,33 @@ import os
|
|
|
2
2
|
from dataclasses import dataclass
|
|
3
3
|
from enum import Enum
|
|
4
4
|
from functools import cache
|
|
5
|
-
from typing import Literal, Mapping, Tuple
|
|
5
|
+
from typing import Literal, Mapping, Optional, Tuple
|
|
6
6
|
|
|
7
7
|
import pyarrow as pa
|
|
8
8
|
|
|
9
9
|
import chalk._gen.chalk.models.v1.model_artifact_pb2 as pb
|
|
10
10
|
import chalk._gen.chalk.models.v1.model_version_pb2 as mv_pb
|
|
11
11
|
|
|
12
|
-
|
|
12
|
+
|
|
13
|
+
def get_registry_metadata_file() -> Optional[str]:
|
|
14
|
+
branch_root = os.getenv("CHALK_MODEL_REGISTRY_BRANCH_METADATA_ROOT", None)
|
|
15
|
+
if os.getenv("IS_BRANCH", None) is not None and branch_root is not None:
|
|
16
|
+
return os.path.join(branch_root, os.getenv("CHALK_DEPLOYMENT_ID", "") + ".bin")
|
|
17
|
+
return os.getenv("CHALK_MODEL_REGISTRY_METADATA_FILENAME", None)
|
|
18
|
+
|
|
19
|
+
|
|
13
20
|
CHALK_MODEL_REGISTRY_ROOT = os.getenv("CHALK_MODEL_REGISTRY_ROOT", "/models")
|
|
14
21
|
|
|
15
22
|
MODEL_METADATA_PREFIX = "__chalk_model__"
|
|
16
23
|
|
|
17
24
|
MODEL_TRAIN_METADATA_RUN_NAME = f"{MODEL_METADATA_PREFIX}run_name__"
|
|
25
|
+
MODEL_TRAIN_RUN_NAME_ENV_VAR = "CHALK_MODEL_TRAIN_RUN_NAME"
|
|
26
|
+
|
|
27
|
+
MODEL_TRAIN_METADATA_RUN_ID = f"{MODEL_METADATA_PREFIX}run_id__"
|
|
18
28
|
|
|
19
29
|
|
|
20
30
|
def get_model_metadata_run_name_from_env():
|
|
21
|
-
return os.getenv(
|
|
31
|
+
return os.getenv(MODEL_TRAIN_RUN_NAME_ENV_VAR, "")
|
|
22
32
|
|
|
23
33
|
|
|
24
34
|
class ModelType(str, Enum):
|
|
@@ -42,6 +52,14 @@ class ModelEncoding(str, Enum):
|
|
|
42
52
|
SAFETENSOR = "MODEL_ENCODING_SAFETENSORS"
|
|
43
53
|
|
|
44
54
|
|
|
55
|
+
class ModelClass(str, Enum):
|
|
56
|
+
CLASSIFICATION = "classification"
|
|
57
|
+
REGRESSION = "regression"
|
|
58
|
+
CLUSTERING = "clustering"
|
|
59
|
+
DIMENSIONALITY_REDUCTION = "dimensionality_reduction"
|
|
60
|
+
EMBEDDING = "embedding"
|
|
61
|
+
|
|
62
|
+
|
|
45
63
|
@dataclass
|
|
46
64
|
class ModelRunCriterion:
|
|
47
65
|
direction: Literal["max", "min"]
|
|
@@ -66,11 +84,12 @@ def load_model_map() -> Mapping[Tuple[str, str], LoadedModel]:
|
|
|
66
84
|
model_map: dict[Tuple[str, str], LoadedModel] = {}
|
|
67
85
|
|
|
68
86
|
try:
|
|
69
|
-
|
|
70
|
-
|
|
87
|
+
registry_metadata_file = get_registry_metadata_file()
|
|
88
|
+
if registry_metadata_file is not None:
|
|
89
|
+
with open(registry_metadata_file, "rb") as f:
|
|
71
90
|
mms.ParseFromString(f.read())
|
|
72
91
|
except FileNotFoundError:
|
|
73
|
-
raise FileNotFoundError(f"Model registry metadata file not found: {
|
|
92
|
+
raise FileNotFoundError(f"Model registry metadata file not found: {registry_metadata_file}")
|
|
74
93
|
except Exception as e:
|
|
75
94
|
raise RuntimeError(f"Failed to load model map: {e}")
|
|
76
95
|
|
|
@@ -309,13 +328,104 @@ class ModelAttributeExtractor:
|
|
|
309
328
|
return input_schema, output_schema
|
|
310
329
|
|
|
311
330
|
@staticmethod
|
|
312
|
-
def
|
|
331
|
+
def infer_catboost_schemas(
|
|
332
|
+
model: Any,
|
|
333
|
+
) -> Tuple[Optional[List[Tuple[List[int], Any]]], Optional[List[Tuple[List[int], Any]]]]:
|
|
334
|
+
input_schema: Optional[List[Tuple[List[int], Any]]] = None
|
|
335
|
+
output_schema: Optional[List[Tuple[List[int], Any]]] = None
|
|
336
|
+
|
|
337
|
+
try:
|
|
338
|
+
n_features = None
|
|
339
|
+
|
|
340
|
+
# CatBoost uses feature_names_ or can query from get_feature_importance
|
|
341
|
+
if hasattr(model, "feature_names_") and model.feature_names_ is not None:
|
|
342
|
+
n_features = len(model.feature_names_)
|
|
343
|
+
elif hasattr(model, "n_features_in_"):
|
|
344
|
+
n_features = model.n_features_in_
|
|
345
|
+
elif hasattr(model, "get_feature_importance"):
|
|
346
|
+
# Try to get feature count from the model's tree structure
|
|
347
|
+
try:
|
|
348
|
+
feature_importances = model.get_feature_importance()
|
|
349
|
+
if feature_importances is not None:
|
|
350
|
+
n_features = len(feature_importances)
|
|
351
|
+
except Exception:
|
|
352
|
+
pass
|
|
353
|
+
|
|
354
|
+
if n_features is not None:
|
|
355
|
+
input_schema = [([n_features], pa.float64())]
|
|
356
|
+
|
|
357
|
+
# Determine output schema based on model type
|
|
358
|
+
# CatBoost has is_fitted() and can check the model type
|
|
359
|
+
if hasattr(model, "_estimator_type"):
|
|
360
|
+
if model._estimator_type == "classifier":
|
|
361
|
+
n_classes = None
|
|
362
|
+
if hasattr(model, "classes_") and model.classes_ is not None:
|
|
363
|
+
n_classes = len(model.classes_)
|
|
364
|
+
|
|
365
|
+
if n_classes is not None:
|
|
366
|
+
if n_classes == 2:
|
|
367
|
+
output_schema = [([1], pa.float64())]
|
|
368
|
+
else:
|
|
369
|
+
output_schema = [([n_classes], pa.float64())]
|
|
370
|
+
else:
|
|
371
|
+
output_schema = [([1], pa.float64())]
|
|
372
|
+
|
|
373
|
+
elif model._estimator_type == "regressor":
|
|
374
|
+
output_schema = [([1], pa.float64())]
|
|
375
|
+
else:
|
|
376
|
+
# Check class name as fallback
|
|
377
|
+
class_name = model.__class__.__name__
|
|
378
|
+
if "Classifier" in class_name:
|
|
379
|
+
n_classes = None
|
|
380
|
+
if hasattr(model, "classes_") and model.classes_ is not None:
|
|
381
|
+
n_classes = len(model.classes_)
|
|
382
|
+
|
|
383
|
+
if n_classes is not None:
|
|
384
|
+
if n_classes == 2:
|
|
385
|
+
output_schema = [([1], pa.float64())]
|
|
386
|
+
else:
|
|
387
|
+
output_schema = [([n_classes], pa.float64())]
|
|
388
|
+
else:
|
|
389
|
+
output_schema = [([1], pa.float64())]
|
|
390
|
+
elif "Regressor" in class_name:
|
|
391
|
+
output_schema = [([1], pa.float64())]
|
|
392
|
+
else:
|
|
393
|
+
# Default to single output
|
|
394
|
+
output_schema = [([1], pa.float64())]
|
|
395
|
+
|
|
396
|
+
except Exception:
|
|
397
|
+
pass
|
|
398
|
+
|
|
399
|
+
return input_schema, output_schema
|
|
400
|
+
|
|
401
|
+
@staticmethod
|
|
402
|
+
def infer_model_type(model: Any) -> Tuple[Optional[ModelType], Optional[ModelClass]]:
|
|
403
|
+
# ONNX - check early since ONNX models are commonly wrapped
|
|
404
|
+
try:
|
|
405
|
+
import onnx # pyright: ignore[reportMissingImports]
|
|
406
|
+
|
|
407
|
+
if isinstance(model, onnx.ModelProto):
|
|
408
|
+
return ModelType.ONNX, None
|
|
409
|
+
# Check if model has a wrapped ONNX ModelProto (e.g., model._model)
|
|
410
|
+
if hasattr(model, "_model") and isinstance(model._model, onnx.ModelProto):
|
|
411
|
+
return ModelType.ONNX, None
|
|
412
|
+
except ImportError:
|
|
413
|
+
pass
|
|
414
|
+
|
|
415
|
+
try:
|
|
416
|
+
import onnxruntime # pyright: ignore[reportMissingImports]
|
|
417
|
+
|
|
418
|
+
if isinstance(model, onnxruntime.InferenceSession):
|
|
419
|
+
return ModelType.ONNX, None
|
|
420
|
+
except ImportError:
|
|
421
|
+
pass
|
|
422
|
+
|
|
313
423
|
# PYTORCH
|
|
314
424
|
try:
|
|
315
425
|
import torch.nn as nn # pyright: ignore[reportMissingImports]
|
|
316
426
|
|
|
317
427
|
if isinstance(model, nn.Module):
|
|
318
|
-
return ModelType.PYTORCH
|
|
428
|
+
return ModelType.PYTORCH, None
|
|
319
429
|
except ImportError:
|
|
320
430
|
pass
|
|
321
431
|
|
|
@@ -323,11 +433,16 @@ class ModelAttributeExtractor:
|
|
|
323
433
|
try:
|
|
324
434
|
import xgboost as xgb # pyright: ignore[reportMissingImports]
|
|
325
435
|
|
|
436
|
+
if isinstance(model, xgb.XGBClassifier):
|
|
437
|
+
return ModelType.XGBOOST, ModelClass.CLASSIFICATION
|
|
438
|
+
if isinstance(model, xgb.XGBRegressor):
|
|
439
|
+
return ModelType.XGBOOST, ModelClass.REGRESSION
|
|
440
|
+
|
|
326
441
|
if isinstance(model, (xgb.XGBModel, xgb.Booster)):
|
|
327
|
-
return ModelType.XGBOOST
|
|
442
|
+
return ModelType.XGBOOST, None
|
|
328
443
|
# Also check for XGBoost sklearn API
|
|
329
444
|
if hasattr(model, "__class__") and "xgboost" in model.__class__.__module__:
|
|
330
|
-
return ModelType.XGBOOST
|
|
445
|
+
return ModelType.XGBOOST, None
|
|
331
446
|
except ImportError:
|
|
332
447
|
pass
|
|
333
448
|
|
|
@@ -336,9 +451,9 @@ class ModelAttributeExtractor:
|
|
|
336
451
|
import lightgbm as lgb # pyright: ignore[reportMissingImports]
|
|
337
452
|
|
|
338
453
|
if isinstance(model, (lgb.LGBMModel, lgb.Booster)):
|
|
339
|
-
return ModelType.LIGHTGBM
|
|
454
|
+
return ModelType.LIGHTGBM, None
|
|
340
455
|
if hasattr(model, "__class__") and "lightgbm" in model.__class__.__module__:
|
|
341
|
-
return ModelType.LIGHTGBM
|
|
456
|
+
return ModelType.LIGHTGBM, None
|
|
342
457
|
except ImportError:
|
|
343
458
|
pass
|
|
344
459
|
|
|
@@ -346,32 +461,20 @@ class ModelAttributeExtractor:
|
|
|
346
461
|
try:
|
|
347
462
|
import catboost as cb # pyright: ignore[reportMissingImports]
|
|
348
463
|
|
|
349
|
-
# CatBoost
|
|
350
|
-
if hasattr(model, "__class__") and "catboost" in model.__class__.__module__:
|
|
351
|
-
return ModelType.CATBOOST
|
|
352
|
-
# Common CatBoost classes
|
|
464
|
+
# Common CatBoost classes - check specific types first
|
|
353
465
|
try:
|
|
354
|
-
if isinstance(model,
|
|
355
|
-
return ModelType.CATBOOST
|
|
466
|
+
if isinstance(model, cb.CatBoostClassifier):
|
|
467
|
+
return ModelType.CATBOOST, ModelClass.CLASSIFICATION
|
|
468
|
+
if isinstance(model, cb.CatBoostRegressor):
|
|
469
|
+
return ModelType.CATBOOST, ModelClass.REGRESSION
|
|
470
|
+
|
|
471
|
+
if isinstance(model, (cb.CatBoost)):
|
|
472
|
+
return ModelType.CATBOOST, None
|
|
356
473
|
except (AttributeError, NameError):
|
|
357
474
|
pass
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
# ONNX - do we need this one?
|
|
362
|
-
try:
|
|
363
|
-
import onnx # pyright: ignore[reportMissingImports]
|
|
364
|
-
|
|
365
|
-
if isinstance(model, onnx.ModelProto):
|
|
366
|
-
return ModelType.ONNX
|
|
367
|
-
except ImportError:
|
|
368
|
-
pass
|
|
369
|
-
|
|
370
|
-
try:
|
|
371
|
-
import onnxruntime # pyright: ignore[reportMissingImports]
|
|
372
|
-
|
|
373
|
-
if isinstance(model, onnxruntime.InferenceSession):
|
|
374
|
-
return ModelType.ONNX
|
|
475
|
+
# CatBoost has various model classes - generic fallback
|
|
476
|
+
if hasattr(model, "__class__") and "catboost" in model.__class__.__module__:
|
|
477
|
+
return ModelType.CATBOOST, None
|
|
375
478
|
except ImportError:
|
|
376
479
|
pass
|
|
377
480
|
|
|
@@ -380,10 +483,10 @@ class ModelAttributeExtractor:
|
|
|
380
483
|
import sklearn.base # pyright: ignore[reportMissingImports]
|
|
381
484
|
|
|
382
485
|
if isinstance(model, sklearn.base.BaseEstimator):
|
|
383
|
-
return ModelType.SKLEARN
|
|
486
|
+
return ModelType.SKLEARN, None
|
|
384
487
|
|
|
385
488
|
if hasattr(model, "__class__") and "sklearn" in model.__class__.__module__:
|
|
386
|
-
return ModelType.SKLEARN
|
|
489
|
+
return ModelType.SKLEARN, None
|
|
387
490
|
except ImportError:
|
|
388
491
|
pass
|
|
389
492
|
|
|
@@ -392,10 +495,10 @@ class ModelAttributeExtractor:
|
|
|
392
495
|
import tensorflow as tf # pyright: ignore[reportMissingImports]
|
|
393
496
|
|
|
394
497
|
if isinstance(model, tf.keras.Model):
|
|
395
|
-
return ModelType.TENSORFLOW
|
|
498
|
+
return ModelType.TENSORFLOW, None
|
|
396
499
|
if hasattr(model, "__class__") and "tensorflow" in model.__class__.__module__:
|
|
397
|
-
return ModelType.TENSORFLOW
|
|
500
|
+
return ModelType.TENSORFLOW, None
|
|
398
501
|
except ImportError:
|
|
399
502
|
pass
|
|
400
503
|
|
|
401
|
-
return None
|
|
504
|
+
return None, None
|
chalk/operators/_utils.py
CHANGED
|
@@ -9,6 +9,7 @@ import pyarrow
|
|
|
9
9
|
from chalk import DataFrame, Features, StaticOperator
|
|
10
10
|
from chalk._gen.chalk.expression.v1 import expression_pb2 as expr_pb
|
|
11
11
|
from chalk.client import ChalkError, ChalkException, ErrorCode, ErrorCodeCategory
|
|
12
|
+
from chalk.df.LazyFramePlaceholder import LazyFramePlaceholder
|
|
12
13
|
from chalk.features.feature_field import Feature
|
|
13
14
|
|
|
14
15
|
|
|
@@ -79,7 +80,7 @@ def static_resolver_to_operator(
|
|
|
79
80
|
fn: Callable,
|
|
80
81
|
inputs: Sequence[Union[Feature, type[DataFrame]]],
|
|
81
82
|
output: Optional[type[Features]],
|
|
82
|
-
) -> StaticOperator | DfPlaceholder | ChalkDataFrame:
|
|
83
|
+
) -> StaticOperator | DfPlaceholder | ChalkDataFrame | LazyFramePlaceholder:
|
|
83
84
|
if output is None:
|
|
84
85
|
raise _GetStaticOperatorError(
|
|
85
86
|
resolver_fqn=fqn,
|
|
@@ -96,8 +97,14 @@ def static_resolver_to_operator(
|
|
|
96
97
|
message="Static resolver must take no arguments and have exactly one DataFrame output",
|
|
97
98
|
underlying_exception=None,
|
|
98
99
|
)
|
|
100
|
+
|
|
99
101
|
try:
|
|
100
|
-
placeholder_inputs = [
|
|
102
|
+
placeholder_inputs = [
|
|
103
|
+
LazyFramePlaceholder.named_table(
|
|
104
|
+
name=f"resolver_df_input_{input_index}", schema=pyarrow.schema(schema_for_input(input_type))
|
|
105
|
+
)
|
|
106
|
+
for input_index, input_type in enumerate(inputs)
|
|
107
|
+
]
|
|
101
108
|
static_operator = fn(*placeholder_inputs)
|
|
102
109
|
except Exception as e:
|
|
103
110
|
# Weird hacky way to return a placeholder even if the resolver fails.
|
|
@@ -108,9 +115,13 @@ def static_resolver_to_operator(
|
|
|
108
115
|
)
|
|
109
116
|
else:
|
|
110
117
|
if (
|
|
111
|
-
not isinstance(static_operator, (StaticOperator, DfPlaceholder))
|
|
118
|
+
not isinstance(static_operator, (StaticOperator, DfPlaceholder, LazyFramePlaceholder))
|
|
112
119
|
and not static_operator.__class__.__name__ == "ChalkDataFrame"
|
|
113
120
|
and not static_operator.__class__.__name__ == "LazyFrame"
|
|
121
|
+
and not (
|
|
122
|
+
static_operator.__class__.__name__ == "DataFrame"
|
|
123
|
+
and static_operator.__class__.__module__ == "chalkdf.dataframe"
|
|
124
|
+
)
|
|
114
125
|
):
|
|
115
126
|
raise _GetStaticOperatorError(
|
|
116
127
|
resolver_fqn=fqn,
|
chalk/parsed/_proto/export.py
CHANGED
|
@@ -27,6 +27,7 @@ from chalk.parsed._proto.utils import (
|
|
|
27
27
|
convert_failed_import_to_gql,
|
|
28
28
|
convert_failed_import_to_proto,
|
|
29
29
|
datetime_to_proto_timestamp,
|
|
30
|
+
timedelta_to_proto_duration,
|
|
30
31
|
)
|
|
31
32
|
from chalk.parsed._proto.validation import validate_artifacts
|
|
32
33
|
from chalk.parsed.to_proto import ToProtoConverter
|
|
@@ -145,6 +146,24 @@ def export_from_registry() -> export_pb.Export:
|
|
|
145
146
|
"""
|
|
146
147
|
failed_protos: List[export_pb.FailedImport] = []
|
|
147
148
|
|
|
149
|
+
# Validate registries BEFORE conversion to catch errors early
|
|
150
|
+
# This ensures parity with GQL validation path
|
|
151
|
+
from chalk.parsed.validation_from_registries import validate_all_from_registries
|
|
152
|
+
|
|
153
|
+
try:
|
|
154
|
+
validate_all_from_registries(
|
|
155
|
+
features_registry=FeatureSetBase.registry,
|
|
156
|
+
resolver_registry=RESOLVER_REGISTRY,
|
|
157
|
+
)
|
|
158
|
+
except Exception as e:
|
|
159
|
+
# If validation fails, add to failed_protos but continue
|
|
160
|
+
# to allow other validation to complete
|
|
161
|
+
from chalk._lsp.error_builder import LSPErrorBuilder
|
|
162
|
+
|
|
163
|
+
if not LSPErrorBuilder.promote_exception(e):
|
|
164
|
+
# Not an LSP error, so log it as a failed import
|
|
165
|
+
failed_protos.append(build_failed_import(e, "validation"))
|
|
166
|
+
|
|
148
167
|
graph_res = ToProtoConverter.convert_graph(
|
|
149
168
|
features_registry=FeatureSetBase.registry,
|
|
150
169
|
resolver_registry=RESOLVER_REGISTRY.get_all_resolvers(),
|
|
@@ -193,6 +212,9 @@ def export_from_registry() -> export_pb.Export:
|
|
|
193
212
|
file_name=cron.filename,
|
|
194
213
|
resource_group=cron.resource_group,
|
|
195
214
|
planner_options=cron.planner_options,
|
|
215
|
+
completion_deadline=timedelta_to_proto_duration(cron.completion_deadline)
|
|
216
|
+
if cron.completion_deadline is not None
|
|
217
|
+
else cron.completion_deadline,
|
|
196
218
|
)
|
|
197
219
|
)
|
|
198
220
|
|
|
@@ -271,6 +271,7 @@ class UpsertCronQueryGQL:
|
|
|
271
271
|
upperBound: Optional[datetime] # deprecated: can't use datetime
|
|
272
272
|
tags: Optional[List[str]]
|
|
273
273
|
requiredResolverTags: Optional[List[str]]
|
|
274
|
+
datasetName: Optional[str] = None
|
|
274
275
|
storeOnline: Optional[bool] = True # None = True
|
|
275
276
|
storeOffline: Optional[bool] = True # None = True
|
|
276
277
|
incrementalSources: Optional[List[str]] = None
|
|
@@ -278,6 +279,9 @@ class UpsertCronQueryGQL:
|
|
|
278
279
|
upperBoundStr: Optional[str] = None
|
|
279
280
|
resourceGroup: Optional[str] = None
|
|
280
281
|
plannerOptions: Optional[Dict[str, str]] = None
|
|
282
|
+
completionDeadline: Optional[str] = None
|
|
283
|
+
numShards: Optional[int] = None
|
|
284
|
+
numWorkers: Optional[int] = None
|
|
281
285
|
|
|
282
286
|
|
|
283
287
|
@dataclasses_json.dataclass_json
|
chalk/parsed/expressions.py
CHANGED
chalk/parsed/json_conversions.py
CHANGED
|
@@ -418,6 +418,7 @@ def convert_type_to_gql(
|
|
|
418
418
|
),
|
|
419
419
|
lowerBound=None,
|
|
420
420
|
upperBound=None,
|
|
421
|
+
datasetName=t.dataset_name,
|
|
421
422
|
lowerBoundStr=datetime.isoformat(t.lower_bound) if t.lower_bound is not None else None,
|
|
422
423
|
upperBoundStr=datetime.isoformat(t.upper_bound) if t.upper_bound is not None else None,
|
|
423
424
|
tags=list(t.tags) if t.tags is not None else None,
|
|
@@ -427,6 +428,9 @@ def convert_type_to_gql(
|
|
|
427
428
|
incrementalSources=None if t.incremental_resolvers is None else list(t.incremental_resolvers),
|
|
428
429
|
resourceGroup=t.resource_group,
|
|
429
430
|
plannerOptions=t.planner_options,
|
|
431
|
+
completionDeadline=None if t.completion_deadline is None else timedelta_to_duration(t.completion_deadline),
|
|
432
|
+
numShards=t.num_shards,
|
|
433
|
+
numWorkers=t.num_workers,
|
|
430
434
|
)
|
|
431
435
|
|
|
432
436
|
if isinstance(t, NamedQuery):
|
|
@@ -538,20 +542,23 @@ def convert_type_to_gql(
|
|
|
538
542
|
)
|
|
539
543
|
|
|
540
544
|
elif t.join is not None:
|
|
541
|
-
#
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
545
|
+
# Check if user tried to use DataFrame (even if validation failed)
|
|
546
|
+
# Use is_dataframe_annotation() to detect DataFrame types without triggering validation errors
|
|
547
|
+
if not t.typ.is_dataframe_annotation():
|
|
548
|
+
# If a has_one/has_many has an incorrect type annotation
|
|
549
|
+
builder = t.features_cls.__chalk_error_builder__
|
|
550
|
+
builder.add_diagnostic(
|
|
551
|
+
range=builder.annotation_range(t.attribute_name),
|
|
552
|
+
message=(
|
|
553
|
+
f"The attribute '{t.features_cls.__name__}.{t.attribute_name}' "
|
|
554
|
+
f"has a join filter ({t.join}) but its type annotation is not a feature class or "
|
|
555
|
+
f"DataFrame ({t.typ})."
|
|
556
|
+
),
|
|
557
|
+
label="Incorrect join type annotation",
|
|
558
|
+
raise_error=TypeError,
|
|
559
|
+
code="34",
|
|
560
|
+
code_href="https://docs.chalk.ai/docs/has-many",
|
|
561
|
+
)
|
|
555
562
|
|
|
556
563
|
elif t.is_feature_time:
|
|
557
564
|
feature_time_kind_gql = UpsertFeatureTimeKindGQL()
|
chalk/parsed/to_proto.py
CHANGED
|
@@ -21,6 +21,7 @@ from chalk._gen.chalk.graph.v1 import graph_pb2 as pb
|
|
|
21
21
|
from chalk._gen.chalk.graph.v2 import sources_pb2 as sources_pb
|
|
22
22
|
from chalk._gen.chalk.lsp.v1.lsp_pb2 import Location, Position, Range
|
|
23
23
|
from chalk._validation.feature_validation import FeatureValidation
|
|
24
|
+
from chalk.df.LazyFramePlaceholder import LazyFramePlaceholder
|
|
24
25
|
from chalk.features import (
|
|
25
26
|
CacheStrategy,
|
|
26
27
|
Feature,
|
|
@@ -899,7 +900,7 @@ class ToProtoConverter:
|
|
|
899
900
|
else None,
|
|
900
901
|
backfill_schedule=mat.backfill_schedule,
|
|
901
902
|
approx_top_k_arg_k=aggregation_kwargs.get("k")
|
|
902
|
-
if mat.aggregation in ("approx_top_k", "min_by_n", "max_by_n")
|
|
903
|
+
if mat.aggregation in ("approx_top_k", "approx_percentile", "min_by_n", "max_by_n")
|
|
903
904
|
else None,
|
|
904
905
|
),
|
|
905
906
|
tags=f.tags,
|
|
@@ -995,7 +996,7 @@ class ToProtoConverter:
|
|
|
995
996
|
else None,
|
|
996
997
|
continuous_resolver=wmp.continuous_resolver,
|
|
997
998
|
approx_top_k_arg_k=aggregation_kwargs.get("k")
|
|
998
|
-
if wmp.aggregation in ("approx_top_k", "min_by_n", "max_by_n")
|
|
999
|
+
if wmp.aggregation in ("approx_top_k", "approx_percentile", "min_by_n", "max_by_n")
|
|
999
1000
|
else None,
|
|
1000
1001
|
)
|
|
1001
1002
|
if wmp is not None
|
|
@@ -1025,6 +1026,9 @@ class ToProtoConverter:
|
|
|
1025
1026
|
expression=ToProtoConverter.convert_underscore(f.underscore_expression)
|
|
1026
1027
|
if f.underscore_expression is not None
|
|
1027
1028
|
else None,
|
|
1029
|
+
offline_expression=ToProtoConverter.convert_underscore(f.offline_underscore_expression)
|
|
1030
|
+
if f.offline_underscore_expression is not None
|
|
1031
|
+
else None,
|
|
1028
1032
|
expression_definition_location=ToProtoConverter.convert_expression_definition_location(
|
|
1029
1033
|
f.underscore_expression
|
|
1030
1034
|
)
|
|
@@ -1147,9 +1151,13 @@ class ToProtoConverter:
|
|
|
1147
1151
|
raise ValueError(f"Unsupported resource hint: {r.resource_hint}")
|
|
1148
1152
|
|
|
1149
1153
|
static_operation = None
|
|
1154
|
+
static_operation_dataframe = None
|
|
1150
1155
|
if r.static:
|
|
1151
1156
|
static_operator = static_resolver_to_operator(fqn=r.fqn, fn=r.fn, inputs=r.inputs, output=r.output)
|
|
1152
|
-
|
|
1157
|
+
if isinstance(static_operator, LazyFramePlaceholder):
|
|
1158
|
+
static_operation_dataframe = static_operator._to_proto() # pyright: ignore[reportPrivateUsage]
|
|
1159
|
+
else:
|
|
1160
|
+
static_operation = static_operator._to_proto() # pyright: ignore[reportPrivateUsage]
|
|
1153
1161
|
|
|
1154
1162
|
function_reference_proto = ToProtoConverter.create_function_reference(
|
|
1155
1163
|
r.fn,
|
|
@@ -1158,7 +1166,9 @@ class ToProtoConverter:
|
|
|
1158
1166
|
filename=r.filename,
|
|
1159
1167
|
source_line=r.source_line,
|
|
1160
1168
|
)
|
|
1161
|
-
|
|
1169
|
+
postprocessing_underscore_expr: expr_pb.LogicalExprNode | None = None
|
|
1170
|
+
if isinstance(r.postprocessing, Underscore):
|
|
1171
|
+
postprocessing_underscore_expr = r.postprocessing._to_proto() # pyright: ignore[reportPrivateUsage]
|
|
1162
1172
|
return pb.Resolver(
|
|
1163
1173
|
fqn=r.fqn,
|
|
1164
1174
|
kind=(
|
|
@@ -1186,9 +1196,11 @@ class ToProtoConverter:
|
|
|
1186
1196
|
unique_on=tuple(x.root_fqn for x in r.unique_on) if r.unique_on is not None else (),
|
|
1187
1197
|
partitioned_by=(x.root_fqn for x in r.partitioned_by) if r.partitioned_by is not None else (),
|
|
1188
1198
|
static_operation=static_operation,
|
|
1199
|
+
static_operation_dataframe=static_operation_dataframe,
|
|
1189
1200
|
sql_settings=ToProtoConverter.convert_sql_settings(r.sql_settings) if r.sql_settings else None,
|
|
1190
1201
|
output_row_order=r.output_row_order,
|
|
1191
1202
|
venv=r.venv,
|
|
1203
|
+
underscore_expr=postprocessing_underscore_expr,
|
|
1192
1204
|
)
|
|
1193
1205
|
|
|
1194
1206
|
@staticmethod
|