chalkpy 2.90.1__py3-none-any.whl → 2.95.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (202) hide show
  1. chalk/__init__.py +2 -1
  2. chalk/_gen/chalk/arrow/v1/arrow_pb2.py +7 -5
  3. chalk/_gen/chalk/arrow/v1/arrow_pb2.pyi +6 -0
  4. chalk/_gen/chalk/artifacts/v1/chart_pb2.py +16 -16
  5. chalk/_gen/chalk/artifacts/v1/chart_pb2.pyi +4 -0
  6. chalk/_gen/chalk/artifacts/v1/cron_query_pb2.py +8 -7
  7. chalk/_gen/chalk/artifacts/v1/cron_query_pb2.pyi +5 -0
  8. chalk/_gen/chalk/common/v1/offline_query_pb2.py +17 -15
  9. chalk/_gen/chalk/common/v1/offline_query_pb2.pyi +25 -0
  10. chalk/_gen/chalk/common/v1/script_task_pb2.py +3 -3
  11. chalk/_gen/chalk/common/v1/script_task_pb2.pyi +2 -0
  12. chalk/_gen/chalk/dataframe/__init__.py +0 -0
  13. chalk/_gen/chalk/dataframe/v1/__init__.py +0 -0
  14. chalk/_gen/chalk/dataframe/v1/dataframe_pb2.py +48 -0
  15. chalk/_gen/chalk/dataframe/v1/dataframe_pb2.pyi +123 -0
  16. chalk/_gen/chalk/dataframe/v1/dataframe_pb2_grpc.py +4 -0
  17. chalk/_gen/chalk/dataframe/v1/dataframe_pb2_grpc.pyi +4 -0
  18. chalk/_gen/chalk/graph/v1/graph_pb2.py +150 -149
  19. chalk/_gen/chalk/graph/v1/graph_pb2.pyi +25 -0
  20. chalk/_gen/chalk/graph/v1/sources_pb2.py +94 -84
  21. chalk/_gen/chalk/graph/v1/sources_pb2.pyi +56 -0
  22. chalk/_gen/chalk/kubernetes/v1/horizontalpodautoscaler_pb2.py +79 -0
  23. chalk/_gen/chalk/kubernetes/v1/horizontalpodautoscaler_pb2.pyi +377 -0
  24. chalk/_gen/chalk/kubernetes/v1/horizontalpodautoscaler_pb2_grpc.py +4 -0
  25. chalk/_gen/chalk/kubernetes/v1/horizontalpodautoscaler_pb2_grpc.pyi +4 -0
  26. chalk/_gen/chalk/kubernetes/v1/scaledobject_pb2.py +43 -7
  27. chalk/_gen/chalk/kubernetes/v1/scaledobject_pb2.pyi +252 -2
  28. chalk/_gen/chalk/protosql/v1/sql_service_pb2.py +54 -27
  29. chalk/_gen/chalk/protosql/v1/sql_service_pb2.pyi +131 -3
  30. chalk/_gen/chalk/protosql/v1/sql_service_pb2_grpc.py +45 -0
  31. chalk/_gen/chalk/protosql/v1/sql_service_pb2_grpc.pyi +14 -0
  32. chalk/_gen/chalk/python/v1/types_pb2.py +14 -14
  33. chalk/_gen/chalk/python/v1/types_pb2.pyi +8 -0
  34. chalk/_gen/chalk/server/v1/benchmark_pb2.py +76 -0
  35. chalk/_gen/chalk/server/v1/benchmark_pb2.pyi +156 -0
  36. chalk/_gen/chalk/server/v1/benchmark_pb2_grpc.py +258 -0
  37. chalk/_gen/chalk/server/v1/benchmark_pb2_grpc.pyi +84 -0
  38. chalk/_gen/chalk/server/v1/billing_pb2.py +40 -38
  39. chalk/_gen/chalk/server/v1/billing_pb2.pyi +17 -1
  40. chalk/_gen/chalk/server/v1/branches_pb2.py +45 -0
  41. chalk/_gen/chalk/server/v1/branches_pb2.pyi +80 -0
  42. chalk/_gen/chalk/server/v1/branches_pb2_grpc.pyi +36 -0
  43. chalk/_gen/chalk/server/v1/builder_pb2.py +358 -288
  44. chalk/_gen/chalk/server/v1/builder_pb2.pyi +360 -10
  45. chalk/_gen/chalk/server/v1/builder_pb2_grpc.py +225 -0
  46. chalk/_gen/chalk/server/v1/builder_pb2_grpc.pyi +60 -0
  47. chalk/_gen/chalk/server/v1/chart_pb2.py +10 -10
  48. chalk/_gen/chalk/server/v1/chart_pb2.pyi +18 -2
  49. chalk/_gen/chalk/server/v1/clickhouse_pb2.py +42 -0
  50. chalk/_gen/chalk/server/v1/clickhouse_pb2.pyi +17 -0
  51. chalk/_gen/chalk/server/v1/clickhouse_pb2_grpc.py +78 -0
  52. chalk/_gen/chalk/server/v1/clickhouse_pb2_grpc.pyi +38 -0
  53. chalk/_gen/chalk/server/v1/cloud_components_pb2.py +141 -119
  54. chalk/_gen/chalk/server/v1/cloud_components_pb2.pyi +106 -4
  55. chalk/_gen/chalk/server/v1/cloud_components_pb2_grpc.py +45 -0
  56. chalk/_gen/chalk/server/v1/cloud_components_pb2_grpc.pyi +12 -0
  57. chalk/_gen/chalk/server/v1/cloud_credentials_pb2.py +11 -3
  58. chalk/_gen/chalk/server/v1/cloud_credentials_pb2.pyi +20 -0
  59. chalk/_gen/chalk/server/v1/cloud_credentials_pb2_grpc.py +45 -0
  60. chalk/_gen/chalk/server/v1/cloud_credentials_pb2_grpc.pyi +12 -0
  61. chalk/_gen/chalk/server/v1/dataplanejobqueue_pb2.py +52 -38
  62. chalk/_gen/chalk/server/v1/dataplanejobqueue_pb2.pyi +62 -1
  63. chalk/_gen/chalk/server/v1/dataplanejobqueue_pb2_grpc.py +90 -0
  64. chalk/_gen/chalk/server/v1/dataplanejobqueue_pb2_grpc.pyi +24 -0
  65. chalk/_gen/chalk/server/v1/dataplaneworkflows_pb2.py +90 -0
  66. chalk/_gen/chalk/server/v1/dataplaneworkflows_pb2.pyi +264 -0
  67. chalk/_gen/chalk/server/v1/dataplaneworkflows_pb2_grpc.py +170 -0
  68. chalk/_gen/chalk/server/v1/dataplaneworkflows_pb2_grpc.pyi +62 -0
  69. chalk/_gen/chalk/server/v1/deploy_pb2.py +9 -3
  70. chalk/_gen/chalk/server/v1/deploy_pb2.pyi +12 -0
  71. chalk/_gen/chalk/server/v1/deploy_pb2_grpc.py +45 -0
  72. chalk/_gen/chalk/server/v1/deploy_pb2_grpc.pyi +12 -0
  73. chalk/_gen/chalk/server/v1/deployment_pb2.py +6 -6
  74. chalk/_gen/chalk/server/v1/deployment_pb2.pyi +20 -0
  75. chalk/_gen/chalk/server/v1/environment_pb2.py +14 -12
  76. chalk/_gen/chalk/server/v1/environment_pb2.pyi +19 -0
  77. chalk/_gen/chalk/server/v1/eventbus_pb2.py +4 -2
  78. chalk/_gen/chalk/server/v1/files_pb2.py +65 -0
  79. chalk/_gen/chalk/server/v1/files_pb2.pyi +167 -0
  80. chalk/_gen/chalk/server/v1/files_pb2_grpc.py +4 -0
  81. chalk/_gen/chalk/server/v1/files_pb2_grpc.pyi +4 -0
  82. chalk/_gen/chalk/server/v1/graph_pb2.py +38 -26
  83. chalk/_gen/chalk/server/v1/graph_pb2.pyi +58 -0
  84. chalk/_gen/chalk/server/v1/graph_pb2_grpc.py +47 -0
  85. chalk/_gen/chalk/server/v1/graph_pb2_grpc.pyi +18 -0
  86. chalk/_gen/chalk/server/v1/incident_pb2.py +23 -21
  87. chalk/_gen/chalk/server/v1/incident_pb2.pyi +15 -1
  88. chalk/_gen/chalk/server/v1/indexing_job_pb2.py +44 -0
  89. chalk/_gen/chalk/server/v1/indexing_job_pb2.pyi +38 -0
  90. chalk/_gen/chalk/server/v1/indexing_job_pb2_grpc.py +78 -0
  91. chalk/_gen/chalk/server/v1/indexing_job_pb2_grpc.pyi +38 -0
  92. chalk/_gen/chalk/server/v1/integrations_pb2.py +11 -9
  93. chalk/_gen/chalk/server/v1/integrations_pb2.pyi +34 -2
  94. chalk/_gen/chalk/server/v1/kube_pb2.py +29 -19
  95. chalk/_gen/chalk/server/v1/kube_pb2.pyi +28 -0
  96. chalk/_gen/chalk/server/v1/kube_pb2_grpc.py +45 -0
  97. chalk/_gen/chalk/server/v1/kube_pb2_grpc.pyi +12 -0
  98. chalk/_gen/chalk/server/v1/log_pb2.py +21 -3
  99. chalk/_gen/chalk/server/v1/log_pb2.pyi +68 -0
  100. chalk/_gen/chalk/server/v1/log_pb2_grpc.py +90 -0
  101. chalk/_gen/chalk/server/v1/log_pb2_grpc.pyi +24 -0
  102. chalk/_gen/chalk/server/v1/model_registry_pb2.py +10 -10
  103. chalk/_gen/chalk/server/v1/model_registry_pb2.pyi +4 -1
  104. chalk/_gen/chalk/server/v1/plandebug_pb2.py +53 -0
  105. chalk/_gen/chalk/server/v1/plandebug_pb2.pyi +86 -0
  106. chalk/_gen/chalk/server/v1/plandebug_pb2_grpc.py +168 -0
  107. chalk/_gen/chalk/server/v1/plandebug_pb2_grpc.pyi +60 -0
  108. chalk/_gen/chalk/server/v1/queries_pb2.py +66 -66
  109. chalk/_gen/chalk/server/v1/queries_pb2.pyi +32 -2
  110. chalk/_gen/chalk/server/v1/scheduled_query_run_pb2.py +12 -12
  111. chalk/_gen/chalk/server/v1/scheduled_query_run_pb2.pyi +16 -3
  112. chalk/_gen/chalk/server/v1/scheduler_pb2.py +24 -12
  113. chalk/_gen/chalk/server/v1/scheduler_pb2.pyi +61 -1
  114. chalk/_gen/chalk/server/v1/scheduler_pb2_grpc.py +90 -0
  115. chalk/_gen/chalk/server/v1/scheduler_pb2_grpc.pyi +24 -0
  116. chalk/_gen/chalk/server/v1/script_tasks_pb2.py +15 -3
  117. chalk/_gen/chalk/server/v1/script_tasks_pb2.pyi +22 -0
  118. chalk/_gen/chalk/server/v1/script_tasks_pb2_grpc.py +90 -0
  119. chalk/_gen/chalk/server/v1/script_tasks_pb2_grpc.pyi +24 -0
  120. chalk/_gen/chalk/server/v1/sql_interface_pb2.py +75 -0
  121. chalk/_gen/chalk/server/v1/sql_interface_pb2.pyi +142 -0
  122. chalk/_gen/chalk/server/v1/sql_interface_pb2_grpc.py +349 -0
  123. chalk/_gen/chalk/server/v1/sql_interface_pb2_grpc.pyi +114 -0
  124. chalk/_gen/chalk/server/v1/sql_queries_pb2.py +48 -0
  125. chalk/_gen/chalk/server/v1/sql_queries_pb2.pyi +150 -0
  126. chalk/_gen/chalk/server/v1/sql_queries_pb2_grpc.py +123 -0
  127. chalk/_gen/chalk/server/v1/sql_queries_pb2_grpc.pyi +52 -0
  128. chalk/_gen/chalk/server/v1/team_pb2.py +154 -141
  129. chalk/_gen/chalk/server/v1/team_pb2.pyi +30 -2
  130. chalk/_gen/chalk/server/v1/team_pb2_grpc.py +45 -0
  131. chalk/_gen/chalk/server/v1/team_pb2_grpc.pyi +12 -0
  132. chalk/_gen/chalk/server/v1/topic_pb2.py +5 -3
  133. chalk/_gen/chalk/server/v1/topic_pb2.pyi +10 -1
  134. chalk/_gen/chalk/server/v1/trace_pb2.py +44 -40
  135. chalk/_gen/chalk/server/v1/trace_pb2.pyi +20 -0
  136. chalk/_gen/chalk/streaming/v1/debug_service_pb2.py +62 -0
  137. chalk/_gen/chalk/streaming/v1/debug_service_pb2.pyi +75 -0
  138. chalk/_gen/chalk/streaming/v1/debug_service_pb2_grpc.py +221 -0
  139. chalk/_gen/chalk/streaming/v1/debug_service_pb2_grpc.pyi +88 -0
  140. chalk/_gen/chalk/streaming/v1/simple_streaming_service_pb2.py +16 -10
  141. chalk/_gen/chalk/streaming/v1/simple_streaming_service_pb2.pyi +52 -1
  142. chalk/_gen/chalk/streaming/v1/simple_streaming_service_pb2_grpc.py +48 -0
  143. chalk/_gen/chalk/streaming/v1/simple_streaming_service_pb2_grpc.pyi +20 -0
  144. chalk/_gen/chalk/utils/v1/field_change_pb2.py +32 -0
  145. chalk/_gen/chalk/utils/v1/field_change_pb2.pyi +42 -0
  146. chalk/_gen/chalk/utils/v1/field_change_pb2_grpc.py +4 -0
  147. chalk/_gen/chalk/utils/v1/field_change_pb2_grpc.pyi +4 -0
  148. chalk/_lsp/error_builder.py +11 -0
  149. chalk/_version.py +1 -1
  150. chalk/client/client.py +128 -43
  151. chalk/client/client_async.py +149 -0
  152. chalk/client/client_async_impl.py +22 -0
  153. chalk/client/client_grpc.py +539 -104
  154. chalk/client/client_impl.py +449 -122
  155. chalk/client/dataset.py +7 -1
  156. chalk/client/models.py +98 -0
  157. chalk/client/serialization/model_serialization.py +92 -9
  158. chalk/df/LazyFramePlaceholder.py +1154 -0
  159. chalk/features/_class_property.py +7 -0
  160. chalk/features/_embedding/embedding.py +1 -0
  161. chalk/features/_encoding/converter.py +83 -2
  162. chalk/features/feature_field.py +40 -30
  163. chalk/features/feature_set_decorator.py +1 -0
  164. chalk/features/feature_wrapper.py +42 -3
  165. chalk/features/hooks.py +81 -10
  166. chalk/features/inference.py +33 -31
  167. chalk/features/resolver.py +224 -24
  168. chalk/functions/__init__.py +65 -3
  169. chalk/gitignore/gitignore_parser.py +5 -1
  170. chalk/importer.py +142 -68
  171. chalk/ml/__init__.py +2 -0
  172. chalk/ml/model_hooks.py +194 -26
  173. chalk/ml/model_reference.py +56 -8
  174. chalk/ml/model_version.py +24 -15
  175. chalk/ml/utils.py +20 -17
  176. chalk/operators/_utils.py +10 -3
  177. chalk/parsed/_proto/export.py +22 -0
  178. chalk/parsed/duplicate_input_gql.py +3 -0
  179. chalk/parsed/json_conversions.py +20 -14
  180. chalk/parsed/to_proto.py +16 -4
  181. chalk/parsed/user_types_to_json.py +31 -10
  182. chalk/parsed/validation_from_registries.py +182 -0
  183. chalk/queries/named_query.py +16 -6
  184. chalk/queries/scheduled_query.py +9 -1
  185. chalk/serialization/parsed_annotation.py +24 -11
  186. chalk/sql/__init__.py +18 -0
  187. chalk/sql/_internal/integrations/databricks.py +55 -17
  188. chalk/sql/_internal/integrations/mssql.py +127 -62
  189. chalk/sql/_internal/integrations/redshift.py +4 -0
  190. chalk/sql/_internal/sql_file_resolver.py +53 -9
  191. chalk/sql/_internal/sql_source.py +35 -2
  192. chalk/streams/_kafka_source.py +5 -1
  193. chalk/streams/_windows.py +15 -2
  194. chalk/utils/_otel_version.py +13 -0
  195. chalk/utils/async_helpers.py +2 -2
  196. chalk/utils/missing_dependency.py +5 -4
  197. chalk/utils/tracing.py +185 -95
  198. {chalkpy-2.90.1.dist-info → chalkpy-2.95.3.dist-info}/METADATA +4 -6
  199. {chalkpy-2.90.1.dist-info → chalkpy-2.95.3.dist-info}/RECORD +202 -146
  200. {chalkpy-2.90.1.dist-info → chalkpy-2.95.3.dist-info}/WHEEL +0 -0
  201. {chalkpy-2.90.1.dist-info → chalkpy-2.95.3.dist-info}/entry_points.txt +0 -0
  202. {chalkpy-2.90.1.dist-info → chalkpy-2.95.3.dist-info}/top_level.txt +0 -0
