chalkpy 2.90.1__py3-none-any.whl → 2.95.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- chalk/__init__.py +2 -1
- chalk/_gen/chalk/arrow/v1/arrow_pb2.py +7 -5
- chalk/_gen/chalk/arrow/v1/arrow_pb2.pyi +6 -0
- chalk/_gen/chalk/artifacts/v1/chart_pb2.py +16 -16
- chalk/_gen/chalk/artifacts/v1/chart_pb2.pyi +4 -0
- chalk/_gen/chalk/artifacts/v1/cron_query_pb2.py +8 -7
- chalk/_gen/chalk/artifacts/v1/cron_query_pb2.pyi +5 -0
- chalk/_gen/chalk/common/v1/offline_query_pb2.py +17 -15
- chalk/_gen/chalk/common/v1/offline_query_pb2.pyi +25 -0
- chalk/_gen/chalk/common/v1/script_task_pb2.py +3 -3
- chalk/_gen/chalk/common/v1/script_task_pb2.pyi +2 -0
- chalk/_gen/chalk/dataframe/__init__.py +0 -0
- chalk/_gen/chalk/dataframe/v1/__init__.py +0 -0
- chalk/_gen/chalk/dataframe/v1/dataframe_pb2.py +48 -0
- chalk/_gen/chalk/dataframe/v1/dataframe_pb2.pyi +123 -0
- chalk/_gen/chalk/dataframe/v1/dataframe_pb2_grpc.py +4 -0
- chalk/_gen/chalk/dataframe/v1/dataframe_pb2_grpc.pyi +4 -0
- chalk/_gen/chalk/graph/v1/graph_pb2.py +150 -149
- chalk/_gen/chalk/graph/v1/graph_pb2.pyi +25 -0
- chalk/_gen/chalk/graph/v1/sources_pb2.py +94 -84
- chalk/_gen/chalk/graph/v1/sources_pb2.pyi +56 -0
- chalk/_gen/chalk/kubernetes/v1/horizontalpodautoscaler_pb2.py +79 -0
- chalk/_gen/chalk/kubernetes/v1/horizontalpodautoscaler_pb2.pyi +377 -0
- chalk/_gen/chalk/kubernetes/v1/horizontalpodautoscaler_pb2_grpc.py +4 -0
- chalk/_gen/chalk/kubernetes/v1/horizontalpodautoscaler_pb2_grpc.pyi +4 -0
- chalk/_gen/chalk/kubernetes/v1/scaledobject_pb2.py +43 -7
- chalk/_gen/chalk/kubernetes/v1/scaledobject_pb2.pyi +252 -2
- chalk/_gen/chalk/protosql/v1/sql_service_pb2.py +54 -27
- chalk/_gen/chalk/protosql/v1/sql_service_pb2.pyi +131 -3
- chalk/_gen/chalk/protosql/v1/sql_service_pb2_grpc.py +45 -0
- chalk/_gen/chalk/protosql/v1/sql_service_pb2_grpc.pyi +14 -0
- chalk/_gen/chalk/python/v1/types_pb2.py +14 -14
- chalk/_gen/chalk/python/v1/types_pb2.pyi +8 -0
- chalk/_gen/chalk/server/v1/benchmark_pb2.py +76 -0
- chalk/_gen/chalk/server/v1/benchmark_pb2.pyi +156 -0
- chalk/_gen/chalk/server/v1/benchmark_pb2_grpc.py +258 -0
- chalk/_gen/chalk/server/v1/benchmark_pb2_grpc.pyi +84 -0
- chalk/_gen/chalk/server/v1/billing_pb2.py +40 -38
- chalk/_gen/chalk/server/v1/billing_pb2.pyi +17 -1
- chalk/_gen/chalk/server/v1/branches_pb2.py +45 -0
- chalk/_gen/chalk/server/v1/branches_pb2.pyi +80 -0
- chalk/_gen/chalk/server/v1/branches_pb2_grpc.pyi +36 -0
- chalk/_gen/chalk/server/v1/builder_pb2.py +358 -288
- chalk/_gen/chalk/server/v1/builder_pb2.pyi +360 -10
- chalk/_gen/chalk/server/v1/builder_pb2_grpc.py +225 -0
- chalk/_gen/chalk/server/v1/builder_pb2_grpc.pyi +60 -0
- chalk/_gen/chalk/server/v1/chart_pb2.py +10 -10
- chalk/_gen/chalk/server/v1/chart_pb2.pyi +18 -2
- chalk/_gen/chalk/server/v1/clickhouse_pb2.py +42 -0
- chalk/_gen/chalk/server/v1/clickhouse_pb2.pyi +17 -0
- chalk/_gen/chalk/server/v1/clickhouse_pb2_grpc.py +78 -0
- chalk/_gen/chalk/server/v1/clickhouse_pb2_grpc.pyi +38 -0
- chalk/_gen/chalk/server/v1/cloud_components_pb2.py +141 -119
- chalk/_gen/chalk/server/v1/cloud_components_pb2.pyi +106 -4
- chalk/_gen/chalk/server/v1/cloud_components_pb2_grpc.py +45 -0
- chalk/_gen/chalk/server/v1/cloud_components_pb2_grpc.pyi +12 -0
- chalk/_gen/chalk/server/v1/cloud_credentials_pb2.py +11 -3
- chalk/_gen/chalk/server/v1/cloud_credentials_pb2.pyi +20 -0
- chalk/_gen/chalk/server/v1/cloud_credentials_pb2_grpc.py +45 -0
- chalk/_gen/chalk/server/v1/cloud_credentials_pb2_grpc.pyi +12 -0
- chalk/_gen/chalk/server/v1/dataplanejobqueue_pb2.py +52 -38
- chalk/_gen/chalk/server/v1/dataplanejobqueue_pb2.pyi +62 -1
- chalk/_gen/chalk/server/v1/dataplanejobqueue_pb2_grpc.py +90 -0
- chalk/_gen/chalk/server/v1/dataplanejobqueue_pb2_grpc.pyi +24 -0
- chalk/_gen/chalk/server/v1/dataplaneworkflows_pb2.py +90 -0
- chalk/_gen/chalk/server/v1/dataplaneworkflows_pb2.pyi +264 -0
- chalk/_gen/chalk/server/v1/dataplaneworkflows_pb2_grpc.py +170 -0
- chalk/_gen/chalk/server/v1/dataplaneworkflows_pb2_grpc.pyi +62 -0
- chalk/_gen/chalk/server/v1/deploy_pb2.py +9 -3
- chalk/_gen/chalk/server/v1/deploy_pb2.pyi +12 -0
- chalk/_gen/chalk/server/v1/deploy_pb2_grpc.py +45 -0
- chalk/_gen/chalk/server/v1/deploy_pb2_grpc.pyi +12 -0
- chalk/_gen/chalk/server/v1/deployment_pb2.py +6 -6
- chalk/_gen/chalk/server/v1/deployment_pb2.pyi +20 -0
- chalk/_gen/chalk/server/v1/environment_pb2.py +14 -12
- chalk/_gen/chalk/server/v1/environment_pb2.pyi +19 -0
- chalk/_gen/chalk/server/v1/eventbus_pb2.py +4 -2
- chalk/_gen/chalk/server/v1/files_pb2.py +65 -0
- chalk/_gen/chalk/server/v1/files_pb2.pyi +167 -0
- chalk/_gen/chalk/server/v1/files_pb2_grpc.py +4 -0
- chalk/_gen/chalk/server/v1/files_pb2_grpc.pyi +4 -0
- chalk/_gen/chalk/server/v1/graph_pb2.py +38 -26
- chalk/_gen/chalk/server/v1/graph_pb2.pyi +58 -0
- chalk/_gen/chalk/server/v1/graph_pb2_grpc.py +47 -0
- chalk/_gen/chalk/server/v1/graph_pb2_grpc.pyi +18 -0
- chalk/_gen/chalk/server/v1/incident_pb2.py +23 -21
- chalk/_gen/chalk/server/v1/incident_pb2.pyi +15 -1
- chalk/_gen/chalk/server/v1/indexing_job_pb2.py +44 -0
- chalk/_gen/chalk/server/v1/indexing_job_pb2.pyi +38 -0
- chalk/_gen/chalk/server/v1/indexing_job_pb2_grpc.py +78 -0
- chalk/_gen/chalk/server/v1/indexing_job_pb2_grpc.pyi +38 -0
- chalk/_gen/chalk/server/v1/integrations_pb2.py +11 -9
- chalk/_gen/chalk/server/v1/integrations_pb2.pyi +34 -2
- chalk/_gen/chalk/server/v1/kube_pb2.py +29 -19
- chalk/_gen/chalk/server/v1/kube_pb2.pyi +28 -0
- chalk/_gen/chalk/server/v1/kube_pb2_grpc.py +45 -0
- chalk/_gen/chalk/server/v1/kube_pb2_grpc.pyi +12 -0
- chalk/_gen/chalk/server/v1/log_pb2.py +21 -3
- chalk/_gen/chalk/server/v1/log_pb2.pyi +68 -0
- chalk/_gen/chalk/server/v1/log_pb2_grpc.py +90 -0
- chalk/_gen/chalk/server/v1/log_pb2_grpc.pyi +24 -0
- chalk/_gen/chalk/server/v1/model_registry_pb2.py +10 -10
- chalk/_gen/chalk/server/v1/model_registry_pb2.pyi +4 -1
- chalk/_gen/chalk/server/v1/plandebug_pb2.py +53 -0
- chalk/_gen/chalk/server/v1/plandebug_pb2.pyi +86 -0
- chalk/_gen/chalk/server/v1/plandebug_pb2_grpc.py +168 -0
- chalk/_gen/chalk/server/v1/plandebug_pb2_grpc.pyi +60 -0
- chalk/_gen/chalk/server/v1/queries_pb2.py +66 -66
- chalk/_gen/chalk/server/v1/queries_pb2.pyi +32 -2
- chalk/_gen/chalk/server/v1/scheduled_query_run_pb2.py +12 -12
- chalk/_gen/chalk/server/v1/scheduled_query_run_pb2.pyi +16 -3
- chalk/_gen/chalk/server/v1/scheduler_pb2.py +24 -12
- chalk/_gen/chalk/server/v1/scheduler_pb2.pyi +61 -1
- chalk/_gen/chalk/server/v1/scheduler_pb2_grpc.py +90 -0
- chalk/_gen/chalk/server/v1/scheduler_pb2_grpc.pyi +24 -0
- chalk/_gen/chalk/server/v1/script_tasks_pb2.py +15 -3
- chalk/_gen/chalk/server/v1/script_tasks_pb2.pyi +22 -0
- chalk/_gen/chalk/server/v1/script_tasks_pb2_grpc.py +90 -0
- chalk/_gen/chalk/server/v1/script_tasks_pb2_grpc.pyi +24 -0
- chalk/_gen/chalk/server/v1/sql_interface_pb2.py +75 -0
- chalk/_gen/chalk/server/v1/sql_interface_pb2.pyi +142 -0
- chalk/_gen/chalk/server/v1/sql_interface_pb2_grpc.py +349 -0
- chalk/_gen/chalk/server/v1/sql_interface_pb2_grpc.pyi +114 -0
- chalk/_gen/chalk/server/v1/sql_queries_pb2.py +48 -0
- chalk/_gen/chalk/server/v1/sql_queries_pb2.pyi +150 -0
- chalk/_gen/chalk/server/v1/sql_queries_pb2_grpc.py +123 -0
- chalk/_gen/chalk/server/v1/sql_queries_pb2_grpc.pyi +52 -0
- chalk/_gen/chalk/server/v1/team_pb2.py +154 -141
- chalk/_gen/chalk/server/v1/team_pb2.pyi +30 -2
- chalk/_gen/chalk/server/v1/team_pb2_grpc.py +45 -0
- chalk/_gen/chalk/server/v1/team_pb2_grpc.pyi +12 -0
- chalk/_gen/chalk/server/v1/topic_pb2.py +5 -3
- chalk/_gen/chalk/server/v1/topic_pb2.pyi +10 -1
- chalk/_gen/chalk/server/v1/trace_pb2.py +44 -40
- chalk/_gen/chalk/server/v1/trace_pb2.pyi +20 -0
- chalk/_gen/chalk/streaming/v1/debug_service_pb2.py +62 -0
- chalk/_gen/chalk/streaming/v1/debug_service_pb2.pyi +75 -0
- chalk/_gen/chalk/streaming/v1/debug_service_pb2_grpc.py +221 -0
- chalk/_gen/chalk/streaming/v1/debug_service_pb2_grpc.pyi +88 -0
- chalk/_gen/chalk/streaming/v1/simple_streaming_service_pb2.py +16 -10
- chalk/_gen/chalk/streaming/v1/simple_streaming_service_pb2.pyi +52 -1
- chalk/_gen/chalk/streaming/v1/simple_streaming_service_pb2_grpc.py +48 -0
- chalk/_gen/chalk/streaming/v1/simple_streaming_service_pb2_grpc.pyi +20 -0
- chalk/_gen/chalk/utils/v1/field_change_pb2.py +32 -0
- chalk/_gen/chalk/utils/v1/field_change_pb2.pyi +42 -0
- chalk/_gen/chalk/utils/v1/field_change_pb2_grpc.py +4 -0
- chalk/_gen/chalk/utils/v1/field_change_pb2_grpc.pyi +4 -0
- chalk/_lsp/error_builder.py +11 -0
- chalk/_version.py +1 -1
- chalk/client/client.py +128 -43
- chalk/client/client_async.py +149 -0
- chalk/client/client_async_impl.py +22 -0
- chalk/client/client_grpc.py +539 -104
- chalk/client/client_impl.py +449 -122
- chalk/client/dataset.py +7 -1
- chalk/client/models.py +98 -0
- chalk/client/serialization/model_serialization.py +92 -9
- chalk/df/LazyFramePlaceholder.py +1154 -0
- chalk/features/_class_property.py +7 -0
- chalk/features/_embedding/embedding.py +1 -0
- chalk/features/_encoding/converter.py +83 -2
- chalk/features/feature_field.py +40 -30
- chalk/features/feature_set_decorator.py +1 -0
- chalk/features/feature_wrapper.py +42 -3
- chalk/features/hooks.py +81 -10
- chalk/features/inference.py +33 -31
- chalk/features/resolver.py +224 -24
- chalk/functions/__init__.py +65 -3
- chalk/gitignore/gitignore_parser.py +5 -1
- chalk/importer.py +142 -68
- chalk/ml/__init__.py +2 -0
- chalk/ml/model_hooks.py +194 -26
- chalk/ml/model_reference.py +56 -8
- chalk/ml/model_version.py +24 -15
- chalk/ml/utils.py +20 -17
- chalk/operators/_utils.py +10 -3
- chalk/parsed/_proto/export.py +22 -0
- chalk/parsed/duplicate_input_gql.py +3 -0
- chalk/parsed/json_conversions.py +20 -14
- chalk/parsed/to_proto.py +16 -4
- chalk/parsed/user_types_to_json.py +31 -10
- chalk/parsed/validation_from_registries.py +182 -0
- chalk/queries/named_query.py +16 -6
- chalk/queries/scheduled_query.py +9 -1
- chalk/serialization/parsed_annotation.py +24 -11
- chalk/sql/__init__.py +18 -0
- chalk/sql/_internal/integrations/databricks.py +55 -17
- chalk/sql/_internal/integrations/mssql.py +127 -62
- chalk/sql/_internal/integrations/redshift.py +4 -0
- chalk/sql/_internal/sql_file_resolver.py +53 -9
- chalk/sql/_internal/sql_source.py +35 -2
- chalk/streams/_kafka_source.py +5 -1
- chalk/streams/_windows.py +15 -2
- chalk/utils/_otel_version.py +13 -0
- chalk/utils/async_helpers.py +2 -2
- chalk/utils/missing_dependency.py +5 -4
- chalk/utils/tracing.py +185 -95
- {chalkpy-2.90.1.dist-info → chalkpy-2.95.3.dist-info}/METADATA +4 -6
- {chalkpy-2.90.1.dist-info → chalkpy-2.95.3.dist-info}/RECORD +202 -146
- {chalkpy-2.90.1.dist-info → chalkpy-2.95.3.dist-info}/WHEEL +0 -0
- {chalkpy-2.90.1.dist-info → chalkpy-2.95.3.dist-info}/entry_points.txt +0 -0
- {chalkpy-2.90.1.dist-info → chalkpy-2.95.3.dist-info}/top_level.txt +0 -0
chalk/importer.py
CHANGED
|
@@ -304,6 +304,19 @@ def _parse_agg_function_call(expr: Underscore | None) -> Tuple[str, Underscore,
|
|
|
304
304
|
f"expecting 'int' type argument for 'k', but received arg of type '{type(call_expr._chalk__kwargs.get('k'))}'"
|
|
305
305
|
)
|
|
306
306
|
opts = FrozenOrderedSet(call_expr._chalk__kwargs.items())
|
|
307
|
+
elif aggregation == "approx_percentile":
|
|
308
|
+
if len(call_expr._chalk__args) > 0:
|
|
309
|
+
raise ChalkParseError("should not have any positional arguments")
|
|
310
|
+
elif {"quantile"} != call_expr._chalk__kwargs.keys():
|
|
311
|
+
raise ChalkParseError("expecting exactly one required keyword argument 'quantile'")
|
|
312
|
+
elif not isinstance(call_expr._chalk__kwargs.get("quantile"), float):
|
|
313
|
+
raise ChalkParseError(
|
|
314
|
+
f"expecting 'float' type argument for 'quantile', but received arg of type '{type(call_expr._chalk__kwargs.get('quantile'))}'"
|
|
315
|
+
)
|
|
316
|
+
# TODO: expand proto definition to accept kwargs that are not necessarily `k`
|
|
317
|
+
quantile = call_expr._chalk__kwargs["quantile"]
|
|
318
|
+
nano_quantile = int(round(quantile * 1_000_000_000))
|
|
319
|
+
opts = FrozenOrderedSet([("k", nano_quantile)])
|
|
307
320
|
elif aggregation in ("min_by_n", "max_by_n"):
|
|
308
321
|
if len(call_expr._chalk__kwargs) > 0:
|
|
309
322
|
raise ChalkParseError("should not have any keyword arguments")
|
|
@@ -433,8 +446,6 @@ def run_post_import_fixups():
|
|
|
433
446
|
# "1m", "2m", materialization={...},
|
|
434
447
|
# expression=_.transactions[_.amount].sum(),
|
|
435
448
|
# )
|
|
436
|
-
assert f.underscore_expression is not None
|
|
437
|
-
assert f.window_materialization is not None
|
|
438
449
|
|
|
439
450
|
try:
|
|
440
451
|
f.window_materialization_parsed = parse_windowed_materialization(f=f)
|
|
@@ -572,39 +583,51 @@ def parse_grouped_window(f: Feature) -> WindowConfigResolved:
|
|
|
572
583
|
aggregation_kwargs=aggregation_kwargs,
|
|
573
584
|
pyarrow_dtype=pyarrow_dtype,
|
|
574
585
|
filters=parsed_filters,
|
|
575
|
-
backfill_resolver=
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
)
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
586
|
+
backfill_resolver=(
|
|
587
|
+
_try_parse_resolver_fqn(
|
|
588
|
+
"backfill_resolver",
|
|
589
|
+
f.window_materialization.get("backfill_resolver", None),
|
|
590
|
+
)
|
|
591
|
+
if isinstance(f.window_materialization, dict)
|
|
592
|
+
else None
|
|
593
|
+
),
|
|
594
|
+
backfill_schedule=(
|
|
595
|
+
f.window_materialization.get("backfill_schedule", None)
|
|
596
|
+
if isinstance(f.window_materialization, dict)
|
|
597
|
+
else None
|
|
598
|
+
),
|
|
599
|
+
backfill_lookback_duration_seconds=(
|
|
600
|
+
_try_parse_duration(
|
|
601
|
+
"backfill_lookback_duration",
|
|
602
|
+
f.window_materialization.get("backfill_lookback_duration", None),
|
|
603
|
+
)
|
|
604
|
+
if isinstance(f.window_materialization, dict)
|
|
605
|
+
else None
|
|
606
|
+
),
|
|
607
|
+
backfill_start_time=(
|
|
608
|
+
_try_parse_datetime(
|
|
609
|
+
"backfill_start_time",
|
|
610
|
+
f.window_materialization.get("backfill_start_time", None),
|
|
611
|
+
)
|
|
612
|
+
if isinstance(f.window_materialization, dict)
|
|
613
|
+
else None
|
|
614
|
+
),
|
|
615
|
+
continuous_resolver=(
|
|
616
|
+
_try_parse_resolver_fqn(
|
|
617
|
+
"continuous_resolver",
|
|
618
|
+
f.window_materialization.get("continuous_resolver", None),
|
|
619
|
+
)
|
|
620
|
+
if isinstance(f.window_materialization, dict)
|
|
621
|
+
else None
|
|
622
|
+
),
|
|
623
|
+
continuous_buffer_duration_seconds=(
|
|
624
|
+
_try_parse_duration(
|
|
625
|
+
"continuous_buffer_duration",
|
|
626
|
+
f.window_materialization.get("continuous_buffer_duration", None),
|
|
627
|
+
)
|
|
628
|
+
if isinstance(f.window_materialization, dict)
|
|
629
|
+
else None
|
|
630
|
+
),
|
|
608
631
|
)
|
|
609
632
|
|
|
610
633
|
return cfg
|
|
@@ -800,39 +823,51 @@ def parse_windowed_materialization(f: Feature) -> WindowConfigResolved | None:
|
|
|
800
823
|
aggregation_kwargs=aggregation_kwargs,
|
|
801
824
|
pyarrow_dtype=f.converter.pyarrow_dtype,
|
|
802
825
|
filters=parsed_filters,
|
|
803
|
-
backfill_resolver=
|
|
804
|
-
|
|
805
|
-
|
|
806
|
-
|
|
807
|
-
|
|
808
|
-
|
|
809
|
-
|
|
810
|
-
|
|
811
|
-
|
|
812
|
-
|
|
813
|
-
|
|
814
|
-
|
|
815
|
-
)
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
|
|
822
|
-
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
|
|
826
|
-
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
|
|
826
|
+
backfill_resolver=(
|
|
827
|
+
_try_parse_resolver_fqn(
|
|
828
|
+
"backfill_resolver",
|
|
829
|
+
f.window_materialization.get("backfill_resolver", None),
|
|
830
|
+
)
|
|
831
|
+
if isinstance(f.window_materialization, dict)
|
|
832
|
+
else None
|
|
833
|
+
),
|
|
834
|
+
backfill_schedule=(
|
|
835
|
+
f.window_materialization.get("backfill_schedule", None)
|
|
836
|
+
if isinstance(f.window_materialization, dict)
|
|
837
|
+
else None
|
|
838
|
+
),
|
|
839
|
+
backfill_lookback_duration_seconds=(
|
|
840
|
+
_try_parse_duration(
|
|
841
|
+
"backfill_lookback_duration",
|
|
842
|
+
f.window_materialization.get("backfill_lookback_duration", None),
|
|
843
|
+
)
|
|
844
|
+
if isinstance(f.window_materialization, dict)
|
|
845
|
+
else None
|
|
846
|
+
),
|
|
847
|
+
backfill_start_time=(
|
|
848
|
+
_try_parse_datetime(
|
|
849
|
+
"backfill_start_time",
|
|
850
|
+
f.window_materialization.get("backfill_start_time", None),
|
|
851
|
+
)
|
|
852
|
+
if isinstance(f.window_materialization, dict)
|
|
853
|
+
else None
|
|
854
|
+
),
|
|
855
|
+
continuous_resolver=(
|
|
856
|
+
_try_parse_resolver_fqn(
|
|
857
|
+
"continuous_resolver",
|
|
858
|
+
f.window_materialization.get("continuous_resolver", None),
|
|
859
|
+
)
|
|
860
|
+
if isinstance(f.window_materialization, dict)
|
|
861
|
+
else None
|
|
862
|
+
),
|
|
863
|
+
continuous_buffer_duration_seconds=(
|
|
864
|
+
_try_parse_duration(
|
|
865
|
+
"continuous_buffer_duration",
|
|
866
|
+
f.window_materialization.get("continuous_buffer_duration", None),
|
|
867
|
+
)
|
|
868
|
+
if isinstance(f.window_materialization, dict)
|
|
869
|
+
else None
|
|
870
|
+
),
|
|
836
871
|
)
|
|
837
872
|
|
|
838
873
|
|
|
@@ -1010,6 +1045,33 @@ class _UnderscoreValidationError(ValueError):
|
|
|
1010
1045
|
...
|
|
1011
1046
|
|
|
1012
1047
|
|
|
1048
|
+
def _has_group_by_in_parent_chain(underscore: Underscore) -> bool:
|
|
1049
|
+
"""
|
|
1050
|
+
Traverse parent chain to check if .group_by() exists before .agg().
|
|
1051
|
+
|
|
1052
|
+
For valid group_by_windowed: _.x.group_by(_.y).agg(_.z.sum())
|
|
1053
|
+
- Looks for: UnderscoreCall -> UnderscoreAttr("group_by")
|
|
1054
|
+
|
|
1055
|
+
Returns True if .group_by() found, False otherwise.
|
|
1056
|
+
"""
|
|
1057
|
+
current: Optional[Any] = underscore
|
|
1058
|
+
|
|
1059
|
+
while current is not None:
|
|
1060
|
+
# Check if current is a .group_by() call
|
|
1061
|
+
if isinstance(current, UnderscoreCall):
|
|
1062
|
+
parent = current._chalk__parent
|
|
1063
|
+
if isinstance(parent, UnderscoreAttr) and parent._chalk__attr == "group_by":
|
|
1064
|
+
return True
|
|
1065
|
+
|
|
1066
|
+
# Move to parent
|
|
1067
|
+
if hasattr(current, "_chalk__parent"):
|
|
1068
|
+
current = current._chalk__parent
|
|
1069
|
+
else:
|
|
1070
|
+
break
|
|
1071
|
+
|
|
1072
|
+
return False
|
|
1073
|
+
|
|
1074
|
+
|
|
1013
1075
|
class ChalkImporter:
|
|
1014
1076
|
def __init__(self):
|
|
1015
1077
|
super().__init__()
|
|
@@ -1111,6 +1173,9 @@ class ChalkImporter:
|
|
|
1111
1173
|
for feature_class in FeatureSetBase.registry.values():
|
|
1112
1174
|
# Iterate through every class, to find every underscore definition.
|
|
1113
1175
|
for f in feature_class.features:
|
|
1176
|
+
if f.is_windowed_pseudofeature is True:
|
|
1177
|
+
# need one LSP just for the base
|
|
1178
|
+
continue
|
|
1114
1179
|
if f.underscore_expression is not None:
|
|
1115
1180
|
# Validate that the underscore expression is well-formed.
|
|
1116
1181
|
# If it is not well-formed, then an `_UnderscoreValidationError` will
|
|
@@ -1489,6 +1554,15 @@ def _supplemental_validate_underscore_expression(
|
|
|
1489
1554
|
)
|
|
1490
1555
|
return None
|
|
1491
1556
|
|
|
1557
|
+
# Validate .agg() usage (addressing TODO at line 1522)
|
|
1558
|
+
if op_name == "agg":
|
|
1559
|
+
if not _has_group_by_in_parent_chain(caller):
|
|
1560
|
+
raise _UnderscoreValidationError(
|
|
1561
|
+
"'.agg()' can only be used with '.group_by()' for group_by_windowed features. "
|
|
1562
|
+
+ "For windowed features, use direct aggregation methods instead. "
|
|
1563
|
+
+ "For example, instead of using '.agg(_.field.method())', use '.field.method()' directly on the filtered DataFrame"
|
|
1564
|
+
)
|
|
1565
|
+
|
|
1492
1566
|
return None
|
|
1493
1567
|
|
|
1494
1568
|
# TODO: check that op_name is a supported agg or .agg/.group_by/etc
|
chalk/ml/__init__.py
CHANGED
|
@@ -2,6 +2,7 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
from chalk.ml.model_file_transfer import FileInfo, HFSourceConfig, LocalSourceConfig, S3SourceConfig, SourceConfig
|
|
4
4
|
from chalk.ml.model_reference import ModelReference
|
|
5
|
+
from chalk.ml.model_version import ModelVersion
|
|
5
6
|
from chalk.ml.utils import ModelClass, ModelEncoding, ModelRunCriterion, ModelType
|
|
6
7
|
|
|
7
8
|
__all__ = (
|
|
@@ -9,6 +10,7 @@ __all__ = (
|
|
|
9
10
|
"ModelClass",
|
|
10
11
|
"ModelEncoding",
|
|
11
12
|
"ModelReference",
|
|
13
|
+
"ModelVersion",
|
|
12
14
|
"SourceConfig",
|
|
13
15
|
"LocalSourceConfig",
|
|
14
16
|
"S3SourceConfig",
|
chalk/ml/model_hooks.py
CHANGED
|
@@ -1,12 +1,15 @@
|
|
|
1
|
-
from typing import Any, Dict, Optional, Protocol, Tuple
|
|
1
|
+
from typing import TYPE_CHECKING, Any, Dict, Optional, Protocol, Tuple
|
|
2
2
|
|
|
3
3
|
from chalk.ml.utils import ModelClass, ModelEncoding, ModelType
|
|
4
4
|
|
|
5
|
+
if TYPE_CHECKING:
|
|
6
|
+
from chalk.features.resolver import ResourceHint
|
|
7
|
+
|
|
5
8
|
|
|
6
9
|
class ModelInference(Protocol):
|
|
7
10
|
"""Abstract base class for model loading and inference."""
|
|
8
11
|
|
|
9
|
-
def load_model(self, path: str) -> Any:
|
|
12
|
+
def load_model(self, path: str, resource_hint: Optional["ResourceHint"] = None) -> Any:
|
|
10
13
|
"""Load a model from the given path."""
|
|
11
14
|
pass
|
|
12
15
|
|
|
@@ -14,11 +17,27 @@ class ModelInference(Protocol):
|
|
|
14
17
|
"""Run inference on the model with input X."""
|
|
15
18
|
pass
|
|
16
19
|
|
|
20
|
+
def prepare_input(self, feature_table: Any) -> Any:
|
|
21
|
+
"""Convert PyArrow table to model input format.
|
|
22
|
+
|
|
23
|
+
Default implementation converts to numpy array via __array__().
|
|
24
|
+
Override for model-specific input formats (e.g., ONNX struct arrays).
|
|
25
|
+
"""
|
|
26
|
+
return feature_table.__array__()
|
|
27
|
+
|
|
28
|
+
def extract_output(self, result: Any, output_feature_name: str) -> Any:
|
|
29
|
+
"""Extract single output from model result.
|
|
30
|
+
|
|
31
|
+
Default implementation returns result as-is (for single outputs).
|
|
32
|
+
Override for models with structured outputs (e.g., ONNX struct arrays).
|
|
33
|
+
"""
|
|
34
|
+
return result
|
|
35
|
+
|
|
17
36
|
|
|
18
37
|
class XGBoostClassifierInference(ModelInference):
|
|
19
38
|
"""Model inference for XGBoost classifiers."""
|
|
20
39
|
|
|
21
|
-
def load_model(self, path: str) -> Any:
|
|
40
|
+
def load_model(self, path: str, resource_hint: Optional["ResourceHint"] = None) -> Any:
|
|
22
41
|
import xgboost # pyright: ignore[reportMissingImports]
|
|
23
42
|
|
|
24
43
|
model = xgboost.XGBClassifier()
|
|
@@ -32,7 +51,7 @@ class XGBoostClassifierInference(ModelInference):
|
|
|
32
51
|
class XGBoostRegressorInference(ModelInference):
|
|
33
52
|
"""Model inference for XGBoost regressors."""
|
|
34
53
|
|
|
35
|
-
def load_model(self, path: str) -> Any:
|
|
54
|
+
def load_model(self, path: str, resource_hint: Optional["ResourceHint"] = None) -> Any:
|
|
36
55
|
import xgboost # pyright: ignore[reportMissingImports]
|
|
37
56
|
|
|
38
57
|
model = xgboost.XGBRegressor()
|
|
@@ -46,17 +65,27 @@ class XGBoostRegressorInference(ModelInference):
|
|
|
46
65
|
class PyTorchInference(ModelInference):
|
|
47
66
|
"""Model inference for PyTorch models."""
|
|
48
67
|
|
|
49
|
-
def load_model(self, path: str) -> Any:
|
|
68
|
+
def load_model(self, path: str, resource_hint: Optional["ResourceHint"] = None) -> Any:
|
|
50
69
|
import torch # pyright: ignore[reportMissingImports]
|
|
51
70
|
|
|
52
71
|
torch.set_grad_enabled(False)
|
|
72
|
+
|
|
73
|
+
# Load the model
|
|
53
74
|
model = torch.jit.load(path)
|
|
54
|
-
|
|
75
|
+
|
|
76
|
+
# If resource_hint is "gpu", move model to GPU
|
|
77
|
+
if resource_hint == "gpu" and torch.cuda.is_available():
|
|
78
|
+
device = torch.device("cuda")
|
|
79
|
+
model = model.to(device)
|
|
80
|
+
model.input_to_tensor = lambda X: torch.from_numpy(X).float().to(device)
|
|
81
|
+
else:
|
|
82
|
+
model.input_to_tensor = lambda X: torch.from_numpy(X).float()
|
|
83
|
+
|
|
55
84
|
return model
|
|
56
85
|
|
|
57
86
|
def predict(self, model: Any, X: Any) -> Any:
|
|
58
87
|
outputs = model(model.input_to_tensor(X))
|
|
59
|
-
result = outputs.detach().numpy().astype("float64")
|
|
88
|
+
result = outputs.detach().cpu().numpy().astype("float64")
|
|
60
89
|
result = result.squeeze()
|
|
61
90
|
|
|
62
91
|
# Convert 0-dimensional array to scalar, or ensure we have a proper 1D array
|
|
@@ -69,7 +98,7 @@ class PyTorchInference(ModelInference):
|
|
|
69
98
|
class SklearnInference(ModelInference):
|
|
70
99
|
"""Model inference for scikit-learn models."""
|
|
71
100
|
|
|
72
|
-
def load_model(self, path: str) -> Any:
|
|
101
|
+
def load_model(self, path: str, resource_hint: Optional["ResourceHint"] = None) -> Any:
|
|
73
102
|
import joblib # pyright: ignore[reportMissingImports]
|
|
74
103
|
|
|
75
104
|
return joblib.load(path)
|
|
@@ -81,7 +110,7 @@ class SklearnInference(ModelInference):
|
|
|
81
110
|
class TensorFlowInference(ModelInference):
|
|
82
111
|
"""Model inference for TensorFlow models."""
|
|
83
112
|
|
|
84
|
-
def load_model(self, path: str) -> Any:
|
|
113
|
+
def load_model(self, path: str, resource_hint: Optional["ResourceHint"] = None) -> Any:
|
|
85
114
|
import tensorflow # pyright: ignore[reportMissingImports]
|
|
86
115
|
|
|
87
116
|
return tensorflow.keras.models.load_model(path)
|
|
@@ -93,7 +122,7 @@ class TensorFlowInference(ModelInference):
|
|
|
93
122
|
class LightGBMInference(ModelInference):
|
|
94
123
|
"""Model inference for LightGBM models."""
|
|
95
124
|
|
|
96
|
-
def load_model(self, path: str) -> Any:
|
|
125
|
+
def load_model(self, path: str, resource_hint: Optional["ResourceHint"] = None) -> Any:
|
|
97
126
|
import lightgbm # pyright: ignore[reportMissingImports]
|
|
98
127
|
|
|
99
128
|
return lightgbm.Booster(model_file=path)
|
|
@@ -105,7 +134,7 @@ class LightGBMInference(ModelInference):
|
|
|
105
134
|
class CatBoostInference(ModelInference):
|
|
106
135
|
"""Model inference for CatBoost models."""
|
|
107
136
|
|
|
108
|
-
def load_model(self, path: str) -> Any:
|
|
137
|
+
def load_model(self, path: str, resource_hint: Optional["ResourceHint"] = None) -> Any:
|
|
109
138
|
import catboost # pyright: ignore[reportMissingImports]
|
|
110
139
|
|
|
111
140
|
return catboost.CatBoost().load_model(path)
|
|
@@ -115,31 +144,170 @@ class CatBoostInference(ModelInference):
|
|
|
115
144
|
|
|
116
145
|
|
|
117
146
|
class ONNXInference(ModelInference):
|
|
118
|
-
"""Model inference for ONNX models."""
|
|
147
|
+
"""Model inference for ONNX models with struct input/output support."""
|
|
119
148
|
|
|
120
|
-
def load_model(self, path: str) -> Any:
|
|
149
|
+
def load_model(self, path: str, resource_hint: Optional["ResourceHint"] = None) -> Any:
|
|
121
150
|
import onnxruntime # pyright: ignore[reportMissingImports]
|
|
122
151
|
|
|
123
|
-
|
|
152
|
+
# Conditionally add CUDAExecutionProvider based on resource_hint
|
|
153
|
+
providers = (
|
|
154
|
+
["CUDAExecutionProvider", "CPUExecutionProvider"] if resource_hint == "gpu" else ["CPUExecutionProvider"]
|
|
155
|
+
)
|
|
156
|
+
return onnxruntime.InferenceSession(path, providers=providers)
|
|
157
|
+
|
|
158
|
+
def prepare_input(self, feature_table: Any) -> Any:
|
|
159
|
+
"""Convert PyArrow table to struct array for ONNX models."""
|
|
160
|
+
import pyarrow as pa
|
|
161
|
+
|
|
162
|
+
# Get arrays for each column, combining chunks if necessary
|
|
163
|
+
arrays = []
|
|
164
|
+
for i in range(feature_table.num_columns):
|
|
165
|
+
col = feature_table.column(i)
|
|
166
|
+
if isinstance(col, pa.ChunkedArray):
|
|
167
|
+
arrays.append(col.combine_chunks())
|
|
168
|
+
else:
|
|
169
|
+
arrays.append(col)
|
|
170
|
+
|
|
171
|
+
# Create fields from schema, preserving original field names
|
|
172
|
+
# Field names should match ONNX input names exactly
|
|
173
|
+
fields = []
|
|
174
|
+
for field in feature_table.schema:
|
|
175
|
+
fields.append(pa.field(field.name, field.type))
|
|
176
|
+
|
|
177
|
+
# Create struct array where each row is a struct with named fields
|
|
178
|
+
return pa.StructArray.from_arrays(arrays, fields=fields)
|
|
179
|
+
|
|
180
|
+
def extract_output(self, result: Any, output_feature_name: str) -> Any:
|
|
181
|
+
"""Extract single field from ONNX struct output."""
|
|
182
|
+
import pyarrow as pa
|
|
183
|
+
|
|
184
|
+
if not isinstance(result, (pa.StructArray, pa.ChunkedArray)):
|
|
185
|
+
return result
|
|
186
|
+
|
|
187
|
+
struct_type = result.type if isinstance(result, pa.StructArray) else result.chunk(0).type
|
|
188
|
+
|
|
189
|
+
# Find matching field by name, or use first field
|
|
190
|
+
field_index = None
|
|
191
|
+
for i, field in enumerate(struct_type):
|
|
192
|
+
if field.name == output_feature_name:
|
|
193
|
+
field_index = i
|
|
194
|
+
break
|
|
195
|
+
|
|
196
|
+
return result.field(field_index if field_index is not None else 0)
|
|
124
197
|
|
|
125
198
|
def predict(self, model: Any, X: Any) -> Any:
|
|
199
|
+
"""Run ONNX inference with struct input/output."""
|
|
200
|
+
# Get ONNX model input/output names
|
|
201
|
+
input_names = [inp.name for inp in model.get_inputs()]
|
|
202
|
+
output_names = [out.name for out in model.get_outputs()]
|
|
203
|
+
|
|
204
|
+
# Convert struct input to ONNX input dict
|
|
205
|
+
input_dict = self._struct_to_inputs(X, input_names)
|
|
206
|
+
|
|
207
|
+
# Run ONNX inference
|
|
208
|
+
outputs = model.run(output_names, input_dict)
|
|
209
|
+
|
|
210
|
+
# Always return outputs as struct array
|
|
211
|
+
return self._outputs_to_struct(output_names, outputs)
|
|
212
|
+
|
|
213
|
+
def _struct_to_inputs(self, struct_array: Any, input_names: list) -> dict:
|
|
214
|
+
"""Extract ONNX inputs from struct array by matching field names.
|
|
215
|
+
|
|
216
|
+
Struct field names must match ONNX input names (supports list/Tensor types).
|
|
217
|
+
If ONNX expects a single input but struct has multiple scalar fields,
|
|
218
|
+
stack them into a 2D array.
|
|
219
|
+
"""
|
|
126
220
|
import numpy as np
|
|
221
|
+
import pyarrow as pa
|
|
222
|
+
|
|
223
|
+
if isinstance(struct_array, pa.ChunkedArray):
|
|
224
|
+
struct_array = struct_array.combine_chunks()
|
|
225
|
+
|
|
226
|
+
input_dict = {}
|
|
227
|
+
struct_fields = {field.name: i for i, field in enumerate(struct_array.type)}
|
|
228
|
+
|
|
229
|
+
# Check if struct field names match ONNX input names
|
|
230
|
+
fields_match = all(input_name in struct_fields for input_name in input_names)
|
|
231
|
+
|
|
232
|
+
if not fields_match:
|
|
233
|
+
# Special case 1: ONNX expects single input and struct has single field
|
|
234
|
+
# Use that field regardless of name mismatch
|
|
235
|
+
if len(input_names) == 1 and len(struct_fields) == 1:
|
|
236
|
+
field_data = struct_array.field(0)
|
|
237
|
+
input_dict[input_names[0]] = self._arrow_to_numpy(field_data)
|
|
238
|
+
return input_dict
|
|
239
|
+
|
|
240
|
+
# Special case 2: ONNX expects single input, but struct has multiple scalar fields
|
|
241
|
+
# Stack them into a 2D array [batch_size, num_fields]
|
|
242
|
+
if len(input_names) == 1 and len(struct_fields) > 1:
|
|
243
|
+
# Check if all fields are scalar (not nested lists)
|
|
244
|
+
all_scalar = all(
|
|
245
|
+
not pa.types.is_list(struct_array.type[i].type)
|
|
246
|
+
and not pa.types.is_large_list(struct_array.type[i].type)
|
|
247
|
+
for i in range(len(struct_array.type))
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
if all_scalar:
|
|
251
|
+
# Stack all fields into a single 2D array
|
|
252
|
+
columns = []
|
|
253
|
+
for i in range(len(struct_array.type)):
|
|
254
|
+
field_data = struct_array.field(i)
|
|
255
|
+
col_array = self._arrow_to_numpy(field_data)
|
|
256
|
+
columns.append(col_array)
|
|
257
|
+
|
|
258
|
+
# Stack columns horizontally to create [batch_size, num_features]
|
|
259
|
+
stacked = np.column_stack(columns)
|
|
260
|
+
input_dict[input_names[0]] = stacked
|
|
261
|
+
return input_dict
|
|
262
|
+
|
|
263
|
+
raise ValueError(
|
|
264
|
+
f"ONNX inputs {input_names} not found in struct fields {list(struct_fields.keys())}. "
|
|
265
|
+
+ "Struct field names must match ONNX input names."
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
# Direct mapping: struct fields match ONNX inputs (for Tensor/list types or named inputs)
|
|
269
|
+
for input_name in input_names:
|
|
270
|
+
field_data = struct_array.field(struct_fields[input_name])
|
|
271
|
+
input_dict[input_name] = self._arrow_to_numpy(field_data)
|
|
272
|
+
|
|
273
|
+
return input_dict
|
|
274
|
+
|
|
275
|
+
def _arrow_to_numpy(self, arrow_array: Any) -> Any:
|
|
276
|
+
"""Convert Arrow array (including nested lists) to dense numpy array."""
|
|
277
|
+
import numpy as np
|
|
278
|
+
import pyarrow as pa
|
|
127
279
|
|
|
128
|
-
|
|
129
|
-
|
|
280
|
+
if isinstance(arrow_array, pa.ChunkedArray):
|
|
281
|
+
arrow_array = arrow_array.combine_chunks()
|
|
130
282
|
|
|
131
|
-
# Convert
|
|
132
|
-
|
|
283
|
+
# Convert to Python list, then numpy - handles all cases (nested lists, flat arrays, etc.)
|
|
284
|
+
return np.array(arrow_array.to_pylist(), dtype=np.float32)
|
|
133
285
|
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
286
|
+
def _outputs_to_struct(self, output_names: list, outputs: list) -> Any:
|
|
287
|
+
"""Convert ONNX outputs to PyArrow struct array."""
|
|
288
|
+
import pyarrow as pa
|
|
289
|
+
|
|
290
|
+
if not outputs:
|
|
291
|
+
raise ValueError("ONNX model returned no outputs")
|
|
292
|
+
|
|
293
|
+
# Convert each output to Arrow array with proper type
|
|
294
|
+
fields = []
|
|
295
|
+
arrays = []
|
|
296
|
+
|
|
297
|
+
for name, output_array in zip(output_names, outputs):
|
|
298
|
+
arrow_array = self._numpy_to_arrow_array(output_array)
|
|
299
|
+
fields.append(pa.field(name, arrow_array.type))
|
|
300
|
+
arrays.append(arrow_array)
|
|
301
|
+
|
|
302
|
+
return pa.StructArray.from_arrays(arrays, fields=fields)
|
|
303
|
+
|
|
304
|
+
def _numpy_to_arrow_array(self, arr: Any) -> Any:
|
|
305
|
+
"""Convert numpy array to PyArrow array (possibly nested list)."""
|
|
306
|
+
import pyarrow as pa
|
|
141
307
|
|
|
142
|
-
|
|
308
|
+
# PyArrow can infer the correct nested list type from Python lists
|
|
309
|
+
# Shape (batch, dim1, dim2, ...) -> list[list[...]]
|
|
310
|
+
return pa.array(arr.tolist())
|
|
143
311
|
|
|
144
312
|
|
|
145
313
|
class ModelInferenceRegistry:
|