chalkpy 2.90.1__py3-none-any.whl → 2.95.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (202) hide show
  1. chalk/__init__.py +2 -1
  2. chalk/_gen/chalk/arrow/v1/arrow_pb2.py +7 -5
  3. chalk/_gen/chalk/arrow/v1/arrow_pb2.pyi +6 -0
  4. chalk/_gen/chalk/artifacts/v1/chart_pb2.py +16 -16
  5. chalk/_gen/chalk/artifacts/v1/chart_pb2.pyi +4 -0
  6. chalk/_gen/chalk/artifacts/v1/cron_query_pb2.py +8 -7
  7. chalk/_gen/chalk/artifacts/v1/cron_query_pb2.pyi +5 -0
  8. chalk/_gen/chalk/common/v1/offline_query_pb2.py +17 -15
  9. chalk/_gen/chalk/common/v1/offline_query_pb2.pyi +25 -0
  10. chalk/_gen/chalk/common/v1/script_task_pb2.py +3 -3
  11. chalk/_gen/chalk/common/v1/script_task_pb2.pyi +2 -0
  12. chalk/_gen/chalk/dataframe/__init__.py +0 -0
  13. chalk/_gen/chalk/dataframe/v1/__init__.py +0 -0
  14. chalk/_gen/chalk/dataframe/v1/dataframe_pb2.py +48 -0
  15. chalk/_gen/chalk/dataframe/v1/dataframe_pb2.pyi +123 -0
  16. chalk/_gen/chalk/dataframe/v1/dataframe_pb2_grpc.py +4 -0
  17. chalk/_gen/chalk/dataframe/v1/dataframe_pb2_grpc.pyi +4 -0
  18. chalk/_gen/chalk/graph/v1/graph_pb2.py +150 -149
  19. chalk/_gen/chalk/graph/v1/graph_pb2.pyi +25 -0
  20. chalk/_gen/chalk/graph/v1/sources_pb2.py +94 -84
  21. chalk/_gen/chalk/graph/v1/sources_pb2.pyi +56 -0
  22. chalk/_gen/chalk/kubernetes/v1/horizontalpodautoscaler_pb2.py +79 -0
  23. chalk/_gen/chalk/kubernetes/v1/horizontalpodautoscaler_pb2.pyi +377 -0
  24. chalk/_gen/chalk/kubernetes/v1/horizontalpodautoscaler_pb2_grpc.py +4 -0
  25. chalk/_gen/chalk/kubernetes/v1/horizontalpodautoscaler_pb2_grpc.pyi +4 -0
  26. chalk/_gen/chalk/kubernetes/v1/scaledobject_pb2.py +43 -7
  27. chalk/_gen/chalk/kubernetes/v1/scaledobject_pb2.pyi +252 -2
  28. chalk/_gen/chalk/protosql/v1/sql_service_pb2.py +54 -27
  29. chalk/_gen/chalk/protosql/v1/sql_service_pb2.pyi +131 -3
  30. chalk/_gen/chalk/protosql/v1/sql_service_pb2_grpc.py +45 -0
  31. chalk/_gen/chalk/protosql/v1/sql_service_pb2_grpc.pyi +14 -0
  32. chalk/_gen/chalk/python/v1/types_pb2.py +14 -14
  33. chalk/_gen/chalk/python/v1/types_pb2.pyi +8 -0
  34. chalk/_gen/chalk/server/v1/benchmark_pb2.py +76 -0
  35. chalk/_gen/chalk/server/v1/benchmark_pb2.pyi +156 -0
  36. chalk/_gen/chalk/server/v1/benchmark_pb2_grpc.py +258 -0
  37. chalk/_gen/chalk/server/v1/benchmark_pb2_grpc.pyi +84 -0
  38. chalk/_gen/chalk/server/v1/billing_pb2.py +40 -38
  39. chalk/_gen/chalk/server/v1/billing_pb2.pyi +17 -1
  40. chalk/_gen/chalk/server/v1/branches_pb2.py +45 -0
  41. chalk/_gen/chalk/server/v1/branches_pb2.pyi +80 -0
  42. chalk/_gen/chalk/server/v1/branches_pb2_grpc.pyi +36 -0
  43. chalk/_gen/chalk/server/v1/builder_pb2.py +358 -288
  44. chalk/_gen/chalk/server/v1/builder_pb2.pyi +360 -10
  45. chalk/_gen/chalk/server/v1/builder_pb2_grpc.py +225 -0
  46. chalk/_gen/chalk/server/v1/builder_pb2_grpc.pyi +60 -0
  47. chalk/_gen/chalk/server/v1/chart_pb2.py +10 -10
  48. chalk/_gen/chalk/server/v1/chart_pb2.pyi +18 -2
  49. chalk/_gen/chalk/server/v1/clickhouse_pb2.py +42 -0
  50. chalk/_gen/chalk/server/v1/clickhouse_pb2.pyi +17 -0
  51. chalk/_gen/chalk/server/v1/clickhouse_pb2_grpc.py +78 -0
  52. chalk/_gen/chalk/server/v1/clickhouse_pb2_grpc.pyi +38 -0
  53. chalk/_gen/chalk/server/v1/cloud_components_pb2.py +141 -119
  54. chalk/_gen/chalk/server/v1/cloud_components_pb2.pyi +106 -4
  55. chalk/_gen/chalk/server/v1/cloud_components_pb2_grpc.py +45 -0
  56. chalk/_gen/chalk/server/v1/cloud_components_pb2_grpc.pyi +12 -0
  57. chalk/_gen/chalk/server/v1/cloud_credentials_pb2.py +11 -3
  58. chalk/_gen/chalk/server/v1/cloud_credentials_pb2.pyi +20 -0
  59. chalk/_gen/chalk/server/v1/cloud_credentials_pb2_grpc.py +45 -0
  60. chalk/_gen/chalk/server/v1/cloud_credentials_pb2_grpc.pyi +12 -0
  61. chalk/_gen/chalk/server/v1/dataplanejobqueue_pb2.py +52 -38
  62. chalk/_gen/chalk/server/v1/dataplanejobqueue_pb2.pyi +62 -1
  63. chalk/_gen/chalk/server/v1/dataplanejobqueue_pb2_grpc.py +90 -0
  64. chalk/_gen/chalk/server/v1/dataplanejobqueue_pb2_grpc.pyi +24 -0
  65. chalk/_gen/chalk/server/v1/dataplaneworkflows_pb2.py +90 -0
  66. chalk/_gen/chalk/server/v1/dataplaneworkflows_pb2.pyi +264 -0
  67. chalk/_gen/chalk/server/v1/dataplaneworkflows_pb2_grpc.py +170 -0
  68. chalk/_gen/chalk/server/v1/dataplaneworkflows_pb2_grpc.pyi +62 -0
  69. chalk/_gen/chalk/server/v1/deploy_pb2.py +9 -3
  70. chalk/_gen/chalk/server/v1/deploy_pb2.pyi +12 -0
  71. chalk/_gen/chalk/server/v1/deploy_pb2_grpc.py +45 -0
  72. chalk/_gen/chalk/server/v1/deploy_pb2_grpc.pyi +12 -0
  73. chalk/_gen/chalk/server/v1/deployment_pb2.py +6 -6
  74. chalk/_gen/chalk/server/v1/deployment_pb2.pyi +20 -0
  75. chalk/_gen/chalk/server/v1/environment_pb2.py +14 -12
  76. chalk/_gen/chalk/server/v1/environment_pb2.pyi +19 -0
  77. chalk/_gen/chalk/server/v1/eventbus_pb2.py +4 -2
  78. chalk/_gen/chalk/server/v1/files_pb2.py +65 -0
  79. chalk/_gen/chalk/server/v1/files_pb2.pyi +167 -0
  80. chalk/_gen/chalk/server/v1/files_pb2_grpc.py +4 -0
  81. chalk/_gen/chalk/server/v1/files_pb2_grpc.pyi +4 -0
  82. chalk/_gen/chalk/server/v1/graph_pb2.py +38 -26
  83. chalk/_gen/chalk/server/v1/graph_pb2.pyi +58 -0
  84. chalk/_gen/chalk/server/v1/graph_pb2_grpc.py +47 -0
  85. chalk/_gen/chalk/server/v1/graph_pb2_grpc.pyi +18 -0
  86. chalk/_gen/chalk/server/v1/incident_pb2.py +23 -21
  87. chalk/_gen/chalk/server/v1/incident_pb2.pyi +15 -1
  88. chalk/_gen/chalk/server/v1/indexing_job_pb2.py +44 -0
  89. chalk/_gen/chalk/server/v1/indexing_job_pb2.pyi +38 -0
  90. chalk/_gen/chalk/server/v1/indexing_job_pb2_grpc.py +78 -0
  91. chalk/_gen/chalk/server/v1/indexing_job_pb2_grpc.pyi +38 -0
  92. chalk/_gen/chalk/server/v1/integrations_pb2.py +11 -9
  93. chalk/_gen/chalk/server/v1/integrations_pb2.pyi +34 -2
  94. chalk/_gen/chalk/server/v1/kube_pb2.py +29 -19
  95. chalk/_gen/chalk/server/v1/kube_pb2.pyi +28 -0
  96. chalk/_gen/chalk/server/v1/kube_pb2_grpc.py +45 -0
  97. chalk/_gen/chalk/server/v1/kube_pb2_grpc.pyi +12 -0
  98. chalk/_gen/chalk/server/v1/log_pb2.py +21 -3
  99. chalk/_gen/chalk/server/v1/log_pb2.pyi +68 -0
  100. chalk/_gen/chalk/server/v1/log_pb2_grpc.py +90 -0
  101. chalk/_gen/chalk/server/v1/log_pb2_grpc.pyi +24 -0
  102. chalk/_gen/chalk/server/v1/model_registry_pb2.py +10 -10
  103. chalk/_gen/chalk/server/v1/model_registry_pb2.pyi +4 -1
  104. chalk/_gen/chalk/server/v1/plandebug_pb2.py +53 -0
  105. chalk/_gen/chalk/server/v1/plandebug_pb2.pyi +86 -0
  106. chalk/_gen/chalk/server/v1/plandebug_pb2_grpc.py +168 -0
  107. chalk/_gen/chalk/server/v1/plandebug_pb2_grpc.pyi +60 -0
  108. chalk/_gen/chalk/server/v1/queries_pb2.py +66 -66
  109. chalk/_gen/chalk/server/v1/queries_pb2.pyi +32 -2
  110. chalk/_gen/chalk/server/v1/scheduled_query_run_pb2.py +12 -12
  111. chalk/_gen/chalk/server/v1/scheduled_query_run_pb2.pyi +16 -3
  112. chalk/_gen/chalk/server/v1/scheduler_pb2.py +24 -12
  113. chalk/_gen/chalk/server/v1/scheduler_pb2.pyi +61 -1
  114. chalk/_gen/chalk/server/v1/scheduler_pb2_grpc.py +90 -0
  115. chalk/_gen/chalk/server/v1/scheduler_pb2_grpc.pyi +24 -0
  116. chalk/_gen/chalk/server/v1/script_tasks_pb2.py +15 -3
  117. chalk/_gen/chalk/server/v1/script_tasks_pb2.pyi +22 -0
  118. chalk/_gen/chalk/server/v1/script_tasks_pb2_grpc.py +90 -0
  119. chalk/_gen/chalk/server/v1/script_tasks_pb2_grpc.pyi +24 -0
  120. chalk/_gen/chalk/server/v1/sql_interface_pb2.py +75 -0
  121. chalk/_gen/chalk/server/v1/sql_interface_pb2.pyi +142 -0
  122. chalk/_gen/chalk/server/v1/sql_interface_pb2_grpc.py +349 -0
  123. chalk/_gen/chalk/server/v1/sql_interface_pb2_grpc.pyi +114 -0
  124. chalk/_gen/chalk/server/v1/sql_queries_pb2.py +48 -0
  125. chalk/_gen/chalk/server/v1/sql_queries_pb2.pyi +150 -0
  126. chalk/_gen/chalk/server/v1/sql_queries_pb2_grpc.py +123 -0
  127. chalk/_gen/chalk/server/v1/sql_queries_pb2_grpc.pyi +52 -0
  128. chalk/_gen/chalk/server/v1/team_pb2.py +154 -141
  129. chalk/_gen/chalk/server/v1/team_pb2.pyi +30 -2
  130. chalk/_gen/chalk/server/v1/team_pb2_grpc.py +45 -0
  131. chalk/_gen/chalk/server/v1/team_pb2_grpc.pyi +12 -0
  132. chalk/_gen/chalk/server/v1/topic_pb2.py +5 -3
  133. chalk/_gen/chalk/server/v1/topic_pb2.pyi +10 -1
  134. chalk/_gen/chalk/server/v1/trace_pb2.py +44 -40
  135. chalk/_gen/chalk/server/v1/trace_pb2.pyi +20 -0
  136. chalk/_gen/chalk/streaming/v1/debug_service_pb2.py +62 -0
  137. chalk/_gen/chalk/streaming/v1/debug_service_pb2.pyi +75 -0
  138. chalk/_gen/chalk/streaming/v1/debug_service_pb2_grpc.py +221 -0
  139. chalk/_gen/chalk/streaming/v1/debug_service_pb2_grpc.pyi +88 -0
  140. chalk/_gen/chalk/streaming/v1/simple_streaming_service_pb2.py +16 -10
  141. chalk/_gen/chalk/streaming/v1/simple_streaming_service_pb2.pyi +52 -1
  142. chalk/_gen/chalk/streaming/v1/simple_streaming_service_pb2_grpc.py +48 -0
  143. chalk/_gen/chalk/streaming/v1/simple_streaming_service_pb2_grpc.pyi +20 -0
  144. chalk/_gen/chalk/utils/v1/field_change_pb2.py +32 -0
  145. chalk/_gen/chalk/utils/v1/field_change_pb2.pyi +42 -0
  146. chalk/_gen/chalk/utils/v1/field_change_pb2_grpc.py +4 -0
  147. chalk/_gen/chalk/utils/v1/field_change_pb2_grpc.pyi +4 -0
  148. chalk/_lsp/error_builder.py +11 -0
  149. chalk/_version.py +1 -1
  150. chalk/client/client.py +128 -43
  151. chalk/client/client_async.py +149 -0
  152. chalk/client/client_async_impl.py +22 -0
  153. chalk/client/client_grpc.py +539 -104
  154. chalk/client/client_impl.py +449 -122
  155. chalk/client/dataset.py +7 -1
  156. chalk/client/models.py +98 -0
  157. chalk/client/serialization/model_serialization.py +92 -9
  158. chalk/df/LazyFramePlaceholder.py +1154 -0
  159. chalk/features/_class_property.py +7 -0
  160. chalk/features/_embedding/embedding.py +1 -0
  161. chalk/features/_encoding/converter.py +83 -2
  162. chalk/features/feature_field.py +40 -30
  163. chalk/features/feature_set_decorator.py +1 -0
  164. chalk/features/feature_wrapper.py +42 -3
  165. chalk/features/hooks.py +81 -10
  166. chalk/features/inference.py +33 -31
  167. chalk/features/resolver.py +224 -24
  168. chalk/functions/__init__.py +65 -3
  169. chalk/gitignore/gitignore_parser.py +5 -1
  170. chalk/importer.py +142 -68
  171. chalk/ml/__init__.py +2 -0
  172. chalk/ml/model_hooks.py +194 -26
  173. chalk/ml/model_reference.py +56 -8
  174. chalk/ml/model_version.py +24 -15
  175. chalk/ml/utils.py +20 -17
  176. chalk/operators/_utils.py +10 -3
  177. chalk/parsed/_proto/export.py +22 -0
  178. chalk/parsed/duplicate_input_gql.py +3 -0
  179. chalk/parsed/json_conversions.py +20 -14
  180. chalk/parsed/to_proto.py +16 -4
  181. chalk/parsed/user_types_to_json.py +31 -10
  182. chalk/parsed/validation_from_registries.py +182 -0
  183. chalk/queries/named_query.py +16 -6
  184. chalk/queries/scheduled_query.py +9 -1
  185. chalk/serialization/parsed_annotation.py +24 -11
  186. chalk/sql/__init__.py +18 -0
  187. chalk/sql/_internal/integrations/databricks.py +55 -17
  188. chalk/sql/_internal/integrations/mssql.py +127 -62
  189. chalk/sql/_internal/integrations/redshift.py +4 -0
  190. chalk/sql/_internal/sql_file_resolver.py +53 -9
  191. chalk/sql/_internal/sql_source.py +35 -2
  192. chalk/streams/_kafka_source.py +5 -1
  193. chalk/streams/_windows.py +15 -2
  194. chalk/utils/_otel_version.py +13 -0
  195. chalk/utils/async_helpers.py +2 -2
  196. chalk/utils/missing_dependency.py +5 -4
  197. chalk/utils/tracing.py +185 -95
  198. {chalkpy-2.90.1.dist-info → chalkpy-2.95.3.dist-info}/METADATA +4 -6
  199. {chalkpy-2.90.1.dist-info → chalkpy-2.95.3.dist-info}/RECORD +202 -146
  200. {chalkpy-2.90.1.dist-info → chalkpy-2.95.3.dist-info}/WHEEL +0 -0
  201. {chalkpy-2.90.1.dist-info → chalkpy-2.95.3.dist-info}/entry_points.txt +0 -0
  202. {chalkpy-2.90.1.dist-info → chalkpy-2.95.3.dist-info}/top_level.txt +0 -0