chalk/importer.py CHANGED
@@ -304,6 +304,19 @@ def _parse_agg_function_call(expr: Underscore | None) -> Tuple[str, Underscore,
304
304
  f"expecting 'int' type argument for 'k', but received arg of type '{type(call_expr._chalk__kwargs.get('k'))}'"
305
305
  )
306
306
  opts = FrozenOrderedSet(call_expr._chalk__kwargs.items())
307
+ elif aggregation == "approx_percentile":
308
+ if len(call_expr._chalk__args) > 0:
309
+ raise ChalkParseError("should not have any positional arguments")
310
+ elif {"quantile"} != call_expr._chalk__kwargs.keys():
311
+ raise ChalkParseError("expecting exactly one required keyword argument 'quantile'")
312
+ elif not isinstance(call_expr._chalk__kwargs.get("quantile"), float):
313
+ raise ChalkParseError(
314
+ f"expecting 'float' type argument for 'quantile', but received arg of type '{type(call_expr._chalk__kwargs.get('quantile'))}'"
315
+ )
316
+ # TODO: expand proto definition to accept kwargs that are not necessarily `k`
317
+ quantile = call_expr._chalk__kwargs["quantile"]
318
+ nano_quantile = int(round(quantile * 1_000_000_000))
319
+ opts = FrozenOrderedSet([("k", nano_quantile)])
307
320
  elif aggregation in ("min_by_n", "max_by_n"):
