chalkpy 2.89.22__py3-none-any.whl → 2.95.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- chalk/__init__.py +2 -1
- chalk/_gen/chalk/arrow/v1/arrow_pb2.py +7 -5
- chalk/_gen/chalk/arrow/v1/arrow_pb2.pyi +6 -0
- chalk/_gen/chalk/artifacts/v1/chart_pb2.py +36 -33
- chalk/_gen/chalk/artifacts/v1/chart_pb2.pyi +41 -1
- chalk/_gen/chalk/artifacts/v1/cron_query_pb2.py +8 -7
- chalk/_gen/chalk/artifacts/v1/cron_query_pb2.pyi +5 -0
- chalk/_gen/chalk/common/v1/offline_query_pb2.py +19 -13
- chalk/_gen/chalk/common/v1/offline_query_pb2.pyi +37 -0
- chalk/_gen/chalk/common/v1/online_query_pb2.py +54 -54
- chalk/_gen/chalk/common/v1/online_query_pb2.pyi +13 -1
- chalk/_gen/chalk/common/v1/script_task_pb2.py +13 -11
- chalk/_gen/chalk/common/v1/script_task_pb2.pyi +19 -1
- chalk/_gen/chalk/dataframe/__init__.py +0 -0
- chalk/_gen/chalk/dataframe/v1/__init__.py +0 -0
- chalk/_gen/chalk/dataframe/v1/dataframe_pb2.py +48 -0
- chalk/_gen/chalk/dataframe/v1/dataframe_pb2.pyi +123 -0
- chalk/_gen/chalk/dataframe/v1/dataframe_pb2_grpc.py +4 -0
- chalk/_gen/chalk/dataframe/v1/dataframe_pb2_grpc.pyi +4 -0
- chalk/_gen/chalk/graph/v1/graph_pb2.py +150 -149
- chalk/_gen/chalk/graph/v1/graph_pb2.pyi +25 -0
- chalk/_gen/chalk/graph/v1/sources_pb2.py +94 -84
- chalk/_gen/chalk/graph/v1/sources_pb2.pyi +56 -0
- chalk/_gen/chalk/kubernetes/v1/horizontalpodautoscaler_pb2.py +79 -0
- chalk/_gen/chalk/kubernetes/v1/horizontalpodautoscaler_pb2.pyi +377 -0
- chalk/_gen/chalk/kubernetes/v1/horizontalpodautoscaler_pb2_grpc.py +4 -0
- chalk/_gen/chalk/kubernetes/v1/horizontalpodautoscaler_pb2_grpc.pyi +4 -0
- chalk/_gen/chalk/kubernetes/v1/scaledobject_pb2.py +43 -7
- chalk/_gen/chalk/kubernetes/v1/scaledobject_pb2.pyi +252 -2
- chalk/_gen/chalk/protosql/v1/sql_service_pb2.py +54 -27
- chalk/_gen/chalk/protosql/v1/sql_service_pb2.pyi +131 -3
- chalk/_gen/chalk/protosql/v1/sql_service_pb2_grpc.py +45 -0
- chalk/_gen/chalk/protosql/v1/sql_service_pb2_grpc.pyi +14 -0
- chalk/_gen/chalk/python/v1/types_pb2.py +14 -14
- chalk/_gen/chalk/python/v1/types_pb2.pyi +8 -0
- chalk/_gen/chalk/server/v1/benchmark_pb2.py +76 -0
- chalk/_gen/chalk/server/v1/benchmark_pb2.pyi +156 -0
- chalk/_gen/chalk/server/v1/benchmark_pb2_grpc.py +258 -0
- chalk/_gen/chalk/server/v1/benchmark_pb2_grpc.pyi +84 -0
- chalk/_gen/chalk/server/v1/billing_pb2.py +40 -38
- chalk/_gen/chalk/server/v1/billing_pb2.pyi +17 -1
- chalk/_gen/chalk/server/v1/branches_pb2.py +45 -0
- chalk/_gen/chalk/server/v1/branches_pb2.pyi +80 -0
- chalk/_gen/chalk/server/v1/branches_pb2_grpc.pyi +36 -0
- chalk/_gen/chalk/server/v1/builder_pb2.py +372 -272
- chalk/_gen/chalk/server/v1/builder_pb2.pyi +479 -12
- chalk/_gen/chalk/server/v1/builder_pb2_grpc.py +360 -0
- chalk/_gen/chalk/server/v1/builder_pb2_grpc.pyi +96 -0
- chalk/_gen/chalk/server/v1/chart_pb2.py +10 -10
- chalk/_gen/chalk/server/v1/chart_pb2.pyi +18 -2
- chalk/_gen/chalk/server/v1/clickhouse_pb2.py +42 -0
- chalk/_gen/chalk/server/v1/clickhouse_pb2.pyi +17 -0
- chalk/_gen/chalk/server/v1/clickhouse_pb2_grpc.py +78 -0
- chalk/_gen/chalk/server/v1/clickhouse_pb2_grpc.pyi +38 -0
- chalk/_gen/chalk/server/v1/cloud_components_pb2.py +153 -107
- chalk/_gen/chalk/server/v1/cloud_components_pb2.pyi +146 -4
- chalk/_gen/chalk/server/v1/cloud_components_pb2_grpc.py +180 -0
- chalk/_gen/chalk/server/v1/cloud_components_pb2_grpc.pyi +48 -0
- chalk/_gen/chalk/server/v1/cloud_credentials_pb2.py +11 -3
- chalk/_gen/chalk/server/v1/cloud_credentials_pb2.pyi +20 -0
- chalk/_gen/chalk/server/v1/cloud_credentials_pb2_grpc.py +45 -0
- chalk/_gen/chalk/server/v1/cloud_credentials_pb2_grpc.pyi +12 -0
- chalk/_gen/chalk/server/v1/dataplanejobqueue_pb2.py +59 -35
- chalk/_gen/chalk/server/v1/dataplanejobqueue_pb2.pyi +127 -1
- chalk/_gen/chalk/server/v1/dataplanejobqueue_pb2_grpc.py +135 -0
- chalk/_gen/chalk/server/v1/dataplanejobqueue_pb2_grpc.pyi +36 -0
- chalk/_gen/chalk/server/v1/dataplaneworkflows_pb2.py +90 -0
- chalk/_gen/chalk/server/v1/dataplaneworkflows_pb2.pyi +264 -0
- chalk/_gen/chalk/server/v1/dataplaneworkflows_pb2_grpc.py +170 -0
- chalk/_gen/chalk/server/v1/dataplaneworkflows_pb2_grpc.pyi +62 -0
- chalk/_gen/chalk/server/v1/datasets_pb2.py +36 -24
- chalk/_gen/chalk/server/v1/datasets_pb2.pyi +71 -2
- chalk/_gen/chalk/server/v1/datasets_pb2_grpc.py +45 -0
- chalk/_gen/chalk/server/v1/datasets_pb2_grpc.pyi +12 -0
- chalk/_gen/chalk/server/v1/deploy_pb2.py +9 -3
- chalk/_gen/chalk/server/v1/deploy_pb2.pyi +12 -0
- chalk/_gen/chalk/server/v1/deploy_pb2_grpc.py +45 -0
- chalk/_gen/chalk/server/v1/deploy_pb2_grpc.pyi +12 -0
- chalk/_gen/chalk/server/v1/deployment_pb2.py +20 -15
- chalk/_gen/chalk/server/v1/deployment_pb2.pyi +25 -0
- chalk/_gen/chalk/server/v1/environment_pb2.py +25 -15
- chalk/_gen/chalk/server/v1/environment_pb2.pyi +93 -1
- chalk/_gen/chalk/server/v1/eventbus_pb2.py +44 -0
- chalk/_gen/chalk/server/v1/eventbus_pb2.pyi +64 -0
- chalk/_gen/chalk/server/v1/eventbus_pb2_grpc.py +4 -0
- chalk/_gen/chalk/server/v1/eventbus_pb2_grpc.pyi +4 -0
- chalk/_gen/chalk/server/v1/files_pb2.py +65 -0
- chalk/_gen/chalk/server/v1/files_pb2.pyi +167 -0
- chalk/_gen/chalk/server/v1/files_pb2_grpc.py +4 -0
- chalk/_gen/chalk/server/v1/files_pb2_grpc.pyi +4 -0
- chalk/_gen/chalk/server/v1/graph_pb2.py +41 -3
- chalk/_gen/chalk/server/v1/graph_pb2.pyi +191 -0
- chalk/_gen/chalk/server/v1/graph_pb2_grpc.py +92 -0
- chalk/_gen/chalk/server/v1/graph_pb2_grpc.pyi +32 -0
- chalk/_gen/chalk/server/v1/incident_pb2.py +57 -0
- chalk/_gen/chalk/server/v1/incident_pb2.pyi +165 -0
- chalk/_gen/chalk/server/v1/incident_pb2_grpc.py +4 -0
- chalk/_gen/chalk/server/v1/incident_pb2_grpc.pyi +4 -0
- chalk/_gen/chalk/server/v1/indexing_job_pb2.py +44 -0
- chalk/_gen/chalk/server/v1/indexing_job_pb2.pyi +38 -0
- chalk/_gen/chalk/server/v1/indexing_job_pb2_grpc.py +78 -0
- chalk/_gen/chalk/server/v1/indexing_job_pb2_grpc.pyi +38 -0
- chalk/_gen/chalk/server/v1/integrations_pb2.py +11 -9
- chalk/_gen/chalk/server/v1/integrations_pb2.pyi +34 -2
- chalk/_gen/chalk/server/v1/kube_pb2.py +29 -19
- chalk/_gen/chalk/server/v1/kube_pb2.pyi +28 -0
- chalk/_gen/chalk/server/v1/kube_pb2_grpc.py +45 -0
- chalk/_gen/chalk/server/v1/kube_pb2_grpc.pyi +12 -0
- chalk/_gen/chalk/server/v1/log_pb2.py +21 -3
- chalk/_gen/chalk/server/v1/log_pb2.pyi +68 -0
- chalk/_gen/chalk/server/v1/log_pb2_grpc.py +90 -0
- chalk/_gen/chalk/server/v1/log_pb2_grpc.pyi +24 -0
- chalk/_gen/chalk/server/v1/metadataplanejobqueue_pb2.py +73 -0
- chalk/_gen/chalk/server/v1/metadataplanejobqueue_pb2.pyi +212 -0
- chalk/_gen/chalk/server/v1/metadataplanejobqueue_pb2_grpc.py +217 -0
- chalk/_gen/chalk/server/v1/metadataplanejobqueue_pb2_grpc.pyi +74 -0
- chalk/_gen/chalk/server/v1/model_registry_pb2.py +10 -10
- chalk/_gen/chalk/server/v1/model_registry_pb2.pyi +4 -1
- chalk/_gen/chalk/server/v1/monitoring_pb2.py +84 -75
- chalk/_gen/chalk/server/v1/monitoring_pb2.pyi +1 -0
- chalk/_gen/chalk/server/v1/monitoring_pb2_grpc.py +136 -0
- chalk/_gen/chalk/server/v1/monitoring_pb2_grpc.pyi +38 -0
- chalk/_gen/chalk/server/v1/offline_queries_pb2.py +32 -10
- chalk/_gen/chalk/server/v1/offline_queries_pb2.pyi +73 -0
- chalk/_gen/chalk/server/v1/offline_queries_pb2_grpc.py +90 -0
- chalk/_gen/chalk/server/v1/offline_queries_pb2_grpc.pyi +24 -0
- chalk/_gen/chalk/server/v1/plandebug_pb2.py +53 -0
- chalk/_gen/chalk/server/v1/plandebug_pb2.pyi +86 -0
- chalk/_gen/chalk/server/v1/plandebug_pb2_grpc.py +168 -0
- chalk/_gen/chalk/server/v1/plandebug_pb2_grpc.pyi +60 -0
- chalk/_gen/chalk/server/v1/queries_pb2.py +76 -48
- chalk/_gen/chalk/server/v1/queries_pb2.pyi +155 -2
- chalk/_gen/chalk/server/v1/queries_pb2_grpc.py +180 -0
- chalk/_gen/chalk/server/v1/queries_pb2_grpc.pyi +48 -0
- chalk/_gen/chalk/server/v1/scheduled_query_pb2.py +4 -2
- chalk/_gen/chalk/server/v1/scheduled_query_pb2_grpc.py +45 -0
- chalk/_gen/chalk/server/v1/scheduled_query_pb2_grpc.pyi +12 -0
- chalk/_gen/chalk/server/v1/scheduled_query_run_pb2.py +12 -6
- chalk/_gen/chalk/server/v1/scheduled_query_run_pb2.pyi +75 -2
- chalk/_gen/chalk/server/v1/scheduler_pb2.py +24 -12
- chalk/_gen/chalk/server/v1/scheduler_pb2.pyi +61 -1
- chalk/_gen/chalk/server/v1/scheduler_pb2_grpc.py +90 -0
- chalk/_gen/chalk/server/v1/scheduler_pb2_grpc.pyi +24 -0
- chalk/_gen/chalk/server/v1/script_tasks_pb2.py +26 -14
- chalk/_gen/chalk/server/v1/script_tasks_pb2.pyi +33 -3
- chalk/_gen/chalk/server/v1/script_tasks_pb2_grpc.py +90 -0
- chalk/_gen/chalk/server/v1/script_tasks_pb2_grpc.pyi +24 -0
- chalk/_gen/chalk/server/v1/sql_interface_pb2.py +75 -0
- chalk/_gen/chalk/server/v1/sql_interface_pb2.pyi +142 -0
- chalk/_gen/chalk/server/v1/sql_interface_pb2_grpc.py +349 -0
- chalk/_gen/chalk/server/v1/sql_interface_pb2_grpc.pyi +114 -0
- chalk/_gen/chalk/server/v1/sql_queries_pb2.py +48 -0
- chalk/_gen/chalk/server/v1/sql_queries_pb2.pyi +150 -0
- chalk/_gen/chalk/server/v1/sql_queries_pb2_grpc.py +123 -0
- chalk/_gen/chalk/server/v1/sql_queries_pb2_grpc.pyi +52 -0
- chalk/_gen/chalk/server/v1/team_pb2.py +156 -137
- chalk/_gen/chalk/server/v1/team_pb2.pyi +56 -10
- chalk/_gen/chalk/server/v1/team_pb2_grpc.py +90 -0
- chalk/_gen/chalk/server/v1/team_pb2_grpc.pyi +24 -0
- chalk/_gen/chalk/server/v1/topic_pb2.py +5 -3
- chalk/_gen/chalk/server/v1/topic_pb2.pyi +10 -1
- chalk/_gen/chalk/server/v1/trace_pb2.py +50 -28
- chalk/_gen/chalk/server/v1/trace_pb2.pyi +121 -0
- chalk/_gen/chalk/server/v1/trace_pb2_grpc.py +135 -0
- chalk/_gen/chalk/server/v1/trace_pb2_grpc.pyi +42 -0
- chalk/_gen/chalk/server/v1/webhook_pb2.py +9 -3
- chalk/_gen/chalk/server/v1/webhook_pb2.pyi +18 -0
- chalk/_gen/chalk/server/v1/webhook_pb2_grpc.py +45 -0
- chalk/_gen/chalk/server/v1/webhook_pb2_grpc.pyi +12 -0
- chalk/_gen/chalk/streaming/v1/debug_service_pb2.py +62 -0
- chalk/_gen/chalk/streaming/v1/debug_service_pb2.pyi +75 -0
- chalk/_gen/chalk/streaming/v1/debug_service_pb2_grpc.py +221 -0
- chalk/_gen/chalk/streaming/v1/debug_service_pb2_grpc.pyi +88 -0
- chalk/_gen/chalk/streaming/v1/simple_streaming_service_pb2.py +19 -7
- chalk/_gen/chalk/streaming/v1/simple_streaming_service_pb2.pyi +96 -3
- chalk/_gen/chalk/streaming/v1/simple_streaming_service_pb2_grpc.py +48 -0
- chalk/_gen/chalk/streaming/v1/simple_streaming_service_pb2_grpc.pyi +20 -0
- chalk/_gen/chalk/utils/v1/field_change_pb2.py +32 -0
- chalk/_gen/chalk/utils/v1/field_change_pb2.pyi +42 -0
- chalk/_gen/chalk/utils/v1/field_change_pb2_grpc.py +4 -0
- chalk/_gen/chalk/utils/v1/field_change_pb2_grpc.pyi +4 -0
- chalk/_lsp/error_builder.py +11 -0
- chalk/_monitoring/Chart.py +1 -3
- chalk/_version.py +1 -1
- chalk/cli.py +5 -10
- chalk/client/client.py +178 -64
- chalk/client/client_async.py +154 -0
- chalk/client/client_async_impl.py +22 -0
- chalk/client/client_grpc.py +738 -112
- chalk/client/client_impl.py +541 -136
- chalk/client/dataset.py +27 -6
- chalk/client/models.py +99 -2
- chalk/client/serialization/model_serialization.py +126 -10
- chalk/config/project_config.py +1 -1
- chalk/df/LazyFramePlaceholder.py +1154 -0
- chalk/df/ast_parser.py +2 -10
- chalk/features/_class_property.py +7 -0
- chalk/features/_embedding/embedding.py +1 -0
- chalk/features/_embedding/sentence_transformer.py +1 -1
- chalk/features/_encoding/converter.py +83 -2
- chalk/features/_encoding/pyarrow.py +20 -4
- chalk/features/_encoding/rich.py +1 -3
- chalk/features/_tensor.py +1 -2
- chalk/features/dataframe/_filters.py +14 -5
- chalk/features/dataframe/_impl.py +91 -36
- chalk/features/dataframe/_validation.py +11 -7
- chalk/features/feature_field.py +40 -30
- chalk/features/feature_set.py +1 -2
- chalk/features/feature_set_decorator.py +1 -0
- chalk/features/feature_wrapper.py +42 -3
- chalk/features/hooks.py +81 -12
- chalk/features/inference.py +65 -10
- chalk/features/resolver.py +338 -56
- chalk/features/tag.py +1 -3
- chalk/features/underscore_features.py +2 -1
- chalk/functions/__init__.py +456 -21
- chalk/functions/holidays.py +1 -3
- chalk/gitignore/gitignore_parser.py +5 -1
- chalk/importer.py +186 -74
- chalk/ml/__init__.py +6 -2
- chalk/ml/model_hooks.py +368 -51
- chalk/ml/model_reference.py +68 -10
- chalk/ml/model_version.py +34 -21
- chalk/ml/utils.py +143 -40
- chalk/operators/_utils.py +14 -3
- chalk/parsed/_proto/export.py +22 -0
- chalk/parsed/duplicate_input_gql.py +4 -0
- chalk/parsed/expressions.py +1 -3
- chalk/parsed/json_conversions.py +21 -14
- chalk/parsed/to_proto.py +16 -4
- chalk/parsed/user_types_to_json.py +31 -10
- chalk/parsed/validation_from_registries.py +182 -0
- chalk/queries/named_query.py +16 -6
- chalk/queries/scheduled_query.py +13 -1
- chalk/serialization/parsed_annotation.py +25 -12
- chalk/sql/__init__.py +221 -0
- chalk/sql/_internal/integrations/athena.py +6 -1
- chalk/sql/_internal/integrations/bigquery.py +22 -2
- chalk/sql/_internal/integrations/databricks.py +61 -18
- chalk/sql/_internal/integrations/mssql.py +281 -0
- chalk/sql/_internal/integrations/postgres.py +11 -3
- chalk/sql/_internal/integrations/redshift.py +4 -0
- chalk/sql/_internal/integrations/snowflake.py +11 -2
- chalk/sql/_internal/integrations/util.py +2 -1
- chalk/sql/_internal/sql_file_resolver.py +55 -10
- chalk/sql/_internal/sql_source.py +36 -2
- chalk/streams/__init__.py +1 -3
- chalk/streams/_kafka_source.py +5 -1
- chalk/streams/_windows.py +16 -4
- chalk/streams/types.py +1 -2
- chalk/utils/__init__.py +1 -3
- chalk/utils/_otel_version.py +13 -0
- chalk/utils/async_helpers.py +14 -5
- chalk/utils/df_utils.py +2 -2
- chalk/utils/duration.py +1 -3
- chalk/utils/job_log_display.py +538 -0
- chalk/utils/missing_dependency.py +5 -4
- chalk/utils/notebook.py +255 -2
- chalk/utils/pl_helpers.py +190 -37
- chalk/utils/pydanticutil/pydantic_compat.py +1 -2
- chalk/utils/storage_client.py +246 -0
- chalk/utils/threading.py +1 -3
- chalk/utils/tracing.py +194 -86
- {chalkpy-2.89.22.dist-info → chalkpy-2.95.3.dist-info}/METADATA +53 -21
- {chalkpy-2.89.22.dist-info → chalkpy-2.95.3.dist-info}/RECORD +268 -198
- {chalkpy-2.89.22.dist-info → chalkpy-2.95.3.dist-info}/WHEEL +0 -0
- {chalkpy-2.89.22.dist-info → chalkpy-2.95.3.dist-info}/entry_points.txt +0 -0
- {chalkpy-2.89.22.dist-info → chalkpy-2.95.3.dist-info}/top_level.txt +0 -0
chalk/features/feature_field.py
CHANGED
|
@@ -4,11 +4,12 @@ import copy
|
|
|
4
4
|
import dataclasses
|
|
5
5
|
import functools
|
|
6
6
|
import inspect
|
|
7
|
+
import itertools
|
|
7
8
|
import os
|
|
8
9
|
import re
|
|
9
10
|
import weakref
|
|
10
11
|
from collections.abc import Mapping, MutableMapping
|
|
11
|
-
from datetime import
|
|
12
|
+
from datetime import datetime, timedelta
|
|
12
13
|
from typing import (
|
|
13
14
|
TYPE_CHECKING,
|
|
14
15
|
Any,
|
|
@@ -28,7 +29,6 @@ from typing import (
|
|
|
28
29
|
cast,
|
|
29
30
|
)
|
|
30
31
|
|
|
31
|
-
import itertools
|
|
32
32
|
import numpy as np
|
|
33
33
|
import pyarrow as pa
|
|
34
34
|
|
|
@@ -39,11 +39,11 @@ from chalk.features._encoding.converter import FeatureConverter, JSONCodec, TDec
|
|
|
39
39
|
from chalk.features._encoding.primitive import TPrimitive
|
|
40
40
|
from chalk.features.feature_set import CURRENT_FEATURE_REGISTRY, FeatureRegistryProtocol
|
|
41
41
|
from chalk.features.feature_wrapper import FeatureWrapper, NearestNeighborException
|
|
42
|
-
from chalk.features.filter import Filter, TimeDelta
|
|
42
|
+
from chalk.features.filter import ClauseJoinWithAndException, Filter, TimeDelta
|
|
43
43
|
from chalk.features.tag import Tags
|
|
44
44
|
from chalk.features.underscore import Underscore
|
|
45
45
|
from chalk.serialization.parsed_annotation import ParsedAnnotation
|
|
46
|
-
from chalk.utils.collections import
|
|
46
|
+
from chalk.utils.collections import FrozenOrderedSet, OrderedSet, ensure_tuple, get_unique_item
|
|
47
47
|
from chalk.utils.duration import CHALK_MAX_TIMEDELTA, Duration, parse_chalk_duration
|
|
48
48
|
from chalk.utils.import_utils import get_type_checking_imports
|
|
49
49
|
from chalk.utils.json import JSON, pyarrow_json_type
|
|
@@ -54,7 +54,7 @@ from chalk.utils.pydanticutil.pydantic_compat import (
|
|
|
54
54
|
is_pydantic_basemodel_instance,
|
|
55
55
|
parse_pydantic_model,
|
|
56
56
|
)
|
|
57
|
-
from chalk.utils.string import
|
|
57
|
+
from chalk.utils.string import oxford_comma_list, to_snake_case
|
|
58
58
|
|
|
59
59
|
if TYPE_CHECKING:
|
|
60
60
|
from google.protobuf.message import Message as ProtobufMessage
|
|
@@ -68,9 +68,9 @@ _TPrim = TypeVar("_TPrim", bound=TPrimitive)
|
|
|
68
68
|
_logger = get_logger(__name__)
|
|
69
69
|
|
|
70
70
|
from chalk.features.feature_cache_strategy import (
|
|
71
|
-
CacheStrategy,
|
|
72
|
-
CacheNullsType,
|
|
73
71
|
CacheDefaultsType,
|
|
72
|
+
CacheNullsType,
|
|
73
|
+
CacheStrategy,
|
|
74
74
|
get_cache_strategy_from_cache_settings,
|
|
75
75
|
)
|
|
76
76
|
|
|
@@ -252,6 +252,7 @@ class Feature(Generic[_TPrim, _TRich]):
|
|
|
252
252
|
"tags",
|
|
253
253
|
"underlying",
|
|
254
254
|
"underscore_expression",
|
|
255
|
+
"offline_underscore_expression",
|
|
255
256
|
"version",
|
|
256
257
|
"window_duration",
|
|
257
258
|
"window_durations",
|
|
@@ -276,6 +277,7 @@ class Feature(Generic[_TPrim, _TRich]):
|
|
|
276
277
|
primary: bool | None = None,
|
|
277
278
|
default: _TRich | ellipsis = ...,
|
|
278
279
|
underscore_expression: Underscore | None = None,
|
|
280
|
+
offline_underscore_expression: Underscore | None = None,
|
|
279
281
|
max_staleness: Duration | None | ellipsis = ...,
|
|
280
282
|
online_store_max_items: int | None = None,
|
|
281
283
|
cache_strategy: CacheStrategy = CacheStrategy.ALL_WITH_BOTH_UNSET,
|
|
@@ -386,6 +388,7 @@ class Feature(Generic[_TPrim, _TRich]):
|
|
|
386
388
|
self._primary = primary
|
|
387
389
|
self._primary_feature: Optional[Feature] = None
|
|
388
390
|
self.underscore_expression: Underscore | None = underscore_expression
|
|
391
|
+
self.offline_underscore_expression: Underscore | None = offline_underscore_expression
|
|
389
392
|
self.is_distance_pseudofeature = is_distance_pseudofeature
|
|
390
393
|
|
|
391
394
|
self._raw_max_staleness = max_staleness
|
|
@@ -503,7 +506,7 @@ class Feature(Generic[_TPrim, _TRich]):
|
|
|
503
506
|
|
|
504
507
|
@property
|
|
505
508
|
def converter(self) -> FeatureConverter:
|
|
506
|
-
from chalk.features import DataFrame,
|
|
509
|
+
from chalk.features import DataFrame, Tensor, Vector
|
|
507
510
|
|
|
508
511
|
self._converter_entered += 1
|
|
509
512
|
|
|
@@ -923,7 +926,7 @@ class Feature(Generic[_TPrim, _TRich]):
|
|
|
923
926
|
|
|
924
927
|
def __repr__(self):
|
|
925
928
|
try:
|
|
926
|
-
root_fqn=self.root_fqn
|
|
929
|
+
root_fqn = self.root_fqn
|
|
927
930
|
except:
|
|
928
931
|
# self.root_fqn is a property, if it failed then just return the object repr
|
|
929
932
|
return object.__repr__(self)
|
|
@@ -1367,7 +1370,7 @@ class Feature(Generic[_TPrim, _TRich]):
|
|
|
1367
1370
|
),
|
|
1368
1371
|
label="invalid nearest neighbor join",
|
|
1369
1372
|
range=self.lsp_error_builder.property_value_range(self.attribute_name)
|
|
1370
|
-
|
|
1373
|
+
or self.lsp_error_builder.property_range(self.attribute_name),
|
|
1371
1374
|
code="32",
|
|
1372
1375
|
raise_error=TypeError,
|
|
1373
1376
|
)
|
|
@@ -1387,7 +1390,7 @@ class Feature(Generic[_TPrim, _TRich]):
|
|
|
1387
1390
|
),
|
|
1388
1391
|
label="invalid join",
|
|
1389
1392
|
range=self.lsp_error_builder.property_value_range(self.attribute_name)
|
|
1390
|
-
|
|
1393
|
+
or self.lsp_error_builder.property_range(self.attribute_name),
|
|
1391
1394
|
code="32",
|
|
1392
1395
|
raise_error=TypeError,
|
|
1393
1396
|
)
|
|
@@ -1400,7 +1403,7 @@ class Feature(Generic[_TPrim, _TRich]):
|
|
|
1400
1403
|
),
|
|
1401
1404
|
label="invalid join",
|
|
1402
1405
|
range=self.lsp_error_builder.property_value_range(self.attribute_name)
|
|
1403
|
-
|
|
1406
|
+
or self.lsp_error_builder.property_range(self.attribute_name),
|
|
1404
1407
|
code="32",
|
|
1405
1408
|
raise_error=TypeError,
|
|
1406
1409
|
)
|
|
@@ -1424,18 +1427,21 @@ class Feature(Generic[_TPrim, _TRich]):
|
|
|
1424
1427
|
)
|
|
1425
1428
|
|
|
1426
1429
|
if not self.is_has_many and not self.is_has_one:
|
|
1427
|
-
|
|
1428
|
-
|
|
1429
|
-
|
|
1430
|
-
|
|
1431
|
-
|
|
1432
|
-
|
|
1433
|
-
|
|
1434
|
-
|
|
1435
|
-
|
|
1436
|
-
|
|
1437
|
-
|
|
1438
|
-
|
|
1430
|
+
# Check if user tried to use DataFrame (even if validation failed)
|
|
1431
|
+
# Use is_dataframe_annotation() to detect DataFrame types without triggering validation errors
|
|
1432
|
+
if not self.typ.is_dataframe_annotation():
|
|
1433
|
+
assert self.features_cls is not None
|
|
1434
|
+
self.lsp_error_builder.add_diagnostic(
|
|
1435
|
+
message=(
|
|
1436
|
+
f"The attribute '{self.features_cls.__name__}.{self.attribute_name}' "
|
|
1437
|
+
f"has a join filter ({join}) but its type annotation '{self.typ}' is not a feature class or DataFrame that links to another feature class."
|
|
1438
|
+
),
|
|
1439
|
+
label="invalid join",
|
|
1440
|
+
range=self.lsp_error_builder.property_value_range(self.attribute_name)
|
|
1441
|
+
or self.lsp_error_builder.property_range(self.attribute_name),
|
|
1442
|
+
code="37",
|
|
1443
|
+
raise_error=TypeError,
|
|
1444
|
+
)
|
|
1439
1445
|
if self._join_type == "has_one":
|
|
1440
1446
|
if self.is_has_many:
|
|
1441
1447
|
assert self.features_cls is not None
|
|
@@ -1748,6 +1754,7 @@ def feature(
|
|
|
1748
1754
|
default: Union[_TRich, ellipsis] = ...,
|
|
1749
1755
|
underscore: Optional[Underscore] = None, # Deprecated. Prefer `expression`.
|
|
1750
1756
|
expression: Optional[Underscore] = None,
|
|
1757
|
+
offline_expression: Optional[Underscore] = None,
|
|
1751
1758
|
offline_ttl: Optional[Union[ellipsis, Duration]] = ...,
|
|
1752
1759
|
deprecated: bool = False,
|
|
1753
1760
|
store_online: bool = True,
|
|
@@ -1811,6 +1818,8 @@ def feature(
|
|
|
1811
1818
|
... total: int = feature(expression=_.subtotal + _.tax, default=0)
|
|
1812
1819
|
|
|
1813
1820
|
See more at https://docs.chalk.ai/docs/expression
|
|
1821
|
+
offline_expression
|
|
1822
|
+
Defines an alternate expression to compute the feature during offline queries.
|
|
1814
1823
|
dtype
|
|
1815
1824
|
The backing `pyarrow.DataType` for the feature. This parameter can
|
|
1816
1825
|
be used to control the storage format of data. For example, if you
|
|
@@ -2058,10 +2067,7 @@ def feature(
|
|
|
2058
2067
|
if not isinstance(value, Feature): # pyright: ignore[reportUnnecessaryIsInstance]
|
|
2059
2068
|
raise ValueError(f"When `versions` is provided, the values must be features, but `{value}` was given.")
|
|
2060
2069
|
|
|
2061
|
-
cache_strategy = get_cache_strategy_from_cache_settings(
|
|
2062
|
-
cache_nulls=cache_nulls,
|
|
2063
|
-
cache_defaults=cache_defaults
|
|
2064
|
-
)
|
|
2070
|
+
cache_strategy = get_cache_strategy_from_cache_settings(cache_nulls=cache_nulls, cache_defaults=cache_defaults)
|
|
2065
2071
|
|
|
2066
2072
|
return cast(
|
|
2067
2073
|
_TRich,
|
|
@@ -2108,6 +2114,7 @@ def feature(
|
|
|
2108
2114
|
),
|
|
2109
2115
|
default=default,
|
|
2110
2116
|
underscore_expression=expression if expression is not None else underscore,
|
|
2117
|
+
offline_underscore_expression=offline_expression,
|
|
2111
2118
|
offline_ttl=offline_ttl,
|
|
2112
2119
|
is_deprecated=deprecated,
|
|
2113
2120
|
store_online=store_online,
|
|
@@ -2173,7 +2180,7 @@ def has_one(f: Callable[[], Any]) -> Any:
|
|
|
2173
2180
|
def has_many(
|
|
2174
2181
|
f: Callable[[], Any],
|
|
2175
2182
|
max_staleness: Union[Duration, None, ellipsis] = ...,
|
|
2176
|
-
online_store_max_items: int | None = None
|
|
2183
|
+
online_store_max_items: int | None = None,
|
|
2177
2184
|
) -> Any:
|
|
2178
2185
|
"""Specify a feature that represents a one-to-many relationship.
|
|
2179
2186
|
|
|
@@ -2192,6 +2199,7 @@ def has_many(
|
|
|
2192
2199
|
The maximum number of items to cache for the joined feature. The
|
|
2193
2200
|
items in the joined feature aggregate, storing the latest values
|
|
2194
2201
|
of the joined feature for each primary key in the joined feature.
|
|
2202
|
+
|
|
2195
2203
|
Examples
|
|
2196
2204
|
--------
|
|
2197
2205
|
>>> from chalk.features import DataFrame, features, has_many
|
|
@@ -2207,7 +2215,9 @@ def has_many(
|
|
|
2207
2215
|
... lambda: User.id == Card.user_id
|
|
2208
2216
|
... )
|
|
2209
2217
|
"""
|
|
2210
|
-
return Feature(
|
|
2218
|
+
return Feature(
|
|
2219
|
+
join=f, max_staleness=max_staleness, online_store_max_items=online_store_max_items, join_type="has_many"
|
|
2220
|
+
)
|
|
2211
2221
|
|
|
2212
2222
|
|
|
2213
2223
|
__all__ = (
|
chalk/features/feature_set.py
CHANGED
|
@@ -19,11 +19,10 @@ from typing import (
|
|
|
19
19
|
Set,
|
|
20
20
|
Tuple,
|
|
21
21
|
Type,
|
|
22
|
+
TypeGuard,
|
|
22
23
|
cast,
|
|
23
24
|
)
|
|
24
25
|
|
|
25
|
-
from typing_extensions import TypeGuard
|
|
26
|
-
|
|
27
26
|
from chalk._lsp._class_finder import get_class_ast
|
|
28
27
|
from chalk.features.feature_wrapper import FeatureWrapper, unwrap_feature
|
|
29
28
|
from chalk.utils import notebook
|
|
@@ -141,6 +141,7 @@ def features(
|
|
|
141
141
|
The `cache_nulls` and `cache_defaults` options can be used together on the same feature with the
|
|
142
142
|
following exceptions: if `cache_nulls=False`, then `cache_defaults` cannot be `"evict_defaults"`, and if
|
|
143
143
|
`cache_nulls="evict_defaults"`, then `cache_defaults` cannot be `False`.
|
|
144
|
+
|
|
144
145
|
Other Parameters
|
|
145
146
|
----------------
|
|
146
147
|
cls
|
|
@@ -11,6 +11,7 @@ from chalk.features._chalkop import op, Aggregation
|
|
|
11
11
|
from chalk.features.filter import Filter
|
|
12
12
|
from chalk.serialization.parsed_annotation import ParsedAnnotation
|
|
13
13
|
from chalk.utils.collections import ensure_tuple
|
|
14
|
+
from chalk.utils.notebook import is_notebook
|
|
14
15
|
|
|
15
16
|
if TYPE_CHECKING:
|
|
16
17
|
from chalk.features.feature_field import Feature
|
|
@@ -22,16 +23,43 @@ class NearestNeighborException(ValueError):
|
|
|
22
23
|
...
|
|
23
24
|
|
|
24
25
|
|
|
26
|
+
class UnresolvedFeature:
|
|
27
|
+
"""Fallback for features that can't be resolved in notebook environments.
|
|
28
|
+
|
|
29
|
+
This allows notebooks to work even when the feature registry is stale or incomplete.
|
|
30
|
+
The server will validate the feature exists when the query is executed.
|
|
31
|
+
"""
|
|
32
|
+
__slots__ = ("fqn",)
|
|
33
|
+
|
|
34
|
+
def __init__(self, fqn: str):
|
|
35
|
+
self.fqn = fqn
|
|
36
|
+
super().__init__()
|
|
37
|
+
|
|
38
|
+
def __str__(self):
|
|
39
|
+
return self.fqn
|
|
40
|
+
|
|
41
|
+
def __repr__(self):
|
|
42
|
+
return f"UnresolvedFeature({self.fqn!r})"
|
|
43
|
+
|
|
44
|
+
def __hash__(self):
|
|
45
|
+
return hash(self.fqn)
|
|
46
|
+
|
|
47
|
+
def __eq__(self, other: object):
|
|
48
|
+
if isinstance(other, UnresolvedFeature):
|
|
49
|
+
return self.fqn == other.fqn
|
|
50
|
+
return False
|
|
51
|
+
|
|
52
|
+
|
|
25
53
|
class _MarkedUnderlyingFeature:
|
|
26
54
|
__slots__ = ("_fn", "_source", "_debug_info")
|
|
27
55
|
|
|
28
|
-
def __init__(self, fn: Callable[[], Feature | Filter | type[DataFrame] | FeatureWrapper | Aggregation],
|
|
56
|
+
def __init__(self, fn: Callable[[], Feature | Filter | type[DataFrame] | FeatureWrapper | Aggregation | UnresolvedFeature],
|
|
29
57
|
debug_info: Any = None) -> None:
|
|
30
58
|
super().__init__()
|
|
31
59
|
self._fn = fn
|
|
32
60
|
self._debug_info = debug_info
|
|
33
61
|
|
|
34
|
-
def __call__(self, *args: Any, **kwds: Any) -> Feature | Filter | type[DataFrame] | FeatureWrapper | Aggregation:
|
|
62
|
+
def __call__(self, *args: Any, **kwds: Any) -> Feature | Filter | type[DataFrame] | FeatureWrapper | Aggregation | UnresolvedFeature:
|
|
35
63
|
return self._fn()
|
|
36
64
|
|
|
37
65
|
|
|
@@ -51,7 +79,7 @@ class FeatureWrapper:
|
|
|
51
79
|
super().__init__()
|
|
52
80
|
self._chalk_underlying = underlying
|
|
53
81
|
|
|
54
|
-
def _chalk_get_underlying(self) -> Feature | Aggregation | Filter | type[DataFrame]:
|
|
82
|
+
def _chalk_get_underlying(self) -> Feature | Aggregation | Filter | type[DataFrame] | UnresolvedFeature:
|
|
55
83
|
if isinstance(self._chalk_underlying, _MarkedUnderlyingFeature):
|
|
56
84
|
self._chalk_underlying = self._chalk_underlying()
|
|
57
85
|
if isinstance(self._chalk_underlying, FeatureWrapper):
|
|
@@ -303,6 +331,12 @@ class FeatureWrapper:
|
|
|
303
331
|
if f.attribute_name == item:
|
|
304
332
|
return FeatureWrapper(underlying.copy_with_path(f))
|
|
305
333
|
|
|
334
|
+
if is_notebook():
|
|
335
|
+
# Construct FQN by preserving the path from the underlying feature
|
|
336
|
+
# If underlying has a path, we need to include it in the FQN
|
|
337
|
+
fqn = f"{underlying.root_fqn}.{item}"
|
|
338
|
+
return UnresolvedFeature(fqn)
|
|
339
|
+
|
|
306
340
|
assert underlying.features_cls is not None
|
|
307
341
|
underlying.features_cls.__chalk_error_builder__.invalid_attribute(
|
|
308
342
|
root_feature_str=joined_class.namespace,
|
|
@@ -314,6 +348,11 @@ class FeatureWrapper:
|
|
|
314
348
|
)
|
|
315
349
|
assert False, "unreachable"
|
|
316
350
|
|
|
351
|
+
# If in notebook, fallback to constructing FQN string instead of raising error
|
|
352
|
+
if is_notebook():
|
|
353
|
+
fqn = f"{underlying.fqn}.{item}"
|
|
354
|
+
return UnresolvedFeature(fqn)
|
|
355
|
+
|
|
317
356
|
assert underlying.features_cls is not None
|
|
318
357
|
underlying.features_cls.__chalk_error_builder__.invalid_attribute(
|
|
319
358
|
root_feature_str=underlying.fqn,
|
chalk/features/hooks.py
CHANGED
|
@@ -4,14 +4,15 @@ import asyncio
|
|
|
4
4
|
import inspect
|
|
5
5
|
import logging
|
|
6
6
|
import time # Added for measuring duration
|
|
7
|
-
from typing import Any, Callable, Iterable, Optional, Set, Tuple, Union, overload
|
|
8
|
-
|
|
9
|
-
from typing_extensions import TypeAlias
|
|
7
|
+
from typing import TYPE_CHECKING, Any, Callable, Iterable, Optional, Set, Tuple, TypeAlias, Union, overload
|
|
10
8
|
|
|
11
9
|
from chalk.features.tag import Environments
|
|
12
10
|
from chalk.utils.collections import ensure_tuple
|
|
13
11
|
from chalk.utils.log_with_context import get_logger
|
|
14
12
|
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from chalk.features.resolver import ResourceHint
|
|
15
|
+
|
|
15
16
|
HookFn: TypeAlias = Callable[[], Any]
|
|
16
17
|
|
|
17
18
|
|
|
@@ -45,21 +46,31 @@ class Hook:
|
|
|
45
46
|
venv: Optional[str]
|
|
46
47
|
fn: HookFn
|
|
47
48
|
filename: str
|
|
49
|
+
resource_hint: Optional["ResourceHint"]
|
|
50
|
+
resource_group: Optional[str]
|
|
48
51
|
|
|
49
52
|
def __init__(
|
|
50
|
-
self,
|
|
53
|
+
self,
|
|
54
|
+
fn: HookFn,
|
|
55
|
+
filename: str,
|
|
56
|
+
environment: Optional[Environments] = None,
|
|
57
|
+
venv: Optional[str] = None,
|
|
58
|
+
resource_hint: Optional["ResourceHint"] = None,
|
|
59
|
+
resource_group: Optional[str] = None,
|
|
51
60
|
):
|
|
52
61
|
super().__init__()
|
|
53
62
|
self.fn = fn
|
|
54
63
|
self.filename = filename
|
|
55
64
|
self.environment = None if environment is None else ensure_tuple(environment)
|
|
56
65
|
self.venv = venv
|
|
66
|
+
self.resource_hint = resource_hint
|
|
67
|
+
self.resource_group = resource_group
|
|
57
68
|
|
|
58
69
|
def __call__(self):
|
|
59
70
|
return self.fn()
|
|
60
71
|
|
|
61
72
|
def __repr__(self):
|
|
62
|
-
return f'Hook(filename={self.filename}, fn={self.fn.__name__}", environment={str(self.environment)}, venv={self.venv})'
|
|
73
|
+
return f'Hook(filename={self.filename}, fn={self.fn.__name__}", environment={str(self.environment)}, venv={self.venv}, resource_hint={self.resource_hint}, resource_group={self.resource_group})'
|
|
63
74
|
|
|
64
75
|
@classmethod
|
|
65
76
|
async def async_run_all_before_all(cls, environment: str, venv: Optional[str] = None) -> None:
|
|
@@ -77,17 +88,46 @@ def before_all(fn: HookFn, /) -> Hook:
|
|
|
77
88
|
|
|
78
89
|
@overload
|
|
79
90
|
def before_all(
|
|
80
|
-
fn:
|
|
91
|
+
fn: HookFn,
|
|
92
|
+
/,
|
|
93
|
+
environment: Optional[Environments] = None,
|
|
94
|
+
venv: Optional[str] = None,
|
|
95
|
+
resource_hint: Optional["ResourceHint"] = None,
|
|
96
|
+
resource_group: Optional[str] = None,
|
|
97
|
+
) -> Hook:
|
|
98
|
+
...
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
@overload
|
|
102
|
+
def before_all(
|
|
103
|
+
fn: None = None,
|
|
104
|
+
/,
|
|
105
|
+
environment: Optional[Environments] = None,
|
|
106
|
+
venv: Optional[str] = None,
|
|
107
|
+
resource_hint: Optional["ResourceHint"] = None,
|
|
108
|
+
resource_group: Optional[str] = None,
|
|
81
109
|
) -> Callable[[HookFn], Hook]:
|
|
82
110
|
...
|
|
83
111
|
|
|
84
112
|
|
|
85
113
|
def before_all(
|
|
86
|
-
fn: Optional[HookFn] = None,
|
|
114
|
+
fn: Optional[HookFn] = None,
|
|
115
|
+
/,
|
|
116
|
+
environment: Optional[Environments] = None,
|
|
117
|
+
venv: Optional[str] = None,
|
|
118
|
+
resource_hint: Optional["ResourceHint"] = None,
|
|
119
|
+
resource_group: Optional[str] = None,
|
|
87
120
|
) -> Union[Hook, Callable[[HookFn], Hook]]:
|
|
88
121
|
def decorator(f: HookFn):
|
|
89
122
|
caller_filename = inspect.getsourcefile(f) or "unknown_file"
|
|
90
|
-
hook = Hook(
|
|
123
|
+
hook = Hook(
|
|
124
|
+
fn=f,
|
|
125
|
+
filename=caller_filename,
|
|
126
|
+
environment=environment,
|
|
127
|
+
venv=venv,
|
|
128
|
+
resource_hint=resource_hint,
|
|
129
|
+
resource_group=resource_group,
|
|
130
|
+
)
|
|
91
131
|
Hook.before_all.add(hook)
|
|
92
132
|
return hook
|
|
93
133
|
|
|
@@ -95,23 +135,52 @@ def before_all(
|
|
|
95
135
|
|
|
96
136
|
|
|
97
137
|
@overload
|
|
98
|
-
def after_all(fn: HookFn,
|
|
138
|
+
def after_all(fn: HookFn, /) -> Hook:
|
|
139
|
+
...
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
@overload
|
|
143
|
+
def after_all(
|
|
144
|
+
fn: HookFn,
|
|
145
|
+
/,
|
|
146
|
+
environment: Optional[Environments] = None,
|
|
147
|
+
venv: Optional[str] = None,
|
|
148
|
+
resource_hint: Optional["ResourceHint"] = None,
|
|
149
|
+
resource_group: Optional[str] = None,
|
|
150
|
+
) -> Hook:
|
|
99
151
|
...
|
|
100
152
|
|
|
101
153
|
|
|
102
154
|
@overload
|
|
103
155
|
def after_all(
|
|
104
|
-
fn: None = None,
|
|
156
|
+
fn: None = None,
|
|
157
|
+
/,
|
|
158
|
+
environment: Optional[Environments] = None,
|
|
159
|
+
venv: Optional[str] = None,
|
|
160
|
+
resource_hint: Optional["ResourceHint"] = None,
|
|
161
|
+
resource_group: Optional[str] = None,
|
|
105
162
|
) -> Callable[[HookFn], Hook]:
|
|
106
163
|
...
|
|
107
164
|
|
|
108
165
|
|
|
109
166
|
def after_all(
|
|
110
|
-
fn: Optional[HookFn] = None,
|
|
167
|
+
fn: Optional[HookFn] = None,
|
|
168
|
+
/,
|
|
169
|
+
environment: Optional[Environments] = None,
|
|
170
|
+
venv: Optional[str] = None,
|
|
171
|
+
resource_hint: Optional["ResourceHint"] = None,
|
|
172
|
+
resource_group: Optional[str] = None,
|
|
111
173
|
) -> Union[Hook, Callable[[HookFn], Hook]]:
|
|
112
174
|
def decorator(f: HookFn):
|
|
113
175
|
caller_filename = inspect.getsourcefile(f) or "unknown_file"
|
|
114
|
-
hook = Hook(
|
|
176
|
+
hook = Hook(
|
|
177
|
+
fn=f,
|
|
178
|
+
filename=caller_filename,
|
|
179
|
+
environment=environment,
|
|
180
|
+
venv=venv,
|
|
181
|
+
resource_hint=resource_hint,
|
|
182
|
+
resource_group=resource_group,
|
|
183
|
+
)
|
|
115
184
|
Hook.after_all.add(hook)
|
|
116
185
|
return hook
|
|
117
186
|
|
chalk/features/inference.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Optional
|
|
1
|
+
from typing import Callable, Optional
|
|
2
2
|
|
|
3
3
|
from chalk._lsp.error_builder import get_resolver_error_builder
|
|
4
4
|
from chalk.features import DataFrame
|
|
@@ -11,8 +11,62 @@ from chalk.ml.model_version import ModelVersion
|
|
|
11
11
|
from chalk.utils.collections import ensure_tuple
|
|
12
12
|
|
|
13
13
|
|
|
14
|
+
def build_inference_function(
|
|
15
|
+
model_version: ModelVersion, pkey: Feature, output_features: Optional[Feature | list[Feature]] = None
|
|
16
|
+
) -> Callable[[DataFrame], DataFrame]:
|
|
17
|
+
"""Build the core inference function that takes a DataFrame and returns predictions.
|
|
18
|
+
|
|
19
|
+
Uses ModelInference.prepare_input() and extract_output() for model-specific logic.
|
|
20
|
+
|
|
21
|
+
Parameters
|
|
22
|
+
----------
|
|
23
|
+
model_version
|
|
24
|
+
The model version to use for prediction
|
|
25
|
+
pkey
|
|
26
|
+
The primary key feature to exclude from predictions
|
|
27
|
+
output_features
|
|
28
|
+
Optional output feature(s) to add predictions to the DataFrame.
|
|
29
|
+
Can be a single Feature or a list of Features for multi-output models.
|
|
30
|
+
|
|
31
|
+
Returns
|
|
32
|
+
-------
|
|
33
|
+
Callable[[DataFrame], DataFrame]
|
|
34
|
+
Function that takes a DataFrame and returns predictions
|
|
35
|
+
"""
|
|
36
|
+
# For all other models, use the ModelInference prepare_input/extract_output methods
|
|
37
|
+
pkey_string = str(pkey)
|
|
38
|
+
|
|
39
|
+
def fn(inp: DataFrame):
|
|
40
|
+
# Get features (excluding primary key) as PyArrow table
|
|
41
|
+
feature_table = inp[[c for c in inp.columns if c != pkey_string]].to_pyarrow()
|
|
42
|
+
|
|
43
|
+
# Use model-specific input preparation (default: __array__(), ONNX: struct array)
|
|
44
|
+
model_input = model_version.predictor.prepare_input(feature_table)
|
|
45
|
+
|
|
46
|
+
# Run prediction
|
|
47
|
+
result = model_version.predict(model_input)
|
|
48
|
+
|
|
49
|
+
if output_features is not None:
|
|
50
|
+
# Normalize to list for uniform processing
|
|
51
|
+
features_list = output_features if isinstance(output_features, list) else [output_features]
|
|
52
|
+
|
|
53
|
+
# Extract output for each feature and build columns dict
|
|
54
|
+
columns_dict = {}
|
|
55
|
+
for output_feature in features_list:
|
|
56
|
+
# Use model-specific output extraction (default: identity, ONNX: extract field)
|
|
57
|
+
output_feature_name = str(output_feature).split(".")[-1]
|
|
58
|
+
result_data = model_version.predictor.extract_output(result, output_feature_name)
|
|
59
|
+
columns_dict[output_feature] = result_data
|
|
60
|
+
|
|
61
|
+
return inp[pkey_string].with_columns(columns_dict)
|
|
62
|
+
|
|
63
|
+
return result
|
|
64
|
+
|
|
65
|
+
return fn
|
|
66
|
+
|
|
67
|
+
|
|
14
68
|
def generate_inference_resolver(
|
|
15
|
-
inputs: list[Underscore], model_version: ModelVersion, resource_hint: Optional[ResourceHint] = None
|
|
69
|
+
inputs: list[Underscore] | Underscore, model_version: ModelVersion, resource_hint: Optional[ResourceHint] = None
|
|
16
70
|
) -> Feature:
|
|
17
71
|
output_feature = Feature()
|
|
18
72
|
previous_hook = output_feature.hook
|
|
@@ -22,20 +76,21 @@ def generate_inference_resolver(
|
|
|
22
76
|
previous_hook(features)
|
|
23
77
|
|
|
24
78
|
pkey = features.__chalk_primary__
|
|
25
|
-
|
|
79
|
+
if pkey is None:
|
|
80
|
+
raise ValueError(f"Feature class {features} does not have a primary key defined")
|
|
26
81
|
|
|
27
82
|
def resolver_factory():
|
|
28
|
-
|
|
29
|
-
result = model_version.predict(inp[[c for c in inp.columns if c != pkey_string]].to_pyarrow())
|
|
30
|
-
return inp[pkey_string].with_columns({output_feature: result})
|
|
31
|
-
|
|
83
|
+
# Use the extracted build_inference_function
|
|
32
84
|
cleaned_inputs = []
|
|
33
|
-
|
|
85
|
+
inputs_list = inputs if isinstance(inputs, list) else [inputs]
|
|
86
|
+
for i in inputs_list:
|
|
34
87
|
try:
|
|
35
88
|
cleaned_inputs.append(Feature.from_root_fqn(output_feature.namespace + str(i)[1:]))
|
|
36
89
|
except Exception as e:
|
|
37
90
|
raise ValueError(f"Could not find feature for input {i}: {e}")
|
|
38
91
|
|
|
92
|
+
fn = build_inference_function(model_version, pkey, output_feature)
|
|
93
|
+
|
|
39
94
|
identifier = model_version.identifier or ""
|
|
40
95
|
model_reference = MODEL_REFERENCE_REGISTRY.get((model_version.name, identifier), None)
|
|
41
96
|
if model_reference is not None:
|
|
@@ -44,7 +99,7 @@ def generate_inference_resolver(
|
|
|
44
99
|
return OnlineResolver(
|
|
45
100
|
function_definition="",
|
|
46
101
|
filename="",
|
|
47
|
-
fqn=f"{model_version.name}__{output_feature.namespace}
|
|
102
|
+
fqn=f"{model_version.name}__{output_feature.namespace}_{output_feature.name}",
|
|
48
103
|
doc=None,
|
|
49
104
|
inputs=[DataFrame[[pkey, *ensure_tuple(cleaned_inputs)]]],
|
|
50
105
|
state=None,
|
|
@@ -58,7 +113,7 @@ def generate_inference_resolver(
|
|
|
58
113
|
when=None,
|
|
59
114
|
tags=None,
|
|
60
115
|
owner=None,
|
|
61
|
-
resource_hint=resource_hint,
|
|
116
|
+
resource_hint=resource_hint or model_version.resource_hint,
|
|
62
117
|
data_sources=None,
|
|
63
118
|
is_sql_file_resolver=False,
|
|
64
119
|
source_line=None,
|