chalk/client/dataset.py CHANGED
@@ -566,7 +566,13 @@ def _extract_df_columns(
566
566
 
567
567
  decoded_stmts: List[pl.Expr] = []
568
568
  feature_name_to_metadata = None if column_metadata is None else {x.feature_fqn: x for x in column_metadata}
569
- for col, dtype in zip(df.columns, df.dtypes):
569
+ # Use collect_schema().dtypes() for newer Polars versions to avoid performance warning
570
+ # Fall back to df.dtypes for older versions
571
+ try:
572
+ dtypes = df.collect_schema().dtypes()
573
+ except AttributeError:
574
+ dtypes = df.dtypes
575
+ for col, dtype in zip(df.columns, dtypes):
570
576
  if version in (
571
577
  DatasetVersion.BIGQUERY_JOB_WITH_B32_ENCODED_COLNAMES,
572
578
  DatasetVersion.BIGQUERY_JOB_WITH_B32_ENCODED_COLNAMES_V2,
chalk/client/models.py CHANGED
@@ -460,6 +460,15 @@ class OfflineQueryInput(BaseModel):
460
460
  values: List[List[Any]] # Values should be of type TJSON
461
461
 
462
462
 
463
+ class OfflineQueryInputSql(BaseModel):
464
+ """Input to an offline query specified as a ChalkSQL query instead
465
+ of literal data.
466
+
467
+ Alternative to OfflineQueryInput or OfflineQueryInputUri."""
468
+
469
+ input_sql: str
470
+
471
+
463
472
  class OnlineQueryRequest(BaseModel):
464
473
  inputs: Mapping[str, Any] # Values should be of type TJSON
465
474
  outputs: List[str]
@@ -838,6 +847,7 @@ class CreateOfflineQueryJobRequest(BaseModel):
838
847
  None,
839
848
  UploadedParquetShardedOfflineQueryInput,
840
849
  OfflineQueryInputUri,
850
+ OfflineQueryInputSql,
841
851
  ] = None
842
852
  """Any givens"""
843
853
 
@@ -1658,6 +1668,7 @@ class PlanQueryResponse(BaseModel):
1658
1668
  output_schema: List[FeatureSchema]
1659
1669
  errors: List[ChalkError]
1660
1670
  structured_plan: Optional[str] = None
1671
+ serialized_plan_proto_bytes: Optional[str] = None
1661
1672
 
1662
1673
 
1663
1674
  class IngestDatasetRequest(BaseModel):
@@ -1782,3 +1793,90 @@ class GetRegisteredModelVersionResponse(BaseModel):
1782
1793
 
1783
1794
  class CreateModelTrainingJobResponse(BaseModel):
1784
1795
  success: bool
1796
+
1797
+
1798
+ class ScheduledQueryRunStatus(str, Enum):
1799
+ """Status of a scheduled query run."""
1800
+
1801
+ UNSPECIFIED = "UNSPECIFIED"
1802
+ INITIALIZING = "INITIALIZING"
1803
+ INIT_FAILED = "INIT_FAILED"
1804
+ SKIPPED = "SKIPPED"
1805
+ QUEUED = "QUEUED"
1806
+ WORKING = "WORKING"
1807
+ COMPLETED = "COMPLETED"
1808
+ FAILED = "FAILED"
1809
+ CANCELED = "CANCELED"
1810
+
1811
+
1812
+ @dataclasses.dataclass
1813
+ class ScheduledQueryRun:
1814
+ """A single scheduled query run."""
1815
+
1816
+ id: int
1817
+ environment_id: str
1818
+ deployment_id: str
1819
+ run_id: str
1820
+ cron_query_id: int
1821
+ cron_query_schedule_id: int
1822
+ cron_name: str
1823
+ gcr_execution_id: str
1824
+ gcr_job_name: str
1825
+ offline_query_id: str
1826
+ created_at: datetime
1827
+ updated_at: datetime
1828
+ status: ScheduledQueryRunStatus
1829
+ blocker_operation_id: str
1830
+
1831
+ @staticmethod
1832
+ def from_proto(proto_run: Any) -> "ScheduledQueryRun":
1833
+ """Convert a proto ScheduledQueryRun to the dataclass version."""
1834
+ from datetime import timezone
1835
+
1836
+ # Map proto status enum to our enum
1837
+ status_map = {
1838
+ 0: ScheduledQueryRunStatus.UNSPECIFIED,
1839
+ 1: ScheduledQueryRunStatus.INITIALIZING,
1840
+ 2: ScheduledQueryRunStatus.INIT_FAILED,
1841
+ 3: ScheduledQueryRunStatus.SKIPPED,
1842
+ 4: ScheduledQueryRunStatus.QUEUED,
1843
+ 5: ScheduledQueryRunStatus.WORKING,
1844
+ 6: ScheduledQueryRunStatus.COMPLETED,
1845
+ 7: ScheduledQueryRunStatus.FAILED,
1846
+ 8: ScheduledQueryRunStatus.CANCELED,
1847
+ }
1848
+
1849
+ # Helper to convert proto Timestamp to datetime
1850
+ def _timestamp_to_datetime(ts: Any) -> datetime:
1851
+ return datetime.fromtimestamp(ts.seconds + ts.nanos / 1e9, tz=timezone.utc)
1852
+
1853
+ return ScheduledQueryRun(
1854
+ id=proto_run.id,
1855
+ environment_id=proto_run.environment_id,
1856
+ deployment_id=proto_run.deployment_id,
1857
+ run_id=proto_run.run_id,
1858
+ cron_query_id=proto_run.cron_query_id,
1859
+ cron_query_schedule_id=proto_run.cron_query_schedule_id,
1860
+ cron_name=proto_run.cron_name,
1861
+ gcr_execution_id=proto_run.gcr_execution_id,
1862
+ gcr_job_name=proto_run.gcr_job_name,
1863
+ offline_query_id=proto_run.offline_query_id,
1864
+ created_at=_timestamp_to_datetime(proto_run.created_at),
1865
+ updated_at=_timestamp_to_datetime(proto_run.updated_at),
1866
+ status=status_map.get(proto_run.status, ScheduledQueryRunStatus.UNSPECIFIED),
1867
+ blocker_operation_id=proto_run.blocker_operation_id,
1868
+ )
1869
+
1870
+
1871
+ @dataclasses.dataclass
1872
+ class ManualTriggerScheduledQueryResponse:
1873
+ """Response from manually triggering a scheduled query."""
1874
+
1875
+ scheduled_query_run: ScheduledQueryRun
1876
+
1877
+ @staticmethod
1878
+ def from_proto(proto_response: Any) -> "ManualTriggerScheduledQueryResponse":
1879
+ """Convert a proto ManualTriggerScheduledQueryResponse to the dataclass version."""
1880
+ return ManualTriggerScheduledQueryResponse(
1881
+ scheduled_query_run=ScheduledQueryRun.from_proto(proto_response.scheduled_query_run),
1882
+ )
@@ -77,7 +77,15 @@ MODEL_SERIALIZERS = {
77
77
  ModelType.ONNX: ModelSerializationConfig(
78
78
  filename="model.onnx",
79
79
  encoding=ModelEncoding.PROTOBUF,
80
- serialize_fn=lambda model, path: model.save_model(path),
80
+ serialize_fn=lambda model, path: ModelSerializer.with_import(
81
+ "onnx",
82
+ lambda onnx: onnx.save(
83
+ # Unwrap model if it has a _model attribute (e.g., wrapped ONNX models)
84
+ model._model if hasattr(model, "_model") else model,
85
+ path,
86
+ ),
87
+ "Please install onnx to save ONNX models.",
88
+ ),
81
89
  ),
82
90
  }
83
91
 
@@ -281,7 +289,15 @@ class ModelSerializer:
281
289
  tensor_schema = _model_artifact_pb2.TensorSchema()
282
290
 
283
291
  for shape, dtype in tensor_specs:
284
- if not isinstance(dtype, pa.DataType):
292
+ # Handle Chalk Tensor types
293
+ if hasattr(dtype, "__mro__") and any("Tensor" in base.__name__ for base in dtype.__mro__):
294
+ # Extract shape and dtype from Tensor type
295
+ if hasattr(dtype, "shape") and hasattr(dtype, "dtype"):
296
+ shape = dtype.shape
297
+ pa_dtype = dtype.dtype
298
+ else:
299
+ raise ValueError(f"Tensor type is missing shape or dtype attributes")
300
+ elif not isinstance(dtype, pa.DataType):
285
301
  if dtype == str:
286
302
  pa_dtype = pa.string()
287
303
  elif dtype == int:
@@ -305,12 +321,73 @@ class ModelSerializer:
305
321
 
306
322
  return tensor_schema
307
323
 
324
+ @staticmethod
325
+ def convert_onnx_list_schema_to_dict(schema: Any, model: Any, is_input: bool = True) -> Any:
326
+ """Convert list-based schema to dict-based schema for ONNX models.
327
+
328
+ Args:
329
+ schema: The schema (list or dict)
330
+ model: The ONNX model (ModelProto or wrapped)
331
+ is_input: True for input schema, False for output schema
332
+
333
+ Returns:
334
+ Dict-based schema with field names from ONNX model
335
+ """
336
+ if not isinstance(schema, list):
337
+ return schema
338
+
339
+ try:
340
+ import onnx # type: ignore[reportMissingImports]
341
+ except ImportError:
342
+ raise ValueError("onnx package is required to convert list schemas for ONNX models")
343
+
344
+ # Unwrap model if needed
345
+ onnx_model = model._model if hasattr(model, "_model") else model
346
+
347
+ if not isinstance(onnx_model, onnx.ModelProto):
348
+ raise ValueError(
349
+ f"ONNX models must be registered with tabular schema (dict format). "
350
+ + f"Use dict format like {{'input': Tensor[...]}} instead of list format."
351
+ )
352
+
353
+ # Get input/output names from ONNX model
354
+ if is_input:
355
+ names = [inp.name for inp in onnx_model.graph.input]
356
+ schema_type = "input"
357
+ else:
358
+ names = [out.name for out in onnx_model.graph.output]
359
+ schema_type = "output"
360
+
361
+ if len(names) != len(schema):
362
+ raise ValueError(f"ONNX model has {len(names)} {schema_type}s but schema has {len(schema)} entries")
363
+
364
+ # Convert to dict format
365
+ return {name: spec for name, spec in zip(names, schema)}
366
+
308
367
  @staticmethod
309
368
  def convert_schema(schema: Any) -> Optional[_model_artifact_pb2.ModelSchema]:
310
369
  model_schema = _model_artifact_pb2.ModelSchema()
311
370
  if schema is not None:
312
371
  if isinstance(schema, dict):
313
- model_schema.tabular.CopyFrom(ModelSerializer.build_tabular_schema(schema))
372
+ # Convert Tensor/Vector types to their PyArrow types for tabular schema
373
+ converted_schema = {}
374
+ for col_name, dtype in schema.items():
375
+ if hasattr(dtype, "__mro__") and any("Tensor" in base.__name__ for base in dtype.__mro__):
376
+ # Use Tensor's to_pyarrow_dtype() method to convert to Arrow type
377
+ if hasattr(dtype, "to_pyarrow_dtype"):
378
+ converted_schema[col_name] = dtype.to_pyarrow_dtype()
379
+ else:
380
+ raise ValueError(f"Tensor type for '{col_name}' is missing to_pyarrow_dtype method")
381
+ elif hasattr(dtype, "__mro__") and any("Vector" in base.__name__ for base in dtype.__mro__):
382
+ # Vector already has a .dtype attribute that's a PyArrow type
383
+ if hasattr(dtype, "dtype"):
384
+ converted_schema[col_name] = dtype.dtype
385
+ else:
386
+ raise ValueError(f"Vector type for '{col_name}' is missing dtype attribute")
387
+ else:
388
+ converted_schema[col_name] = dtype
389
+
390
+ model_schema.tabular.CopyFrom(ModelSerializer.build_tabular_schema(converted_schema))
314
391
  elif isinstance(schema, list):
315
392
  model_schema.tensor.CopyFrom(ModelSerializer.build_tensor_schema(schema))
316
393
  else:
@@ -322,21 +399,27 @@ class ModelSerializer:
322
399
 
323
400
  @staticmethod
324
401
  def convert_run_criterion_to_proto(
325
- run_name: Optional[str], criterion: Optional[ModelRunCriterion]
402
+ run_id: Optional[str] = None, run_name: Optional[str] = None, criterion: Optional[ModelRunCriterion] = None
326
403
  ) -> Optional[RunCriterion]:
327
- if run_name is None:
328
- return None
404
+ if run_id is None and run_name is None:
405
+ raise ValueError("Please specify either run_id or run_name.")
329
406
 
330
407
  if criterion is None:
331
- return RunCriterion(run_id=run_name)
408
+ return RunCriterion(run_id=run_id, run_name=run_name)
332
409
 
333
410
  if criterion.direction == "max":
334
411
  return RunCriterion(
335
- run_id=run_name, metric=criterion.metric, direction=RunCriterionDirection.RUN_CRITERION_DIRECTION_MAX
412
+ run_id=run_id,
413
+ run_name=run_name,
414
+ metric=criterion.metric,
415
+ direction=RunCriterionDirection.RUN_CRITERION_DIRECTION_MAX,
336
416
  )
337
417
  elif criterion.direction == "min":
338
418
  return RunCriterion(
339
- run_id=run_name, metric=criterion.metric, direction=RunCriterionDirection.RUN_CRITERION_DIRECTION_MIN
419
+ run_id=run_id,
420
+ run_name=run_name,
421
+ metric=criterion.metric,
422
+ direction=RunCriterionDirection.RUN_CRITERION_DIRECTION_MIN,
340
423
  )
341
424
  else:
342
425
  raise ValueError(