308
321
  if len(call_expr._chalk__kwargs) > 0:
309
322
  raise ChalkParseError("should not have any keyword arguments")
@@ -433,8 +446,6 @@ def run_post_import_fixups():
433
446
  # "1m", "2m", materialization={...},
434
447
  # expression=_.transactions[_.amount].sum(),
435
448
  # )
436
- assert f.underscore_expression is not None
437
- assert f.window_materialization is not None
438
449
 
439
450
  try:
440
451
  f.window_materialization_parsed = parse_windowed_materialization(f=f)
@@ -572,39 +583,51 @@ def parse_grouped_window(f: Feature) -> WindowConfigResolved:
572
583
  aggregation_kwargs=aggregation_kwargs,
573
584
  pyarrow_dtype=pyarrow_dtype,
574
585
  filters=parsed_filters,
575
- backfill_resolver=_try_parse_resolver_fqn(
576
- "backfill_resolver",
577
- f.window_materialization.get("backfill_resolver", None),
578
- )
579
- if isinstance(f.window_materialization, dict)
580
- else None,
581
- backfill_schedule=f.window_materialization.get("backfill_schedule", None)
582
- if isinstance(f.window_materialization, dict)
583
- else None,
584
- backfill_lookback_duration_seconds=_try_parse_duration(
585
- "backfill_lookback_duration",
586
- f.window_materialization.get("backfill_lookback_duration", None),
587
- )
588
- if isinstance(f.window_materialization, dict)
589
- else None,
590
- backfill_start_time=_try_parse_datetime(
591
- "backfill_start_time",
592
- f.window_materialization.get("backfill_start_time", None),
593
- )
594
- if isinstance(f.window_materialization, dict)
595
- else None,
596
- continuous_resolver=_try_parse_resolver_fqn(
597
- "continuous_resolver",
598
- f.window_materialization.get("continuous_resolver", None),
599
- )
600
- if isinstance(f.window_materialization, dict)
601
- else None,
602
- continuous_buffer_duration_seconds=_try_parse_duration(
603
- "continuous_buffer_duration",
604
- f.window_materialization.get("continuous_buffer_duration", None),
605
- )
606
- if isinstance(f.window_materialization, dict)
607
- else None,
586
+ backfill_resolver=(
587
+ _try_parse_resolver_fqn(
588
+ "backfill_resolver",
589
+ f.window_materialization.get("backfill_resolver", None),
590
+ )
591
+ if isinstance(f.window_materialization, dict)
592
+ else None
593
+ ),
594
+ backfill_schedule=(
595
+ f.window_materialization.get("backfill_schedule", None)
596
+ if isinstance(f.window_materialization, dict)
597
+ else None
598
+ ),
599
+ backfill_lookback_duration_seconds=(
600
+ _try_parse_duration(
601
+ "backfill_lookback_duration",
602
+ f.window_materialization.get("backfill_lookback_duration", None),
603
+ )
604
+ if isinstance(f.window_materialization, dict)
605
+ else None
606
+ ),
607
+ backfill_start_time=(
608
+ _try_parse_datetime(
609
+ "backfill_start_time",
610
+ f.window_materialization.get("backfill_start_time", None),
611
+ )
612
+ if isinstance(f.window_materialization, dict)
613
+ else None
614
+ ),
615
+ continuous_resolver=(
616
+ _try_parse_resolver_fqn(
617
+ "continuous_resolver",
618
+ f.window_materialization.get("continuous_resolver", None),
619
+ )
620
+ if isinstance(f.window_materialization, dict)
621
+ else None
622
+ ),
623
+ continuous_buffer_duration_seconds=(
624
+ _try_parse_duration(
625
+ "continuous_buffer_duration",
626
+ f.window_materialization.get("continuous_buffer_duration", None),
627
+ )
628
+ if isinstance(f.window_materialization, dict)
629
+ else None
630
+ ),
608
631
  )
