chalkpy 2.89.22__py3-none-any.whl → 2.95.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- chalk/__init__.py +2 -1
- chalk/_gen/chalk/arrow/v1/arrow_pb2.py +7 -5
- chalk/_gen/chalk/arrow/v1/arrow_pb2.pyi +6 -0
- chalk/_gen/chalk/artifacts/v1/chart_pb2.py +36 -33
- chalk/_gen/chalk/artifacts/v1/chart_pb2.pyi +41 -1
- chalk/_gen/chalk/artifacts/v1/cron_query_pb2.py +8 -7
- chalk/_gen/chalk/artifacts/v1/cron_query_pb2.pyi +5 -0
- chalk/_gen/chalk/common/v1/offline_query_pb2.py +19 -13
- chalk/_gen/chalk/common/v1/offline_query_pb2.pyi +37 -0
- chalk/_gen/chalk/common/v1/online_query_pb2.py +54 -54
- chalk/_gen/chalk/common/v1/online_query_pb2.pyi +13 -1
- chalk/_gen/chalk/common/v1/script_task_pb2.py +13 -11
- chalk/_gen/chalk/common/v1/script_task_pb2.pyi +19 -1
- chalk/_gen/chalk/dataframe/__init__.py +0 -0
- chalk/_gen/chalk/dataframe/v1/__init__.py +0 -0
- chalk/_gen/chalk/dataframe/v1/dataframe_pb2.py +48 -0
- chalk/_gen/chalk/dataframe/v1/dataframe_pb2.pyi +123 -0
- chalk/_gen/chalk/dataframe/v1/dataframe_pb2_grpc.py +4 -0
- chalk/_gen/chalk/dataframe/v1/dataframe_pb2_grpc.pyi +4 -0
- chalk/_gen/chalk/graph/v1/graph_pb2.py +150 -149
- chalk/_gen/chalk/graph/v1/graph_pb2.pyi +25 -0
- chalk/_gen/chalk/graph/v1/sources_pb2.py +94 -84
- chalk/_gen/chalk/graph/v1/sources_pb2.pyi +56 -0
- chalk/_gen/chalk/kubernetes/v1/horizontalpodautoscaler_pb2.py +79 -0
- chalk/_gen/chalk/kubernetes/v1/horizontalpodautoscaler_pb2.pyi +377 -0
- chalk/_gen/chalk/kubernetes/v1/horizontalpodautoscaler_pb2_grpc.py +4 -0
- chalk/_gen/chalk/kubernetes/v1/horizontalpodautoscaler_pb2_grpc.pyi +4 -0
- chalk/_gen/chalk/kubernetes/v1/scaledobject_pb2.py +43 -7
- chalk/_gen/chalk/kubernetes/v1/scaledobject_pb2.pyi +252 -2
- chalk/_gen/chalk/protosql/v1/sql_service_pb2.py +54 -27
- chalk/_gen/chalk/protosql/v1/sql_service_pb2.pyi +131 -3
- chalk/_gen/chalk/protosql/v1/sql_service_pb2_grpc.py +45 -0
- chalk/_gen/chalk/protosql/v1/sql_service_pb2_grpc.pyi +14 -0
- chalk/_gen/chalk/python/v1/types_pb2.py +14 -14
- chalk/_gen/chalk/python/v1/types_pb2.pyi +8 -0
- chalk/_gen/chalk/server/v1/benchmark_pb2.py +76 -0
- chalk/_gen/chalk/server/v1/benchmark_pb2.pyi +156 -0
- chalk/_gen/chalk/server/v1/benchmark_pb2_grpc.py +258 -0
- chalk/_gen/chalk/server/v1/benchmark_pb2_grpc.pyi +84 -0
- chalk/_gen/chalk/server/v1/billing_pb2.py +40 -38
- chalk/_gen/chalk/server/v1/billing_pb2.pyi +17 -1
- chalk/_gen/chalk/server/v1/branches_pb2.py +45 -0
- chalk/_gen/chalk/server/v1/branches_pb2.pyi +80 -0
- chalk/_gen/chalk/server/v1/branches_pb2_grpc.pyi +36 -0
- chalk/_gen/chalk/server/v1/builder_pb2.py +372 -272
- chalk/_gen/chalk/server/v1/builder_pb2.pyi +479 -12
- chalk/_gen/chalk/server/v1/builder_pb2_grpc.py +360 -0
- chalk/_gen/chalk/server/v1/builder_pb2_grpc.pyi +96 -0
- chalk/_gen/chalk/server/v1/chart_pb2.py +10 -10
- chalk/_gen/chalk/server/v1/chart_pb2.pyi +18 -2
- chalk/_gen/chalk/server/v1/clickhouse_pb2.py +42 -0
- chalk/_gen/chalk/server/v1/clickhouse_pb2.pyi +17 -0
- chalk/_gen/chalk/server/v1/clickhouse_pb2_grpc.py +78 -0
- chalk/_gen/chalk/server/v1/clickhouse_pb2_grpc.pyi +38 -0
- chalk/_gen/chalk/server/v1/cloud_components_pb2.py +153 -107
- chalk/_gen/chalk/server/v1/cloud_components_pb2.pyi +146 -4
- chalk/_gen/chalk/server/v1/cloud_components_pb2_grpc.py +180 -0
- chalk/_gen/chalk/server/v1/cloud_components_pb2_grpc.pyi +48 -0
- chalk/_gen/chalk/server/v1/cloud_credentials_pb2.py +11 -3
- chalk/_gen/chalk/server/v1/cloud_credentials_pb2.pyi +20 -0
- chalk/_gen/chalk/server/v1/cloud_credentials_pb2_grpc.py +45 -0
- chalk/_gen/chalk/server/v1/cloud_credentials_pb2_grpc.pyi +12 -0
- chalk/_gen/chalk/server/v1/dataplanejobqueue_pb2.py +59 -35
- chalk/_gen/chalk/server/v1/dataplanejobqueue_pb2.pyi +127 -1
- chalk/_gen/chalk/server/v1/dataplanejobqueue_pb2_grpc.py +135 -0
- chalk/_gen/chalk/server/v1/dataplanejobqueue_pb2_grpc.pyi +36 -0
- chalk/_gen/chalk/server/v1/dataplaneworkflows_pb2.py +90 -0
- chalk/_gen/chalk/server/v1/dataplaneworkflows_pb2.pyi +264 -0
- chalk/_gen/chalk/server/v1/dataplaneworkflows_pb2_grpc.py +170 -0
- chalk/_gen/chalk/server/v1/dataplaneworkflows_pb2_grpc.pyi +62 -0
- chalk/_gen/chalk/server/v1/datasets_pb2.py +36 -24
- chalk/_gen/chalk/server/v1/datasets_pb2.pyi +71 -2
- chalk/_gen/chalk/server/v1/datasets_pb2_grpc.py +45 -0
- chalk/_gen/chalk/server/v1/datasets_pb2_grpc.pyi +12 -0
- chalk/_gen/chalk/server/v1/deploy_pb2.py +9 -3
- chalk/_gen/chalk/server/v1/deploy_pb2.pyi +12 -0
- chalk/_gen/chalk/server/v1/deploy_pb2_grpc.py +45 -0
- chalk/_gen/chalk/server/v1/deploy_pb2_grpc.pyi +12 -0
- chalk/_gen/chalk/server/v1/deployment_pb2.py +20 -15
- chalk/_gen/chalk/server/v1/deployment_pb2.pyi +25 -0
- chalk/_gen/chalk/server/v1/environment_pb2.py +25 -15
- chalk/_gen/chalk/server/v1/environment_pb2.pyi +93 -1
- chalk/_gen/chalk/server/v1/eventbus_pb2.py +44 -0
- chalk/_gen/chalk/server/v1/eventbus_pb2.pyi +64 -0
- chalk/_gen/chalk/server/v1/eventbus_pb2_grpc.py +4 -0
- chalk/_gen/chalk/server/v1/eventbus_pb2_grpc.pyi +4 -0
- chalk/_gen/chalk/server/v1/files_pb2.py +65 -0
- chalk/_gen/chalk/server/v1/files_pb2.pyi +167 -0
- chalk/_gen/chalk/server/v1/files_pb2_grpc.py +4 -0
- chalk/_gen/chalk/server/v1/files_pb2_grpc.pyi +4 -0
- chalk/_gen/chalk/server/v1/graph_pb2.py +41 -3
- chalk/_gen/chalk/server/v1/graph_pb2.pyi +191 -0
- chalk/_gen/chalk/server/v1/graph_pb2_grpc.py +92 -0
- chalk/_gen/chalk/server/v1/graph_pb2_grpc.pyi +32 -0
- chalk/_gen/chalk/server/v1/incident_pb2.py +57 -0
- chalk/_gen/chalk/server/v1/incident_pb2.pyi +165 -0
- chalk/_gen/chalk/server/v1/incident_pb2_grpc.py +4 -0
- chalk/_gen/chalk/server/v1/incident_pb2_grpc.pyi +4 -0
- chalk/_gen/chalk/server/v1/indexing_job_pb2.py +44 -0
- chalk/_gen/chalk/server/v1/indexing_job_pb2.pyi +38 -0
- chalk/_gen/chalk/server/v1/indexing_job_pb2_grpc.py +78 -0
- chalk/_gen/chalk/server/v1/indexing_job_pb2_grpc.pyi +38 -0
- chalk/_gen/chalk/server/v1/integrations_pb2.py +11 -9
- chalk/_gen/chalk/server/v1/integrations_pb2.pyi +34 -2
- chalk/_gen/chalk/server/v1/kube_pb2.py +29 -19
- chalk/_gen/chalk/server/v1/kube_pb2.pyi +28 -0
- chalk/_gen/chalk/server/v1/kube_pb2_grpc.py +45 -0
- chalk/_gen/chalk/server/v1/kube_pb2_grpc.pyi +12 -0
- chalk/_gen/chalk/server/v1/log_pb2.py +21 -3
- chalk/_gen/chalk/server/v1/log_pb2.pyi +68 -0
- chalk/_gen/chalk/server/v1/log_pb2_grpc.py +90 -0
- chalk/_gen/chalk/server/v1/log_pb2_grpc.pyi +24 -0
- chalk/_gen/chalk/server/v1/metadataplanejobqueue_pb2.py +73 -0
- chalk/_gen/chalk/server/v1/metadataplanejobqueue_pb2.pyi +212 -0
- chalk/_gen/chalk/server/v1/metadataplanejobqueue_pb2_grpc.py +217 -0
- chalk/_gen/chalk/server/v1/metadataplanejobqueue_pb2_grpc.pyi +74 -0
- chalk/_gen/chalk/server/v1/model_registry_pb2.py +10 -10
- chalk/_gen/chalk/server/v1/model_registry_pb2.pyi +4 -1
- chalk/_gen/chalk/server/v1/monitoring_pb2.py +84 -75
- chalk/_gen/chalk/server/v1/monitoring_pb2.pyi +1 -0
- chalk/_gen/chalk/server/v1/monitoring_pb2_grpc.py +136 -0
- chalk/_gen/chalk/server/v1/monitoring_pb2_grpc.pyi +38 -0
- chalk/_gen/chalk/server/v1/offline_queries_pb2.py +32 -10
- chalk/_gen/chalk/server/v1/offline_queries_pb2.pyi +73 -0
- chalk/_gen/chalk/server/v1/offline_queries_pb2_grpc.py +90 -0
- chalk/_gen/chalk/server/v1/offline_queries_pb2_grpc.pyi +24 -0
- chalk/_gen/chalk/server/v1/plandebug_pb2.py +53 -0
- chalk/_gen/chalk/server/v1/plandebug_pb2.pyi +86 -0
- chalk/_gen/chalk/server/v1/plandebug_pb2_grpc.py +168 -0
- chalk/_gen/chalk/server/v1/plandebug_pb2_grpc.pyi +60 -0
- chalk/_gen/chalk/server/v1/queries_pb2.py +76 -48
- chalk/_gen/chalk/server/v1/queries_pb2.pyi +155 -2
- chalk/_gen/chalk/server/v1/queries_pb2_grpc.py +180 -0
- chalk/_gen/chalk/server/v1/queries_pb2_grpc.pyi +48 -0
- chalk/_gen/chalk/server/v1/scheduled_query_pb2.py +4 -2
- chalk/_gen/chalk/server/v1/scheduled_query_pb2_grpc.py +45 -0
- chalk/_gen/chalk/server/v1/scheduled_query_pb2_grpc.pyi +12 -0
- chalk/_gen/chalk/server/v1/scheduled_query_run_pb2.py +12 -6
- chalk/_gen/chalk/server/v1/scheduled_query_run_pb2.pyi +75 -2
- chalk/_gen/chalk/server/v1/scheduler_pb2.py +24 -12
- chalk/_gen/chalk/server/v1/scheduler_pb2.pyi +61 -1
- chalk/_gen/chalk/server/v1/scheduler_pb2_grpc.py +90 -0
- chalk/_gen/chalk/server/v1/scheduler_pb2_grpc.pyi +24 -0
- chalk/_gen/chalk/server/v1/script_tasks_pb2.py +26 -14
- chalk/_gen/chalk/server/v1/script_tasks_pb2.pyi +33 -3
- chalk/_gen/chalk/server/v1/script_tasks_pb2_grpc.py +90 -0
- chalk/_gen/chalk/server/v1/script_tasks_pb2_grpc.pyi +24 -0
- chalk/_gen/chalk/server/v1/sql_interface_pb2.py +75 -0
- chalk/_gen/chalk/server/v1/sql_interface_pb2.pyi +142 -0
- chalk/_gen/chalk/server/v1/sql_interface_pb2_grpc.py +349 -0
- chalk/_gen/chalk/server/v1/sql_interface_pb2_grpc.pyi +114 -0
- chalk/_gen/chalk/server/v1/sql_queries_pb2.py +48 -0
- chalk/_gen/chalk/server/v1/sql_queries_pb2.pyi +150 -0
- chalk/_gen/chalk/server/v1/sql_queries_pb2_grpc.py +123 -0
- chalk/_gen/chalk/server/v1/sql_queries_pb2_grpc.pyi +52 -0
- chalk/_gen/chalk/server/v1/team_pb2.py +156 -137
- chalk/_gen/chalk/server/v1/team_pb2.pyi +56 -10
- chalk/_gen/chalk/server/v1/team_pb2_grpc.py +90 -0
- chalk/_gen/chalk/server/v1/team_pb2_grpc.pyi +24 -0
- chalk/_gen/chalk/server/v1/topic_pb2.py +5 -3
- chalk/_gen/chalk/server/v1/topic_pb2.pyi +10 -1
- chalk/_gen/chalk/server/v1/trace_pb2.py +50 -28
- chalk/_gen/chalk/server/v1/trace_pb2.pyi +121 -0
- chalk/_gen/chalk/server/v1/trace_pb2_grpc.py +135 -0
- chalk/_gen/chalk/server/v1/trace_pb2_grpc.pyi +42 -0
- chalk/_gen/chalk/server/v1/webhook_pb2.py +9 -3
- chalk/_gen/chalk/server/v1/webhook_pb2.pyi +18 -0
- chalk/_gen/chalk/server/v1/webhook_pb2_grpc.py +45 -0
- chalk/_gen/chalk/server/v1/webhook_pb2_grpc.pyi +12 -0
- chalk/_gen/chalk/streaming/v1/debug_service_pb2.py +62 -0
- chalk/_gen/chalk/streaming/v1/debug_service_pb2.pyi +75 -0
- chalk/_gen/chalk/streaming/v1/debug_service_pb2_grpc.py +221 -0
- chalk/_gen/chalk/streaming/v1/debug_service_pb2_grpc.pyi +88 -0
- chalk/_gen/chalk/streaming/v1/simple_streaming_service_pb2.py +19 -7
- chalk/_gen/chalk/streaming/v1/simple_streaming_service_pb2.pyi +96 -3
- chalk/_gen/chalk/streaming/v1/simple_streaming_service_pb2_grpc.py +48 -0
- chalk/_gen/chalk/streaming/v1/simple_streaming_service_pb2_grpc.pyi +20 -0
- chalk/_gen/chalk/utils/v1/field_change_pb2.py +32 -0
- chalk/_gen/chalk/utils/v1/field_change_pb2.pyi +42 -0
- chalk/_gen/chalk/utils/v1/field_change_pb2_grpc.py +4 -0
- chalk/_gen/chalk/utils/v1/field_change_pb2_grpc.pyi +4 -0
- chalk/_lsp/error_builder.py +11 -0
- chalk/_monitoring/Chart.py +1 -3
- chalk/_version.py +1 -1
- chalk/cli.py +5 -10
- chalk/client/client.py +178 -64
- chalk/client/client_async.py +154 -0
- chalk/client/client_async_impl.py +22 -0
- chalk/client/client_grpc.py +738 -112
- chalk/client/client_impl.py +541 -136
- chalk/client/dataset.py +27 -6
- chalk/client/models.py +99 -2
- chalk/client/serialization/model_serialization.py +126 -10
- chalk/config/project_config.py +1 -1
- chalk/df/LazyFramePlaceholder.py +1154 -0
- chalk/df/ast_parser.py +2 -10
- chalk/features/_class_property.py +7 -0
- chalk/features/_embedding/embedding.py +1 -0
- chalk/features/_embedding/sentence_transformer.py +1 -1
- chalk/features/_encoding/converter.py +83 -2
- chalk/features/_encoding/pyarrow.py +20 -4
- chalk/features/_encoding/rich.py +1 -3
- chalk/features/_tensor.py +1 -2
- chalk/features/dataframe/_filters.py +14 -5
- chalk/features/dataframe/_impl.py +91 -36
- chalk/features/dataframe/_validation.py +11 -7
- chalk/features/feature_field.py +40 -30
- chalk/features/feature_set.py +1 -2
- chalk/features/feature_set_decorator.py +1 -0
- chalk/features/feature_wrapper.py +42 -3
- chalk/features/hooks.py +81 -12
- chalk/features/inference.py +65 -10
- chalk/features/resolver.py +338 -56
- chalk/features/tag.py +1 -3
- chalk/features/underscore_features.py +2 -1
- chalk/functions/__init__.py +456 -21
- chalk/functions/holidays.py +1 -3
- chalk/gitignore/gitignore_parser.py +5 -1
- chalk/importer.py +186 -74
- chalk/ml/__init__.py +6 -2
- chalk/ml/model_hooks.py +368 -51
- chalk/ml/model_reference.py +68 -10
- chalk/ml/model_version.py +34 -21
- chalk/ml/utils.py +143 -40
- chalk/operators/_utils.py +14 -3
- chalk/parsed/_proto/export.py +22 -0
- chalk/parsed/duplicate_input_gql.py +4 -0
- chalk/parsed/expressions.py +1 -3
- chalk/parsed/json_conversions.py +21 -14
- chalk/parsed/to_proto.py +16 -4
- chalk/parsed/user_types_to_json.py +31 -10
- chalk/parsed/validation_from_registries.py +182 -0
- chalk/queries/named_query.py +16 -6
- chalk/queries/scheduled_query.py +13 -1
- chalk/serialization/parsed_annotation.py +25 -12
- chalk/sql/__init__.py +221 -0
- chalk/sql/_internal/integrations/athena.py +6 -1
- chalk/sql/_internal/integrations/bigquery.py +22 -2
- chalk/sql/_internal/integrations/databricks.py +61 -18
- chalk/sql/_internal/integrations/mssql.py +281 -0
- chalk/sql/_internal/integrations/postgres.py +11 -3
- chalk/sql/_internal/integrations/redshift.py +4 -0
- chalk/sql/_internal/integrations/snowflake.py +11 -2
- chalk/sql/_internal/integrations/util.py +2 -1
- chalk/sql/_internal/sql_file_resolver.py +55 -10
- chalk/sql/_internal/sql_source.py +36 -2
- chalk/streams/__init__.py +1 -3
- chalk/streams/_kafka_source.py +5 -1
- chalk/streams/_windows.py +16 -4
- chalk/streams/types.py +1 -2
- chalk/utils/__init__.py +1 -3
- chalk/utils/_otel_version.py +13 -0
- chalk/utils/async_helpers.py +14 -5
- chalk/utils/df_utils.py +2 -2
- chalk/utils/duration.py +1 -3
- chalk/utils/job_log_display.py +538 -0
- chalk/utils/missing_dependency.py +5 -4
- chalk/utils/notebook.py +255 -2
- chalk/utils/pl_helpers.py +190 -37
- chalk/utils/pydanticutil/pydantic_compat.py +1 -2
- chalk/utils/storage_client.py +246 -0
- chalk/utils/threading.py +1 -3
- chalk/utils/tracing.py +194 -86
- {chalkpy-2.89.22.dist-info → chalkpy-2.95.3.dist-info}/METADATA +53 -21
- {chalkpy-2.89.22.dist-info → chalkpy-2.95.3.dist-info}/RECORD +268 -198
- {chalkpy-2.89.22.dist-info → chalkpy-2.95.3.dist-info}/WHEEL +0 -0
- {chalkpy-2.89.22.dist-info → chalkpy-2.95.3.dist-info}/entry_points.txt +0 -0
- {chalkpy-2.89.22.dist-info → chalkpy-2.95.3.dist-info}/top_level.txt +0 -0
chalk/client/dataset.py
CHANGED
|
@@ -48,6 +48,7 @@ from chalk.integrations.catalogs.base_catalog import BaseCatalog
|
|
|
48
48
|
from chalk.utils.df_utils import read_parquet
|
|
49
49
|
from chalk.utils.log_with_context import get_logger
|
|
50
50
|
from chalk.utils.missing_dependency import missing_dependency_exception
|
|
51
|
+
from chalk.utils.pl_helpers import apply_compat, polars_group_by_instead_of_groupby
|
|
51
52
|
from chalk.utils.threading import DEFAULT_IO_EXECUTOR
|
|
52
53
|
|
|
53
54
|
if TYPE_CHECKING:
|
|
@@ -541,14 +542,19 @@ def _extract_df_columns(
|
|
|
541
542
|
for fqn in unique_features
|
|
542
543
|
]
|
|
543
544
|
|
|
544
|
-
|
|
545
|
+
if polars_group_by_instead_of_groupby:
|
|
546
|
+
df = df.group_by("pkey").agg(cols)
|
|
547
|
+
else:
|
|
548
|
+
df = df.groupby("pkey").agg(cols) # pyright: ignore
|
|
545
549
|
decoded_stmts: List[pl.Expr] = []
|
|
546
550
|
for col in df.columns:
|
|
547
551
|
if col == "pkey":
|
|
548
552
|
continue
|
|
549
553
|
else:
|
|
550
554
|
decoded_stmts.append(
|
|
551
|
-
|
|
555
|
+
apply_compat(
|
|
556
|
+
pl.col(col), _json_decode, return_dtype=Feature.from_root_fqn(col).converter.polars_dtype
|
|
557
|
+
)
|
|
552
558
|
)
|
|
553
559
|
df = df.select(decoded_stmts)
|
|
554
560
|
# it might be a good idea to remember that we used to rename this __id__ column to the primary key
|
|
@@ -560,7 +566,13 @@ def _extract_df_columns(
|
|
|
560
566
|
|
|
561
567
|
decoded_stmts: List[pl.Expr] = []
|
|
562
568
|
feature_name_to_metadata = None if column_metadata is None else {x.feature_fqn: x for x in column_metadata}
|
|
563
|
-
for
|
|
569
|
+
# Use collect_schema().dtypes() for newer Polars versions to avoid performance warning
|
|
570
|
+
# Fall back to df.dtypes for older versions
|
|
571
|
+
try:
|
|
572
|
+
dtypes = df.collect_schema().dtypes()
|
|
573
|
+
except AttributeError:
|
|
574
|
+
dtypes = df.dtypes
|
|
575
|
+
for col, dtype in zip(df.columns, dtypes):
|
|
564
576
|
if version in (
|
|
565
577
|
DatasetVersion.BIGQUERY_JOB_WITH_B32_ENCODED_COLNAMES,
|
|
566
578
|
DatasetVersion.BIGQUERY_JOB_WITH_B32_ENCODED_COLNAMES_V2,
|
|
@@ -571,7 +583,7 @@ def _extract_df_columns(
|
|
|
571
583
|
# Assuming that the only datetime column is for timestamps
|
|
572
584
|
decoded_stmts.append(to_utc(df, col, pl.col(col)))
|
|
573
585
|
else:
|
|
574
|
-
decoded_stmts.append(pl.col(col)
|
|
586
|
+
decoded_stmts.append(apply_compat(pl.col(col), _json_decode, return_dtype=dtype))
|
|
575
587
|
elif version in (DatasetVersion.NATIVE_DTYPES, DatasetVersion.NATIVE_COLUMN_NAMES):
|
|
576
588
|
# We already decoded the column names so matching against the fqn
|
|
577
589
|
if col == CHALK_TS_FEATURE or col == OBSERVED_AT_FEATURE:
|
|
@@ -1338,12 +1350,14 @@ class DatasetRevisionImpl(DatasetRevision):
|
|
|
1338
1350
|
actual_args.append(
|
|
1339
1351
|
DataFrame(
|
|
1340
1352
|
pl.DataFrame([pl.Series(col_name, [], dtype=raw_input_df.schema[col_name])])
|
|
1341
|
-
.explode(
|
|
1353
|
+
.explode(col_name)
|
|
1342
1354
|
.unnest(col_name)
|
|
1343
1355
|
)
|
|
1344
1356
|
)
|
|
1345
1357
|
else:
|
|
1346
|
-
actual_args.append(
|
|
1358
|
+
actual_args.append(
|
|
1359
|
+
DataFrame(has_many_input_df.explode(has_many_input_df.columns).unnest(col_name))
|
|
1360
|
+
)
|
|
1347
1361
|
else:
|
|
1348
1362
|
value = args[i]
|
|
1349
1363
|
if isinstance(input, Feature):
|
|
@@ -1363,6 +1377,13 @@ This occurred during the actual execution of resolver {resolver.fqn}.
|
|
|
1363
1377
|
)
|
|
1364
1378
|
raise e
|
|
1365
1379
|
print(f"resolver_replay: {resolver.fqn} returned {output}")
|
|
1380
|
+
if isinstance(output, DataFrame):
|
|
1381
|
+
try:
|
|
1382
|
+
output = output.to_polars().collect().rows(named=True)
|
|
1383
|
+
except Exception as e:
|
|
1384
|
+
raise RuntimeError(
|
|
1385
|
+
f'Failed to convert DataFrame output from resolver "{resolver.fqn}" during resolver replay'
|
|
1386
|
+
) from e
|
|
1366
1387
|
output_col.append(output)
|
|
1367
1388
|
return raw_input_df.with_columns(pl.Series(name="__resolver_replay_output__", values=output_col))
|
|
1368
1389
|
|
chalk/client/models.py
CHANGED
|
@@ -7,10 +7,9 @@ import traceback
|
|
|
7
7
|
import uuid
|
|
8
8
|
from datetime import datetime, timedelta
|
|
9
9
|
from enum import Enum, IntEnum
|
|
10
|
-
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional, Sequence, Tuple, Union
|
|
10
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional, Sequence, Tuple, TypeAlias, Union
|
|
11
11
|
|
|
12
12
|
import numpy as np
|
|
13
|
-
from typing_extensions import TypeAlias
|
|
14
13
|
|
|
15
14
|
from chalk.byte_transmit.model import ByteBaseModel, ByteDict
|
|
16
15
|
from chalk.client._internal_models.models import OfflineQueryGivensVersion
|
|
@@ -461,6 +460,15 @@ class OfflineQueryInput(BaseModel):
|
|
|
461
460
|
values: List[List[Any]] # Values should be of type TJSON
|
|
462
461
|
|
|
463
462
|
|
|
463
|
+
class OfflineQueryInputSql(BaseModel):
|
|
464
|
+
"""Input to an offline query specified as a ChalkSQL query instead
|
|
465
|
+
of literal data.
|
|
466
|
+
|
|
467
|
+
Alternative to OfflineQueryInput or OfflineQueryInputUri."""
|
|
468
|
+
|
|
469
|
+
input_sql: str
|
|
470
|
+
|
|
471
|
+
|
|
464
472
|
class OnlineQueryRequest(BaseModel):
|
|
465
473
|
inputs: Mapping[str, Any] # Values should be of type TJSON
|
|
466
474
|
outputs: List[str]
|
|
@@ -839,6 +847,7 @@ class CreateOfflineQueryJobRequest(BaseModel):
|
|
|
839
847
|
None,
|
|
840
848
|
UploadedParquetShardedOfflineQueryInput,
|
|
841
849
|
OfflineQueryInputUri,
|
|
850
|
+
OfflineQueryInputSql,
|
|
842
851
|
] = None
|
|
843
852
|
"""Any givens"""
|
|
844
853
|
|
|
@@ -1659,6 +1668,7 @@ class PlanQueryResponse(BaseModel):
|
|
|
1659
1668
|
output_schema: List[FeatureSchema]
|
|
1660
1669
|
errors: List[ChalkError]
|
|
1661
1670
|
structured_plan: Optional[str] = None
|
|
1671
|
+
serialized_plan_proto_bytes: Optional[str] = None
|
|
1662
1672
|
|
|
1663
1673
|
|
|
1664
1674
|
class IngestDatasetRequest(BaseModel):
|
|
@@ -1783,3 +1793,90 @@ class GetRegisteredModelVersionResponse(BaseModel):
|
|
|
1783
1793
|
|
|
1784
1794
|
class CreateModelTrainingJobResponse(BaseModel):
|
|
1785
1795
|
success: bool
|
|
1796
|
+
|
|
1797
|
+
|
|
1798
|
+
class ScheduledQueryRunStatus(str, Enum):
|
|
1799
|
+
"""Status of a scheduled query run."""
|
|
1800
|
+
|
|
1801
|
+
UNSPECIFIED = "UNSPECIFIED"
|
|
1802
|
+
INITIALIZING = "INITIALIZING"
|
|
1803
|
+
INIT_FAILED = "INIT_FAILED"
|
|
1804
|
+
SKIPPED = "SKIPPED"
|
|
1805
|
+
QUEUED = "QUEUED"
|
|
1806
|
+
WORKING = "WORKING"
|
|
1807
|
+
COMPLETED = "COMPLETED"
|
|
1808
|
+
FAILED = "FAILED"
|
|
1809
|
+
CANCELED = "CANCELED"
|
|
1810
|
+
|
|
1811
|
+
|
|
1812
|
+
@dataclasses.dataclass
|
|
1813
|
+
class ScheduledQueryRun:
|
|
1814
|
+
"""A single scheduled query run."""
|
|
1815
|
+
|
|
1816
|
+
id: int
|
|
1817
|
+
environment_id: str
|
|
1818
|
+
deployment_id: str
|
|
1819
|
+
run_id: str
|
|
1820
|
+
cron_query_id: int
|
|
1821
|
+
cron_query_schedule_id: int
|
|
1822
|
+
cron_name: str
|
|
1823
|
+
gcr_execution_id: str
|
|
1824
|
+
gcr_job_name: str
|
|
1825
|
+
offline_query_id: str
|
|
1826
|
+
created_at: datetime
|
|
1827
|
+
updated_at: datetime
|
|
1828
|
+
status: ScheduledQueryRunStatus
|
|
1829
|
+
blocker_operation_id: str
|
|
1830
|
+
|
|
1831
|
+
@staticmethod
|
|
1832
|
+
def from_proto(proto_run: Any) -> "ScheduledQueryRun":
|
|
1833
|
+
"""Convert a proto ScheduledQueryRun to the dataclass version."""
|
|
1834
|
+
from datetime import timezone
|
|
1835
|
+
|
|
1836
|
+
# Map proto status enum to our enum
|
|
1837
|
+
status_map = {
|
|
1838
|
+
0: ScheduledQueryRunStatus.UNSPECIFIED,
|
|
1839
|
+
1: ScheduledQueryRunStatus.INITIALIZING,
|
|
1840
|
+
2: ScheduledQueryRunStatus.INIT_FAILED,
|
|
1841
|
+
3: ScheduledQueryRunStatus.SKIPPED,
|
|
1842
|
+
4: ScheduledQueryRunStatus.QUEUED,
|
|
1843
|
+
5: ScheduledQueryRunStatus.WORKING,
|
|
1844
|
+
6: ScheduledQueryRunStatus.COMPLETED,
|
|
1845
|
+
7: ScheduledQueryRunStatus.FAILED,
|
|
1846
|
+
8: ScheduledQueryRunStatus.CANCELED,
|
|
1847
|
+
}
|
|
1848
|
+
|
|
1849
|
+
# Helper to convert proto Timestamp to datetime
|
|
1850
|
+
def _timestamp_to_datetime(ts: Any) -> datetime:
|
|
1851
|
+
return datetime.fromtimestamp(ts.seconds + ts.nanos / 1e9, tz=timezone.utc)
|
|
1852
|
+
|
|
1853
|
+
return ScheduledQueryRun(
|
|
1854
|
+
id=proto_run.id,
|
|
1855
|
+
environment_id=proto_run.environment_id,
|
|
1856
|
+
deployment_id=proto_run.deployment_id,
|
|
1857
|
+
run_id=proto_run.run_id,
|
|
1858
|
+
cron_query_id=proto_run.cron_query_id,
|
|
1859
|
+
cron_query_schedule_id=proto_run.cron_query_schedule_id,
|
|
1860
|
+
cron_name=proto_run.cron_name,
|
|
1861
|
+
gcr_execution_id=proto_run.gcr_execution_id,
|
|
1862
|
+
gcr_job_name=proto_run.gcr_job_name,
|
|
1863
|
+
offline_query_id=proto_run.offline_query_id,
|
|
1864
|
+
created_at=_timestamp_to_datetime(proto_run.created_at),
|
|
1865
|
+
updated_at=_timestamp_to_datetime(proto_run.updated_at),
|
|
1866
|
+
status=status_map.get(proto_run.status, ScheduledQueryRunStatus.UNSPECIFIED),
|
|
1867
|
+
blocker_operation_id=proto_run.blocker_operation_id,
|
|
1868
|
+
)
|
|
1869
|
+
|
|
1870
|
+
|
|
1871
|
+
@dataclasses.dataclass
|
|
1872
|
+
class ManualTriggerScheduledQueryResponse:
|
|
1873
|
+
"""Response from manually triggering a scheduled query."""
|
|
1874
|
+
|
|
1875
|
+
scheduled_query_run: ScheduledQueryRun
|
|
1876
|
+
|
|
1877
|
+
@staticmethod
|
|
1878
|
+
def from_proto(proto_response: Any) -> "ManualTriggerScheduledQueryResponse":
|
|
1879
|
+
"""Convert a proto ManualTriggerScheduledQueryResponse to the dataclass version."""
|
|
1880
|
+
return ManualTriggerScheduledQueryResponse(
|
|
1881
|
+
scheduled_query_run=ScheduledQueryRun.from_proto(proto_response.scheduled_query_run),
|
|
1882
|
+
)
|
|
@@ -72,11 +72,20 @@ MODEL_SERIALIZERS = {
|
|
|
72
72
|
filename="model.cbm",
|
|
73
73
|
encoding=ModelEncoding.CBM,
|
|
74
74
|
serialize_fn=lambda model, path: model.save_model(path),
|
|
75
|
+
schema_fn=lambda model: ModelAttributeExtractor.infer_catboost_schemas(model),
|
|
75
76
|
),
|
|
76
77
|
ModelType.ONNX: ModelSerializationConfig(
|
|
77
78
|
filename="model.onnx",
|
|
78
79
|
encoding=ModelEncoding.PROTOBUF,
|
|
79
|
-
serialize_fn=lambda model, path:
|
|
80
|
+
serialize_fn=lambda model, path: ModelSerializer.with_import(
|
|
81
|
+
"onnx",
|
|
82
|
+
lambda onnx: onnx.save(
|
|
83
|
+
# Unwrap model if it has a _model attribute (e.g., wrapped ONNX models)
|
|
84
|
+
model._model if hasattr(model, "_model") else model,
|
|
85
|
+
path,
|
|
86
|
+
),
|
|
87
|
+
"Please install onnx to save ONNX models.",
|
|
88
|
+
),
|
|
80
89
|
),
|
|
81
90
|
}
|
|
82
91
|
|
|
@@ -85,7 +94,13 @@ class ModelSerializer:
|
|
|
85
94
|
def __init__(self, model: Any, model_type: Optional[ModelType]):
|
|
86
95
|
self._temp_files: List[str] = []
|
|
87
96
|
self.model = model
|
|
88
|
-
|
|
97
|
+
if model_type is not None:
|
|
98
|
+
self.model_type = model_type
|
|
99
|
+
self.model_class = None
|
|
100
|
+
else:
|
|
101
|
+
model_type, model_class = ModelAttributeExtractor.infer_model_type(model)
|
|
102
|
+
self.model_type = model_type
|
|
103
|
+
self.model_class = model_class
|
|
89
104
|
|
|
90
105
|
if self.model_type is None:
|
|
91
106
|
raise ValueError("Unable to infer model type from object and no type given.")
|
|
@@ -125,6 +140,32 @@ class ModelSerializer:
|
|
|
125
140
|
self._temp_files.append(model_path)
|
|
126
141
|
return model_path, serializer_config.encoding
|
|
127
142
|
|
|
143
|
+
def serialize_to_path(self, path: str, cleanup: bool = True) -> Tuple[str, ModelEncoding]:
|
|
144
|
+
assert self.model_type is not None, "Could not determine model type. Please set parameter: model_type."
|
|
145
|
+
return self.serialize_model_to_path(self.model, self.model_type, path, cleanup)
|
|
146
|
+
|
|
147
|
+
def serialize_model_to_path(
|
|
148
|
+
self,
|
|
149
|
+
model: Any,
|
|
150
|
+
model_type: ModelType,
|
|
151
|
+
path: str,
|
|
152
|
+
cleanup: bool = True,
|
|
153
|
+
) -> tuple[str, ModelEncoding]:
|
|
154
|
+
if model_type not in MODEL_SERIALIZERS:
|
|
155
|
+
raise NotImplementedError(f"Unsupported model type: {model_type}")
|
|
156
|
+
dir = path
|
|
157
|
+
|
|
158
|
+
serializer_config = MODEL_SERIALIZERS[model_type]
|
|
159
|
+
|
|
160
|
+
file_name = serializer_config.filename
|
|
161
|
+
|
|
162
|
+
model_path = os.path.join(dir, file_name)
|
|
163
|
+
|
|
164
|
+
serializer_config.serialize_fn(model, model_path)
|
|
165
|
+
if cleanup:
|
|
166
|
+
self._temp_files.append(model_path)
|
|
167
|
+
return model_path, serializer_config.encoding
|
|
168
|
+
|
|
128
169
|
@staticmethod
|
|
129
170
|
def with_import(module_name: str, func: Callable[[Any], Any], error_msg: str) -> Any:
|
|
130
171
|
try:
|
|
@@ -248,7 +289,15 @@ class ModelSerializer:
|
|
|
248
289
|
tensor_schema = _model_artifact_pb2.TensorSchema()
|
|
249
290
|
|
|
250
291
|
for shape, dtype in tensor_specs:
|
|
251
|
-
|
|
292
|
+
# Handle Chalk Tensor types
|
|
293
|
+
if hasattr(dtype, "__mro__") and any("Tensor" in base.__name__ for base in dtype.__mro__):
|
|
294
|
+
# Extract shape and dtype from Tensor type
|
|
295
|
+
if hasattr(dtype, "shape") and hasattr(dtype, "dtype"):
|
|
296
|
+
shape = dtype.shape
|
|
297
|
+
pa_dtype = dtype.dtype
|
|
298
|
+
else:
|
|
299
|
+
raise ValueError(f"Tensor type is missing shape or dtype attributes")
|
|
300
|
+
elif not isinstance(dtype, pa.DataType):
|
|
252
301
|
if dtype == str:
|
|
253
302
|
pa_dtype = pa.string()
|
|
254
303
|
elif dtype == int:
|
|
@@ -272,12 +321,73 @@ class ModelSerializer:
|
|
|
272
321
|
|
|
273
322
|
return tensor_schema
|
|
274
323
|
|
|
324
|
+
@staticmethod
|
|
325
|
+
def convert_onnx_list_schema_to_dict(schema: Any, model: Any, is_input: bool = True) -> Any:
|
|
326
|
+
"""Convert list-based schema to dict-based schema for ONNX models.
|
|
327
|
+
|
|
328
|
+
Args:
|
|
329
|
+
schema: The schema (list or dict)
|
|
330
|
+
model: The ONNX model (ModelProto or wrapped)
|
|
331
|
+
is_input: True for input schema, False for output schema
|
|
332
|
+
|
|
333
|
+
Returns:
|
|
334
|
+
Dict-based schema with field names from ONNX model
|
|
335
|
+
"""
|
|
336
|
+
if not isinstance(schema, list):
|
|
337
|
+
return schema
|
|
338
|
+
|
|
339
|
+
try:
|
|
340
|
+
import onnx # type: ignore[reportMissingImports]
|
|
341
|
+
except ImportError:
|
|
342
|
+
raise ValueError("onnx package is required to convert list schemas for ONNX models")
|
|
343
|
+
|
|
344
|
+
# Unwrap model if needed
|
|
345
|
+
onnx_model = model._model if hasattr(model, "_model") else model
|
|
346
|
+
|
|
347
|
+
if not isinstance(onnx_model, onnx.ModelProto):
|
|
348
|
+
raise ValueError(
|
|
349
|
+
f"ONNX models must be registered with tabular schema (dict format). "
|
|
350
|
+
+ f"Use dict format like {{'input': Tensor[...]}} instead of list format."
|
|
351
|
+
)
|
|
352
|
+
|
|
353
|
+
# Get input/output names from ONNX model
|
|
354
|
+
if is_input:
|
|
355
|
+
names = [inp.name for inp in onnx_model.graph.input]
|
|
356
|
+
schema_type = "input"
|
|
357
|
+
else:
|
|
358
|
+
names = [out.name for out in onnx_model.graph.output]
|
|
359
|
+
schema_type = "output"
|
|
360
|
+
|
|
361
|
+
if len(names) != len(schema):
|
|
362
|
+
raise ValueError(f"ONNX model has {len(names)} {schema_type}s but schema has {len(schema)} entries")
|
|
363
|
+
|
|
364
|
+
# Convert to dict format
|
|
365
|
+
return {name: spec for name, spec in zip(names, schema)}
|
|
366
|
+
|
|
275
367
|
@staticmethod
|
|
276
368
|
def convert_schema(schema: Any) -> Optional[_model_artifact_pb2.ModelSchema]:
|
|
277
369
|
model_schema = _model_artifact_pb2.ModelSchema()
|
|
278
370
|
if schema is not None:
|
|
279
371
|
if isinstance(schema, dict):
|
|
280
|
-
|
|
372
|
+
# Convert Tensor/Vector types to their PyArrow types for tabular schema
|
|
373
|
+
converted_schema = {}
|
|
374
|
+
for col_name, dtype in schema.items():
|
|
375
|
+
if hasattr(dtype, "__mro__") and any("Tensor" in base.__name__ for base in dtype.__mro__):
|
|
376
|
+
# Use Tensor's to_pyarrow_dtype() method to convert to Arrow type
|
|
377
|
+
if hasattr(dtype, "to_pyarrow_dtype"):
|
|
378
|
+
converted_schema[col_name] = dtype.to_pyarrow_dtype()
|
|
379
|
+
else:
|
|
380
|
+
raise ValueError(f"Tensor type for '{col_name}' is missing to_pyarrow_dtype method")
|
|
381
|
+
elif hasattr(dtype, "__mro__") and any("Vector" in base.__name__ for base in dtype.__mro__):
|
|
382
|
+
# Vector already has a .dtype attribute that's a PyArrow type
|
|
383
|
+
if hasattr(dtype, "dtype"):
|
|
384
|
+
converted_schema[col_name] = dtype.dtype
|
|
385
|
+
else:
|
|
386
|
+
raise ValueError(f"Vector type for '{col_name}' is missing dtype attribute")
|
|
387
|
+
else:
|
|
388
|
+
converted_schema[col_name] = dtype
|
|
389
|
+
|
|
390
|
+
model_schema.tabular.CopyFrom(ModelSerializer.build_tabular_schema(converted_schema))
|
|
281
391
|
elif isinstance(schema, list):
|
|
282
392
|
model_schema.tensor.CopyFrom(ModelSerializer.build_tensor_schema(schema))
|
|
283
393
|
else:
|
|
@@ -289,21 +399,27 @@ class ModelSerializer:
|
|
|
289
399
|
|
|
290
400
|
@staticmethod
|
|
291
401
|
def convert_run_criterion_to_proto(
|
|
292
|
-
run_name: Optional[str], criterion: Optional[ModelRunCriterion]
|
|
402
|
+
run_id: Optional[str] = None, run_name: Optional[str] = None, criterion: Optional[ModelRunCriterion] = None
|
|
293
403
|
) -> Optional[RunCriterion]:
|
|
294
|
-
if run_name is None:
|
|
295
|
-
|
|
404
|
+
if run_id is None and run_name is None:
|
|
405
|
+
raise ValueError("Please specify either run_id or run_name.")
|
|
296
406
|
|
|
297
407
|
if criterion is None:
|
|
298
|
-
return RunCriterion(run_id=run_name)
|
|
408
|
+
return RunCriterion(run_id=run_id, run_name=run_name)
|
|
299
409
|
|
|
300
410
|
if criterion.direction == "max":
|
|
301
411
|
return RunCriterion(
|
|
302
|
-
run_id=
|
|
412
|
+
run_id=run_id,
|
|
413
|
+
run_name=run_name,
|
|
414
|
+
metric=criterion.metric,
|
|
415
|
+
direction=RunCriterionDirection.RUN_CRITERION_DIRECTION_MAX,
|
|
303
416
|
)
|
|
304
417
|
elif criterion.direction == "min":
|
|
305
418
|
return RunCriterion(
|
|
306
|
-
run_id=
|
|
419
|
+
run_id=run_id,
|
|
420
|
+
run_name=run_name,
|
|
421
|
+
metric=criterion.metric,
|
|
422
|
+
direction=RunCriterionDirection.RUN_CRITERION_DIRECTION_MIN,
|
|
307
423
|
)
|
|
308
424
|
else:
|
|
309
425
|
raise ValueError(
|
chalk/config/project_config.py
CHANGED
|
@@ -184,7 +184,7 @@ def _load_project_config_at_path(filename: Path) -> Optional[ProjectSettings]:
|
|
|
184
184
|
except OSError:
|
|
185
185
|
return None
|
|
186
186
|
except ValueError as e:
|
|
187
|
-
raise ValueError(f"Failed to load project config (chalkpy=={__version__}): {e}")
|
|
187
|
+
raise ValueError(f"Failed to load project config from {filename} (chalkpy=={__version__}): {e}") from e
|
|
188
188
|
|
|
189
189
|
|
|
190
190
|
def load_project_config() -> Optional[ProjectSettings]:
|