609
632
 
610
633
  return cfg
@@ -800,39 +823,51 @@ def parse_windowed_materialization(f: Feature) -> WindowConfigResolved | None:
800
823
  aggregation_kwargs=aggregation_kwargs,
801
824
  pyarrow_dtype=f.converter.pyarrow_dtype,
802
825
  filters=parsed_filters,
803
- backfill_resolver=_try_parse_resolver_fqn(
804
- "backfill_resolver",
805
- f.window_materialization.get("backfill_resolver", None),
806
- )
807
- if isinstance(f.window_materialization, dict)
808
- else None,
809
- backfill_schedule=f.window_materialization.get("backfill_schedule", None)
810
- if isinstance(f.window_materialization, dict)
811
- else None,
812
- backfill_lookback_duration_seconds=_try_parse_duration(
813
- "backfill_lookback_duration",
814
- f.window_materialization.get("backfill_lookback_duration", None),
815
- )
816
- if isinstance(f.window_materialization, dict)
817
- else None,
818
- backfill_start_time=_try_parse_datetime(
819
- "backfill_start_time",
820
- f.window_materialization.get("backfill_start_time", None),
821
- )
822
- if isinstance(f.window_materialization, dict)
823
- else None,
824
- continuous_resolver=_try_parse_resolver_fqn(
825
- "continuous_resolver",
826
- f.window_materialization.get("continuous_resolver", None),
827
- )
828
- if isinstance(f.window_materialization, dict)
829
- else None,
830
- continuous_buffer_duration_seconds=_try_parse_duration(
831
- "continuous_buffer_duration",
832
- f.window_materialization.get("continuous_buffer_duration", None),
833
- )
834
- if isinstance(f.window_materialization, dict)
835
- else None,
826
+ backfill_resolver=(
827
+ _try_parse_resolver_fqn(
828
+ "backfill_resolver",
829
+ f.window_materialization.get("backfill_resolver", None),
830
+ )
831
+ if isinstance(f.window_materialization, dict)
832
+ else None
833
+ ),
834
+ backfill_schedule=(
835
+ f.window_materialization.get("backfill_schedule", None)
836
+ if isinstance(f.window_materialization, dict)
837
+ else None
838
+ ),
839
+ backfill_lookback_duration_seconds=(
840
+ _try_parse_duration(
841
+ "backfill_lookback_duration",
842
+ f.window_materialization.get("backfill_lookback_duration", None),
843
+ )
844
+ if isinstance(f.window_materialization, dict)
845
+ else None
846
+ ),
847
+ backfill_start_time=(
848
+ _try_parse_datetime(
849
+ "backfill_start_time",
850
+ f.window_materialization.get("backfill_start_time", None),
851
+ )
852
+ if isinstance(f.window_materialization, dict)
853
+ else None
854
+ ),
855
+ continuous_resolver=(
856
+ _try_parse_resolver_fqn(
857
+ "continuous_resolver",
858
+ f.window_materialization.get("continuous_resolver", None),
859
+ )
860
+ if isinstance(f.window_materialization, dict)
861
+ else None
862
+ ),
863
+ continuous_buffer_duration_seconds=(
864
+ _try_parse_duration(
865
+ "continuous_buffer_duration",
866
+ f.window_materialization.get("continuous_buffer_duration", None),
867
+ )
868
+ if isinstance(f.window_materialization, dict)
869
+ else None
870
+ ),
836
871
  )
837
872
 
838
873
 
@@ -1010,6 +1045,33 @@ class _UnderscoreValidationError(ValueError):
1010
1045
  ...
1011
1046
 
1012
1047
 
1048
+ def _has_group_by_in_parent_chain(underscore: Underscore) -> bool:
1049
+ """
1050
+ Traverse parent chain to check if .group_by() exists before .agg().
1051
+
1052
+ For valid group_by_windowed: _.x.group_by(_.y).agg(_.z.sum())
1053
+ - Looks for: UnderscoreCall -> UnderscoreAttr("group_by")
1054
+
1055
+ Returns True if .group_by() found, False otherwise.
1056
+ """
1057
+ current: Optional[Any] = underscore
1058
+
1059
+ while current is not None:
1060
+ # Check if current is a .group_by() call
1061
+ if isinstance(current, UnderscoreCall):
1062
+ parent = current._chalk__parent
1063
+ if isinstance(parent, UnderscoreAttr) and parent._chalk__attr == "group_by":
1064
+ return True
1065
+
1066
+ # Move to parent
1067
+ if hasattr(current, "_chalk__parent"):
1068
+ current = current._chalk__parent
1069
+ else:
1070
+ break
1071
+
1072
+ return False
1073
+
1074
+
1013
1075
  class ChalkImporter:
1014
1076
  def __init__(self):
1015
1077
  super().__init__()
@@ -1111,6 +1173,9 @@ class ChalkImporter:
1111
1173
  for feature_class in FeatureSetBase.registry.values():
1112
1174
  # Iterate through every class, to find every underscore definition.
1113
1175
  for f in feature_class.features:
1176
+ if f.is_windowed_pseudofeature is True:
1177
+ # need one LSP just for the base
1178
+ continue
1114
1179
  if f.underscore_expression is not None:
1115
1180
  # Validate that the underscore expression is well-formed.
1116
1181
  # If it is not well-formed, then an `_UnderscoreValidationError` will
@@ -1489,6 +1554,15 @@ def _supplemental_validate_underscore_expression(
1489
1554
  )
1490
1555
  return None
1491
1556
 
1557
+ # Validate .agg() usage (addressing TODO at line 1522)
1558
+ if op_name == "agg":
1559
+ if not _has_group_by_in_parent_chain(caller):
1560
+ raise _UnderscoreValidationError(
1561
+ "'.agg()' can only be used with '.group_by()' for group_by_windowed features. "
1562
+ + "For windowed features, use direct aggregation methods instead. "
1563
+ + "For example, instead of using '.agg(_.field.method())', use '.field.method()' directly on the filtered DataFrame"
1564
+ )
1565
+
1492
1566
  return None
1493
1567
 
1494
1568
  # TODO: check that op_name is a supported agg or .agg/.group_by/etc
chalk/ml/__init__.py CHANGED
@@ -2,6 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  from chalk.ml.model_file_transfer import FileInfo, HFSourceConfig, LocalSourceConfig, S3SourceConfig, SourceConfig
4
4
  from chalk.ml.model_reference import ModelReference
5
+ from chalk.ml.model_version import ModelVersion
5
6
  from chalk.ml.utils import ModelClass, ModelEncoding, ModelRunCriterion, ModelType
6
7
 
7
8
  __all__ = (
@@ -9,6 +10,7 @@ __all__ = (
9
10
  "ModelClass",
10
11
  "ModelEncoding",
11
12
  "ModelReference",
13
+ "ModelVersion",
12
14
  "SourceConfig",
13
15
  "LocalSourceConfig",
14
16
  "S3SourceConfig",
chalk/ml/model_hooks.py CHANGED
@@ -1,12 +1,15 @@
1
- from typing import Any, Dict, Optional, Protocol, Tuple
1
+ from typing import TYPE_CHECKING, Any, Dict, Optional, Protocol, Tuple
2
2
 
3
3
  from chalk.ml.utils import ModelClass, ModelEncoding, ModelType
4
4
 
5
+ if TYPE_CHECKING:
6
+ from chalk.features.resolver import ResourceHint
7
+
5
8
 
6
9
  class ModelInference(Protocol):
7
10
  """Abstract base class for model loading and inference."""
8
11
 
9
- def load_model(self, path: str) -> Any:
12
+ def load_model(self, path: str, resource_hint: Optional["ResourceHint"] = None) -> Any:
10
13
  """Load a model from the given path."""
11
14
  pass
12
15
 
@@ -14,11 +17,27 @@ class ModelInference(Protocol):
14
17
  """Run inference on the model with input X."""
15
18
  pass
16
19
 
20
+ def prepare_input(self, feature_table: Any) -> Any:
21
+ """Convert PyArrow table to model input format.
22
+
23
+ Default implementation converts to numpy array via __array__().
24
+ Override for model-specific input formats (e.g., ONNX struct arrays).
25
+ """
26
+ return feature_table.__array__()
27
+
28
+ def extract_output(self, result: Any, output_feature_name: str) -> Any:
29
+ """Extract single output from model result.
30
+
31
+ Default implementation returns result as-is (for single outputs).
32
+ Override for models with structured outputs (e.g., ONNX struct arrays).
33
+ """
34
+ return result
35
+
17
36
 
18
37
  class XGBoostClassifierInference(ModelInference):
19
38
  """Model inference for XGBoost classifiers."""
20
39
 
21
- def load_model(self, path: str) -> Any:
40
+ def load_model(self, path: str, resource_hint: Optional["ResourceHint"] = None) -> Any:
22
41
  import xgboost # pyright: ignore[reportMissingImports]
23
42
 
24
43
  model = xgboost.XGBClassifier()
@@ -32,7 +51,7 @@ class XGBoostClassifierInference(ModelInference):
32
51
  class XGBoostRegressorInference(ModelInference):
33
52
  """Model inference for XGBoost regressors."""
34
53
 
35
- def load_model(self, path: str) -> Any:
54
+ def load_model(self, path: str, resource_hint: Optional["ResourceHint"] = None) -> Any:
36
55
  import xgboost # pyright: ignore[reportMissingImports]
37
56
 
38
57
  model = xgboost.XGBRegressor()
@@ -46,17 +65,27 @@ class XGBoostRegressorInference(ModelInference):
46
65
  class PyTorchInference(ModelInference):
47
66
  """Model inference for PyTorch models."""
48
67
 
49
- def load_model(self, path: str) -> Any:
68
+ def load_model(self, path: str, resource_hint: Optional["ResourceHint"] = None) -> Any:
50
69
  import torch # pyright: ignore[reportMissingImports]
51
70
 
52
71
  torch.set_grad_enabled(False)
72
+
73
+ # Load the model
53
74
  model = torch.jit.load(path)
54
- model.input_to_tensor = lambda X: torch.from_numpy(X).float()
75
+
76
+ # If resource_hint is "gpu", move model to GPU
77
+ if resource_hint == "gpu" and torch.cuda.is_available():
78
+ device = torch.device("cuda")
79
+ model = model.to(device)
80
+ model.input_to_tensor = lambda X: torch.from_numpy(X).float().to(device)
81
+ else:
82
+ model.input_to_tensor = lambda X: torch.from_numpy(X).float()
83
+
55
84
  return model
56
85
 
57
86
  def predict(self, model: Any, X: Any) -> Any:
58
87
  outputs = model(model.input_to_tensor(X))
59
- result = outputs.detach().numpy().astype("float64")
88
+ result = outputs.detach().cpu().numpy().astype("float64")
60
89
  result = result.squeeze()
61
90
 
62
91
  # Convert 0-dimensional array to scalar, or ensure we have a proper 1D array
@@ -69,7 +98,7 @@ class PyTorchInference(ModelInference):
69
98
  class SklearnInference(ModelInference):
70
99
  """Model inference for scikit-learn models."""
71
100
 
72
- def load_model(self, path: str) -> Any:
101
+ def load_model(self, path: str, resource_hint: Optional["ResourceHint"] = None) -> Any:
73
102
  import joblib # pyright: ignore[reportMissingImports]
74
103
 
75
104
  return joblib.load(path)
@@ -81,7 +110,7 @@ class SklearnInference(ModelInference):
81
110
  class TensorFlowInference(ModelInference):
82
111
  """Model inference for TensorFlow models."""
83
112
 
84
- def load_model(self, path: str) -> Any:
113
+ def load_model(self, path: str, resource_hint: Optional["ResourceHint"] = None) -> Any:
85
114
  import tensorflow # pyright: ignore[reportMissingImports]
86
115
 
87
116
  return tensorflow.keras.models.load_model(path)
@@ -93,7 +122,7 @@ class TensorFlowInference(ModelInference):
93
122
  class LightGBMInference(ModelInference):
94
123
  """Model inference for LightGBM models."""
95
124
 
96
- def load_model(self, path: str) -> Any:
125
+ def load_model(self, path: str, resource_hint: Optional["ResourceHint"] = None) -> Any:
97
126
  import lightgbm # pyright: ignore[reportMissingImports]
98
127
 
99
128
  return lightgbm.Booster(model_file=path)
@@ -105,7 +134,7 @@ class LightGBMInference(ModelInference):
105
134
  class CatBoostInference(ModelInference):
106
135
  """Model inference for CatBoost models."""
107
136
 
108
- def load_model(self, path: str) -> Any:
137
+ def load_model(self, path: str, resource_hint: Optional["ResourceHint"] = None) -> Any:
109
138
  import catboost # pyright: ignore[reportMissingImports]
110
139
 
111
140
  return catboost.CatBoost().load_model(path)
@@ -115,31 +144,170 @@ class CatBoostInference(ModelInference):
115
144
 
116
145
 
117
146
  class ONNXInference(ModelInference):
118
- """Model inference for ONNX models."""
147
+ """Model inference for ONNX models with struct input/output support."""
119
148
 
120
- def load_model(self, path: str) -> Any:
149
+ def load_model(self, path: str, resource_hint: Optional["ResourceHint"] = None) -> Any:
121
150
  import onnxruntime # pyright: ignore[reportMissingImports]
122
151
 
123
- return onnxruntime.InferenceSession(path)
152
+ # Conditionally add CUDAExecutionProvider based on resource_hint
153
+ providers = (
154
+ ["CUDAExecutionProvider", "CPUExecutionProvider"] if resource_hint == "gpu" else ["CPUExecutionProvider"]
155
+ )
156
+ return onnxruntime.InferenceSession(path, providers=providers)
157
+
158
+ def prepare_input(self, feature_table: Any) -> Any:
159
+ """Convert PyArrow table to struct array for ONNX models."""
160
+ import pyarrow as pa
161
+
162
+ # Get arrays for each column, combining chunks if necessary
163
+ arrays = []
164
+ for i in range(feature_table.num_columns):
165
+ col = feature_table.column(i)
166
+ if isinstance(col, pa.ChunkedArray):
167
+ arrays.append(col.combine_chunks())
168
+ else:
169
+ arrays.append(col)
170
+
171
+ # Create fields from schema, preserving original field names
172
+ # Field names should match ONNX input names exactly
173
+ fields = []
174
+ for field in feature_table.schema:
175
+ fields.append(pa.field(field.name, field.type))
176
+
177
+ # Create struct array where each row is a struct with named fields
178
+ return pa.StructArray.from_arrays(arrays, fields=fields)
179
+
180
+ def extract_output(self, result: Any, output_feature_name: str) -> Any:
181
+ """Extract single field from ONNX struct output."""
182
+ import pyarrow as pa
183
+
184
+ if not isinstance(result, (pa.StructArray, pa.ChunkedArray)):
185
+ return result
186
+
187
+ struct_type = result.type if isinstance(result, pa.StructArray) else result.chunk(0).type
188
+
189
+ # Find matching field by name, or use first field
190
+ field_index = None
191
+ for i, field in enumerate(struct_type):
192
+ if field.name == output_feature_name:
193
+ field_index = i
194
+ break
195
+
196
+ return result.field(field_index if field_index is not None else 0)
124
197
 
125
198
  def predict(self, model: Any, X: Any) -> Any:
199
+ """Run ONNX inference with struct input/output."""
200
+ # Get ONNX model input/output names
201
+ input_names = [inp.name for inp in model.get_inputs()]
202
+ output_names = [out.name for out in model.get_outputs()]
203
+
204
+ # Convert struct input to ONNX input dict
205
+ input_dict = self._struct_to_inputs(X, input_names)
206
+
207
+ # Run ONNX inference
208
+ outputs = model.run(output_names, input_dict)
209
+
210
+ # Always return outputs as struct array
211
+ return self._outputs_to_struct(output_names, outputs)
212
+
213
+ def _struct_to_inputs(self, struct_array: Any, input_names: list) -> dict:
214
+ """Extract ONNX inputs from struct array by matching field names.
215
+
216
+ Struct field names must match ONNX input names (supports list/Tensor types).
217
+ If ONNX expects a single input but struct has multiple scalar fields,
218
+ stack them into a 2D array.
219
+ """
126
220
  import numpy as np
221
+ import pyarrow as pa
222
+
223
+ if isinstance(struct_array, pa.ChunkedArray):
224
+ struct_array = struct_array.combine_chunks()
225
+
226
+ input_dict = {}
227
+ struct_fields = {field.name: i for i, field in enumerate(struct_array.type)}
228
+
229
+ # Check if struct field names match ONNX input names
230
+ fields_match = all(input_name in struct_fields for input_name in input_names)
231
+
232
+ if not fields_match:
233
+ # Special case 1: ONNX expects single input and struct has single field
234
+ # Use that field regardless of name mismatch
235
+ if len(input_names) == 1 and len(struct_fields) == 1:
236
+ field_data = struct_array.field(0)
237
+ input_dict[input_names[0]] = self._arrow_to_numpy(field_data)
238
+ return input_dict
239
+
240
+ # Special case 2: ONNX expects single input, but struct has multiple scalar fields
241
+ # Stack them into a 2D array [batch_size, num_fields]
242
+ if len(input_names) == 1 and len(struct_fields) > 1:
243
+ # Check if all fields are scalar (not nested lists)
244
+ all_scalar = all(
245
+ not pa.types.is_list(struct_array.type[i].type)
246
+ and not pa.types.is_large_list(struct_array.type[i].type)
247
+ for i in range(len(struct_array.type))
248
+ )
249
+
250
+ if all_scalar:
251
+ # Stack all fields into a single 2D array
252
+ columns = []
253
+ for i in range(len(struct_array.type)):
254
+ field_data = struct_array.field(i)
255
+ col_array = self._arrow_to_numpy(field_data)
256
+ columns.append(col_array)
257
+
258
+ # Stack columns horizontally to create [batch_size, num_features]
259
+ stacked = np.column_stack(columns)
260
+ input_dict[input_names[0]] = stacked
261
+ return input_dict
262
+
263
+ raise ValueError(
264
+ f"ONNX inputs {input_names} not found in struct fields {list(struct_fields.keys())}. "
265
+ + "Struct field names must match ONNX input names."
266
+ )
267
+
268
+ # Direct mapping: struct fields match ONNX inputs (for Tensor/list types or named inputs)
269
+ for input_name in input_names:
270
+ field_data = struct_array.field(struct_fields[input_name])
271
+ input_dict[input_name] = self._arrow_to_numpy(field_data)
272
+
273
+ return input_dict
274
+
275
+ def _arrow_to_numpy(self, arrow_array: Any) -> Any:
276
+ """Convert Arrow array (including nested lists) to dense numpy array."""
277
+ import numpy as np
278
+ import pyarrow as pa
127
279
 
128
- # Get input names from the model metadata
129
- input_names = [inp.name for inp in model.get_inputs()]
280
+ if isinstance(arrow_array, pa.ChunkedArray):
281
+ arrow_array = arrow_array.combine_chunks()
130
282
 
131
- # Convert X to float32 if needed
132
- X_float32 = X.astype("float32") if hasattr(X, "astype") else np.array(X, dtype="float32")
283
+ # Convert to Python list, then numpy - handles all cases (nested lists, flat arrays, etc.)
284
+ return np.array(arrow_array.to_pylist(), dtype=np.float32)
133
285
 
134
- # If there's only one input, use it directly
135
- if len(input_names) == 1:
136
- input_dict = {input_names[0]: X_float32}
137
- else:
138
- # For multiple inputs, we'd need additional logic
139
- # For now, assume the first input is the main one
140
- input_dict = {input_names[0]: X_float32}
286
+ def _outputs_to_struct(self, output_names: list, outputs: list) -> Any:
287
+ """Convert ONNX outputs to PyArrow struct array."""
288
+ import pyarrow as pa
289
+
290
+ if not outputs:
291
+ raise ValueError("ONNX model returned no outputs")
292
+
293
+ # Convert each output to Arrow array with proper type
294
+ fields = []
295
+ arrays = []
296
+
297
+ for name, output_array in zip(output_names, outputs):
298
+ arrow_array = self._numpy_to_arrow_array(output_array)
299
+ fields.append(pa.field(name, arrow_array.type))
300
+ arrays.append(arrow_array)
301
+
302
+ return pa.StructArray.from_arrays(arrays, fields=fields)
303
+
304
+ def _numpy_to_arrow_array(self, arr: Any) -> Any:
305
+ """Convert numpy array to PyArrow array (possibly nested list)."""
306
+ import pyarrow as pa
141
307
 
142
- return model.run(None, input_dict)[0]
308
+ # PyArrow can infer the correct nested list type from Python lists
309
+ # Shape (batch, dim1, dim2, ...) -> list[list[...]]
310
+ return pa.array(arr.tolist())
143
311
 
144
312
 
145
313
  class ModelInferenceRegistry: