chalkpy 2.90.1__py3-none-any.whl → 2.95.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- chalk/__init__.py +2 -1
- chalk/_gen/chalk/arrow/v1/arrow_pb2.py +7 -5
- chalk/_gen/chalk/arrow/v1/arrow_pb2.pyi +6 -0
- chalk/_gen/chalk/artifacts/v1/chart_pb2.py +16 -16
- chalk/_gen/chalk/artifacts/v1/chart_pb2.pyi +4 -0
- chalk/_gen/chalk/artifacts/v1/cron_query_pb2.py +8 -7
- chalk/_gen/chalk/artifacts/v1/cron_query_pb2.pyi +5 -0
- chalk/_gen/chalk/common/v1/offline_query_pb2.py +17 -15
- chalk/_gen/chalk/common/v1/offline_query_pb2.pyi +25 -0
- chalk/_gen/chalk/common/v1/script_task_pb2.py +3 -3
- chalk/_gen/chalk/common/v1/script_task_pb2.pyi +2 -0
- chalk/_gen/chalk/dataframe/__init__.py +0 -0
- chalk/_gen/chalk/dataframe/v1/__init__.py +0 -0
- chalk/_gen/chalk/dataframe/v1/dataframe_pb2.py +48 -0
- chalk/_gen/chalk/dataframe/v1/dataframe_pb2.pyi +123 -0
- chalk/_gen/chalk/dataframe/v1/dataframe_pb2_grpc.py +4 -0
- chalk/_gen/chalk/dataframe/v1/dataframe_pb2_grpc.pyi +4 -0
- chalk/_gen/chalk/graph/v1/graph_pb2.py +150 -149
- chalk/_gen/chalk/graph/v1/graph_pb2.pyi +25 -0
- chalk/_gen/chalk/graph/v1/sources_pb2.py +94 -84
- chalk/_gen/chalk/graph/v1/sources_pb2.pyi +56 -0
- chalk/_gen/chalk/kubernetes/v1/horizontalpodautoscaler_pb2.py +79 -0
- chalk/_gen/chalk/kubernetes/v1/horizontalpodautoscaler_pb2.pyi +377 -0
- chalk/_gen/chalk/kubernetes/v1/horizontalpodautoscaler_pb2_grpc.py +4 -0
- chalk/_gen/chalk/kubernetes/v1/horizontalpodautoscaler_pb2_grpc.pyi +4 -0
- chalk/_gen/chalk/kubernetes/v1/scaledobject_pb2.py +43 -7
- chalk/_gen/chalk/kubernetes/v1/scaledobject_pb2.pyi +252 -2
- chalk/_gen/chalk/protosql/v1/sql_service_pb2.py +54 -27
- chalk/_gen/chalk/protosql/v1/sql_service_pb2.pyi +131 -3
- chalk/_gen/chalk/protosql/v1/sql_service_pb2_grpc.py +45 -0
- chalk/_gen/chalk/protosql/v1/sql_service_pb2_grpc.pyi +14 -0
- chalk/_gen/chalk/python/v1/types_pb2.py +14 -14
- chalk/_gen/chalk/python/v1/types_pb2.pyi +8 -0
- chalk/_gen/chalk/server/v1/benchmark_pb2.py +76 -0
- chalk/_gen/chalk/server/v1/benchmark_pb2.pyi +156 -0
- chalk/_gen/chalk/server/v1/benchmark_pb2_grpc.py +258 -0
- chalk/_gen/chalk/server/v1/benchmark_pb2_grpc.pyi +84 -0
- chalk/_gen/chalk/server/v1/billing_pb2.py +40 -38
- chalk/_gen/chalk/server/v1/billing_pb2.pyi +17 -1
- chalk/_gen/chalk/server/v1/branches_pb2.py +45 -0
- chalk/_gen/chalk/server/v1/branches_pb2.pyi +80 -0
- chalk/_gen/chalk/server/v1/branches_pb2_grpc.pyi +36 -0
- chalk/_gen/chalk/server/v1/builder_pb2.py +358 -288
- chalk/_gen/chalk/server/v1/builder_pb2.pyi +360 -10
- chalk/_gen/chalk/server/v1/builder_pb2_grpc.py +225 -0
- chalk/_gen/chalk/server/v1/builder_pb2_grpc.pyi +60 -0
- chalk/_gen/chalk/server/v1/chart_pb2.py +10 -10
- chalk/_gen/chalk/server/v1/chart_pb2.pyi +18 -2
- chalk/_gen/chalk/server/v1/clickhouse_pb2.py +42 -0
- chalk/_gen/chalk/server/v1/clickhouse_pb2.pyi +17 -0
- chalk/_gen/chalk/server/v1/clickhouse_pb2_grpc.py +78 -0
- chalk/_gen/chalk/server/v1/clickhouse_pb2_grpc.pyi +38 -0
- chalk/_gen/chalk/server/v1/cloud_components_pb2.py +141 -119
- chalk/_gen/chalk/server/v1/cloud_components_pb2.pyi +106 -4
- chalk/_gen/chalk/server/v1/cloud_components_pb2_grpc.py +45 -0
- chalk/_gen/chalk/server/v1/cloud_components_pb2_grpc.pyi +12 -0
- chalk/_gen/chalk/server/v1/cloud_credentials_pb2.py +11 -3
- chalk/_gen/chalk/server/v1/cloud_credentials_pb2.pyi +20 -0
- chalk/_gen/chalk/server/v1/cloud_credentials_pb2_grpc.py +45 -0
- chalk/_gen/chalk/server/v1/cloud_credentials_pb2_grpc.pyi +12 -0
- chalk/_gen/chalk/server/v1/dataplanejobqueue_pb2.py +52 -38
- chalk/_gen/chalk/server/v1/dataplanejobqueue_pb2.pyi +62 -1
- chalk/_gen/chalk/server/v1/dataplanejobqueue_pb2_grpc.py +90 -0
- chalk/_gen/chalk/server/v1/dataplanejobqueue_pb2_grpc.pyi +24 -0
- chalk/_gen/chalk/server/v1/dataplaneworkflows_pb2.py +90 -0
- chalk/_gen/chalk/server/v1/dataplaneworkflows_pb2.pyi +264 -0
- chalk/_gen/chalk/server/v1/dataplaneworkflows_pb2_grpc.py +170 -0
- chalk/_gen/chalk/server/v1/dataplaneworkflows_pb2_grpc.pyi +62 -0
- chalk/_gen/chalk/server/v1/deploy_pb2.py +9 -3
- chalk/_gen/chalk/server/v1/deploy_pb2.pyi +12 -0
- chalk/_gen/chalk/server/v1/deploy_pb2_grpc.py +45 -0
- chalk/_gen/chalk/server/v1/deploy_pb2_grpc.pyi +12 -0
- chalk/_gen/chalk/server/v1/deployment_pb2.py +6 -6
- chalk/_gen/chalk/server/v1/deployment_pb2.pyi +20 -0
- chalk/_gen/chalk/server/v1/environment_pb2.py +14 -12
- chalk/_gen/chalk/server/v1/environment_pb2.pyi +19 -0
- chalk/_gen/chalk/server/v1/eventbus_pb2.py +4 -2
- chalk/_gen/chalk/server/v1/files_pb2.py +65 -0
- chalk/_gen/chalk/server/v1/files_pb2.pyi +167 -0
- chalk/_gen/chalk/server/v1/files_pb2_grpc.py +4 -0
- chalk/_gen/chalk/server/v1/files_pb2_grpc.pyi +4 -0
- chalk/_gen/chalk/server/v1/graph_pb2.py +38 -26
- chalk/_gen/chalk/server/v1/graph_pb2.pyi +58 -0
- chalk/_gen/chalk/server/v1/graph_pb2_grpc.py +47 -0
- chalk/_gen/chalk/server/v1/graph_pb2_grpc.pyi +18 -0
- chalk/_gen/chalk/server/v1/incident_pb2.py +23 -21
- chalk/_gen/chalk/server/v1/incident_pb2.pyi +15 -1
- chalk/_gen/chalk/server/v1/indexing_job_pb2.py +44 -0
- chalk/_gen/chalk/server/v1/indexing_job_pb2.pyi +38 -0
- chalk/_gen/chalk/server/v1/indexing_job_pb2_grpc.py +78 -0
- chalk/_gen/chalk/server/v1/indexing_job_pb2_grpc.pyi +38 -0
- chalk/_gen/chalk/server/v1/integrations_pb2.py +11 -9
- chalk/_gen/chalk/server/v1/integrations_pb2.pyi +34 -2
- chalk/_gen/chalk/server/v1/kube_pb2.py +29 -19
- chalk/_gen/chalk/server/v1/kube_pb2.pyi +28 -0
- chalk/_gen/chalk/server/v1/kube_pb2_grpc.py +45 -0
- chalk/_gen/chalk/server/v1/kube_pb2_grpc.pyi +12 -0
- chalk/_gen/chalk/server/v1/log_pb2.py +21 -3
- chalk/_gen/chalk/server/v1/log_pb2.pyi +68 -0
- chalk/_gen/chalk/server/v1/log_pb2_grpc.py +90 -0
- chalk/_gen/chalk/server/v1/log_pb2_grpc.pyi +24 -0
- chalk/_gen/chalk/server/v1/model_registry_pb2.py +10 -10
- chalk/_gen/chalk/server/v1/model_registry_pb2.pyi +4 -1
- chalk/_gen/chalk/server/v1/plandebug_pb2.py +53 -0
- chalk/_gen/chalk/server/v1/plandebug_pb2.pyi +86 -0
- chalk/_gen/chalk/server/v1/plandebug_pb2_grpc.py +168 -0
- chalk/_gen/chalk/server/v1/plandebug_pb2_grpc.pyi +60 -0
- chalk/_gen/chalk/server/v1/queries_pb2.py +66 -66
- chalk/_gen/chalk/server/v1/queries_pb2.pyi +32 -2
- chalk/_gen/chalk/server/v1/scheduled_query_run_pb2.py +12 -12
- chalk/_gen/chalk/server/v1/scheduled_query_run_pb2.pyi +16 -3
- chalk/_gen/chalk/server/v1/scheduler_pb2.py +24 -12
- chalk/_gen/chalk/server/v1/scheduler_pb2.pyi +61 -1
- chalk/_gen/chalk/server/v1/scheduler_pb2_grpc.py +90 -0
- chalk/_gen/chalk/server/v1/scheduler_pb2_grpc.pyi +24 -0
- chalk/_gen/chalk/server/v1/script_tasks_pb2.py +15 -3
- chalk/_gen/chalk/server/v1/script_tasks_pb2.pyi +22 -0
- chalk/_gen/chalk/server/v1/script_tasks_pb2_grpc.py +90 -0
- chalk/_gen/chalk/server/v1/script_tasks_pb2_grpc.pyi +24 -0
- chalk/_gen/chalk/server/v1/sql_interface_pb2.py +75 -0
- chalk/_gen/chalk/server/v1/sql_interface_pb2.pyi +142 -0
- chalk/_gen/chalk/server/v1/sql_interface_pb2_grpc.py +349 -0
- chalk/_gen/chalk/server/v1/sql_interface_pb2_grpc.pyi +114 -0
- chalk/_gen/chalk/server/v1/sql_queries_pb2.py +48 -0
- chalk/_gen/chalk/server/v1/sql_queries_pb2.pyi +150 -0
- chalk/_gen/chalk/server/v1/sql_queries_pb2_grpc.py +123 -0
- chalk/_gen/chalk/server/v1/sql_queries_pb2_grpc.pyi +52 -0
- chalk/_gen/chalk/server/v1/team_pb2.py +154 -141
- chalk/_gen/chalk/server/v1/team_pb2.pyi +30 -2
- chalk/_gen/chalk/server/v1/team_pb2_grpc.py +45 -0
- chalk/_gen/chalk/server/v1/team_pb2_grpc.pyi +12 -0
- chalk/_gen/chalk/server/v1/topic_pb2.py +5 -3
- chalk/_gen/chalk/server/v1/topic_pb2.pyi +10 -1
- chalk/_gen/chalk/server/v1/trace_pb2.py +44 -40
- chalk/_gen/chalk/server/v1/trace_pb2.pyi +20 -0
- chalk/_gen/chalk/streaming/v1/debug_service_pb2.py +62 -0
- chalk/_gen/chalk/streaming/v1/debug_service_pb2.pyi +75 -0
- chalk/_gen/chalk/streaming/v1/debug_service_pb2_grpc.py +221 -0
- chalk/_gen/chalk/streaming/v1/debug_service_pb2_grpc.pyi +88 -0
- chalk/_gen/chalk/streaming/v1/simple_streaming_service_pb2.py +16 -10
- chalk/_gen/chalk/streaming/v1/simple_streaming_service_pb2.pyi +52 -1
- chalk/_gen/chalk/streaming/v1/simple_streaming_service_pb2_grpc.py +48 -0
- chalk/_gen/chalk/streaming/v1/simple_streaming_service_pb2_grpc.pyi +20 -0
- chalk/_gen/chalk/utils/v1/field_change_pb2.py +32 -0
- chalk/_gen/chalk/utils/v1/field_change_pb2.pyi +42 -0
- chalk/_gen/chalk/utils/v1/field_change_pb2_grpc.py +4 -0
- chalk/_gen/chalk/utils/v1/field_change_pb2_grpc.pyi +4 -0
- chalk/_lsp/error_builder.py +11 -0
- chalk/_version.py +1 -1
- chalk/client/client.py +128 -43
- chalk/client/client_async.py +149 -0
- chalk/client/client_async_impl.py +22 -0
- chalk/client/client_grpc.py +539 -104
- chalk/client/client_impl.py +449 -122
- chalk/client/dataset.py +7 -1
- chalk/client/models.py +98 -0
- chalk/client/serialization/model_serialization.py +92 -9
- chalk/df/LazyFramePlaceholder.py +1154 -0
- chalk/features/_class_property.py +7 -0
- chalk/features/_embedding/embedding.py +1 -0
- chalk/features/_encoding/converter.py +83 -2
- chalk/features/feature_field.py +40 -30
- chalk/features/feature_set_decorator.py +1 -0
- chalk/features/feature_wrapper.py +42 -3
- chalk/features/hooks.py +81 -10
- chalk/features/inference.py +33 -31
- chalk/features/resolver.py +224 -24
- chalk/functions/__init__.py +65 -3
- chalk/gitignore/gitignore_parser.py +5 -1
- chalk/importer.py +142 -68
- chalk/ml/__init__.py +2 -0
- chalk/ml/model_hooks.py +194 -26
- chalk/ml/model_reference.py +56 -8
- chalk/ml/model_version.py +24 -15
- chalk/ml/utils.py +20 -17
- chalk/operators/_utils.py +10 -3
- chalk/parsed/_proto/export.py +22 -0
- chalk/parsed/duplicate_input_gql.py +3 -0
- chalk/parsed/json_conversions.py +20 -14
- chalk/parsed/to_proto.py +16 -4
- chalk/parsed/user_types_to_json.py +31 -10
- chalk/parsed/validation_from_registries.py +182 -0
- chalk/queries/named_query.py +16 -6
- chalk/queries/scheduled_query.py +9 -1
- chalk/serialization/parsed_annotation.py +24 -11
- chalk/sql/__init__.py +18 -0
- chalk/sql/_internal/integrations/databricks.py +55 -17
- chalk/sql/_internal/integrations/mssql.py +127 -62
- chalk/sql/_internal/integrations/redshift.py +4 -0
- chalk/sql/_internal/sql_file_resolver.py +53 -9
- chalk/sql/_internal/sql_source.py +35 -2
- chalk/streams/_kafka_source.py +5 -1
- chalk/streams/_windows.py +15 -2
- chalk/utils/_otel_version.py +13 -0
- chalk/utils/async_helpers.py +2 -2
- chalk/utils/missing_dependency.py +5 -4
- chalk/utils/tracing.py +185 -95
- {chalkpy-2.90.1.dist-info → chalkpy-2.95.3.dist-info}/METADATA +4 -6
- {chalkpy-2.90.1.dist-info → chalkpy-2.95.3.dist-info}/RECORD +202 -146
- {chalkpy-2.90.1.dist-info → chalkpy-2.95.3.dist-info}/WHEEL +0 -0
- {chalkpy-2.90.1.dist-info → chalkpy-2.95.3.dist-info}/entry_points.txt +0 -0
- {chalkpy-2.90.1.dist-info → chalkpy-2.95.3.dist-info}/top_level.txt +0 -0
chalk/client/dataset.py
CHANGED
|
@@ -566,7 +566,13 @@ def _extract_df_columns(
|
|
|
566
566
|
|
|
567
567
|
decoded_stmts: List[pl.Expr] = []
|
|
568
568
|
feature_name_to_metadata = None if column_metadata is None else {x.feature_fqn: x for x in column_metadata}
|
|
569
|
-
for
|
|
569
|
+
# Use collect_schema().dtypes() for newer Polars versions to avoid performance warning
|
|
570
|
+
# Fall back to df.dtypes for older versions
|
|
571
|
+
try:
|
|
572
|
+
dtypes = df.collect_schema().dtypes()
|
|
573
|
+
except AttributeError:
|
|
574
|
+
dtypes = df.dtypes
|
|
575
|
+
for col, dtype in zip(df.columns, dtypes):
|
|
570
576
|
if version in (
|
|
571
577
|
DatasetVersion.BIGQUERY_JOB_WITH_B32_ENCODED_COLNAMES,
|
|
572
578
|
DatasetVersion.BIGQUERY_JOB_WITH_B32_ENCODED_COLNAMES_V2,
|
chalk/client/models.py
CHANGED
|
@@ -460,6 +460,15 @@ class OfflineQueryInput(BaseModel):
|
|
|
460
460
|
values: List[List[Any]] # Values should be of type TJSON
|
|
461
461
|
|
|
462
462
|
|
|
463
|
+
class OfflineQueryInputSql(BaseModel):
|
|
464
|
+
"""Input to an offline query specified as a ChalkSQL query instead
|
|
465
|
+
of literal data.
|
|
466
|
+
|
|
467
|
+
Alternative to OfflineQueryInput or OfflineQueryInputUri."""
|
|
468
|
+
|
|
469
|
+
input_sql: str
|
|
470
|
+
|
|
471
|
+
|
|
463
472
|
class OnlineQueryRequest(BaseModel):
|
|
464
473
|
inputs: Mapping[str, Any] # Values should be of type TJSON
|
|
465
474
|
outputs: List[str]
|
|
@@ -838,6 +847,7 @@ class CreateOfflineQueryJobRequest(BaseModel):
|
|
|
838
847
|
None,
|
|
839
848
|
UploadedParquetShardedOfflineQueryInput,
|
|
840
849
|
OfflineQueryInputUri,
|
|
850
|
+
OfflineQueryInputSql,
|
|
841
851
|
] = None
|
|
842
852
|
"""Any givens"""
|
|
843
853
|
|
|
@@ -1658,6 +1668,7 @@ class PlanQueryResponse(BaseModel):
|
|
|
1658
1668
|
output_schema: List[FeatureSchema]
|
|
1659
1669
|
errors: List[ChalkError]
|
|
1660
1670
|
structured_plan: Optional[str] = None
|
|
1671
|
+
serialized_plan_proto_bytes: Optional[str] = None
|
|
1661
1672
|
|
|
1662
1673
|
|
|
1663
1674
|
class IngestDatasetRequest(BaseModel):
|
|
@@ -1782,3 +1793,90 @@ class GetRegisteredModelVersionResponse(BaseModel):
|
|
|
1782
1793
|
|
|
1783
1794
|
class CreateModelTrainingJobResponse(BaseModel):
|
|
1784
1795
|
success: bool
|
|
1796
|
+
|
|
1797
|
+
|
|
1798
|
+
class ScheduledQueryRunStatus(str, Enum):
|
|
1799
|
+
"""Status of a scheduled query run."""
|
|
1800
|
+
|
|
1801
|
+
UNSPECIFIED = "UNSPECIFIED"
|
|
1802
|
+
INITIALIZING = "INITIALIZING"
|
|
1803
|
+
INIT_FAILED = "INIT_FAILED"
|
|
1804
|
+
SKIPPED = "SKIPPED"
|
|
1805
|
+
QUEUED = "QUEUED"
|
|
1806
|
+
WORKING = "WORKING"
|
|
1807
|
+
COMPLETED = "COMPLETED"
|
|
1808
|
+
FAILED = "FAILED"
|
|
1809
|
+
CANCELED = "CANCELED"
|
|
1810
|
+
|
|
1811
|
+
|
|
1812
|
+
@dataclasses.dataclass
|
|
1813
|
+
class ScheduledQueryRun:
|
|
1814
|
+
"""A single scheduled query run."""
|
|
1815
|
+
|
|
1816
|
+
id: int
|
|
1817
|
+
environment_id: str
|
|
1818
|
+
deployment_id: str
|
|
1819
|
+
run_id: str
|
|
1820
|
+
cron_query_id: int
|
|
1821
|
+
cron_query_schedule_id: int
|
|
1822
|
+
cron_name: str
|
|
1823
|
+
gcr_execution_id: str
|
|
1824
|
+
gcr_job_name: str
|
|
1825
|
+
offline_query_id: str
|
|
1826
|
+
created_at: datetime
|
|
1827
|
+
updated_at: datetime
|
|
1828
|
+
status: ScheduledQueryRunStatus
|
|
1829
|
+
blocker_operation_id: str
|
|
1830
|
+
|
|
1831
|
+
@staticmethod
|
|
1832
|
+
def from_proto(proto_run: Any) -> "ScheduledQueryRun":
|
|
1833
|
+
"""Convert a proto ScheduledQueryRun to the dataclass version."""
|
|
1834
|
+
from datetime import timezone
|
|
1835
|
+
|
|
1836
|
+
# Map proto status enum to our enum
|
|
1837
|
+
status_map = {
|
|
1838
|
+
0: ScheduledQueryRunStatus.UNSPECIFIED,
|
|
1839
|
+
1: ScheduledQueryRunStatus.INITIALIZING,
|
|
1840
|
+
2: ScheduledQueryRunStatus.INIT_FAILED,
|
|
1841
|
+
3: ScheduledQueryRunStatus.SKIPPED,
|
|
1842
|
+
4: ScheduledQueryRunStatus.QUEUED,
|
|
1843
|
+
5: ScheduledQueryRunStatus.WORKING,
|
|
1844
|
+
6: ScheduledQueryRunStatus.COMPLETED,
|
|
1845
|
+
7: ScheduledQueryRunStatus.FAILED,
|
|
1846
|
+
8: ScheduledQueryRunStatus.CANCELED,
|
|
1847
|
+
}
|
|
1848
|
+
|
|
1849
|
+
# Helper to convert proto Timestamp to datetime
|
|
1850
|
+
def _timestamp_to_datetime(ts: Any) -> datetime:
|
|
1851
|
+
return datetime.fromtimestamp(ts.seconds + ts.nanos / 1e9, tz=timezone.utc)
|
|
1852
|
+
|
|
1853
|
+
return ScheduledQueryRun(
|
|
1854
|
+
id=proto_run.id,
|
|
1855
|
+
environment_id=proto_run.environment_id,
|
|
1856
|
+
deployment_id=proto_run.deployment_id,
|
|
1857
|
+
run_id=proto_run.run_id,
|
|
1858
|
+
cron_query_id=proto_run.cron_query_id,
|
|
1859
|
+
cron_query_schedule_id=proto_run.cron_query_schedule_id,
|
|
1860
|
+
cron_name=proto_run.cron_name,
|
|
1861
|
+
gcr_execution_id=proto_run.gcr_execution_id,
|
|
1862
|
+
gcr_job_name=proto_run.gcr_job_name,
|
|
1863
|
+
offline_query_id=proto_run.offline_query_id,
|
|
1864
|
+
created_at=_timestamp_to_datetime(proto_run.created_at),
|
|
1865
|
+
updated_at=_timestamp_to_datetime(proto_run.updated_at),
|
|
1866
|
+
status=status_map.get(proto_run.status, ScheduledQueryRunStatus.UNSPECIFIED),
|
|
1867
|
+
blocker_operation_id=proto_run.blocker_operation_id,
|
|
1868
|
+
)
|
|
1869
|
+
|
|
1870
|
+
|
|
1871
|
+
@dataclasses.dataclass
|
|
1872
|
+
class ManualTriggerScheduledQueryResponse:
|
|
1873
|
+
"""Response from manually triggering a scheduled query."""
|
|
1874
|
+
|
|
1875
|
+
scheduled_query_run: ScheduledQueryRun
|
|
1876
|
+
|
|
1877
|
+
@staticmethod
|
|
1878
|
+
def from_proto(proto_response: Any) -> "ManualTriggerScheduledQueryResponse":
|
|
1879
|
+
"""Convert a proto ManualTriggerScheduledQueryResponse to the dataclass version."""
|
|
1880
|
+
return ManualTriggerScheduledQueryResponse(
|
|
1881
|
+
scheduled_query_run=ScheduledQueryRun.from_proto(proto_response.scheduled_query_run),
|
|
1882
|
+
)
|
|
@@ -77,7 +77,15 @@ MODEL_SERIALIZERS = {
|
|
|
77
77
|
ModelType.ONNX: ModelSerializationConfig(
|
|
78
78
|
filename="model.onnx",
|
|
79
79
|
encoding=ModelEncoding.PROTOBUF,
|
|
80
|
-
serialize_fn=lambda model, path:
|
|
80
|
+
serialize_fn=lambda model, path: ModelSerializer.with_import(
|
|
81
|
+
"onnx",
|
|
82
|
+
lambda onnx: onnx.save(
|
|
83
|
+
# Unwrap model if it has a _model attribute (e.g., wrapped ONNX models)
|
|
84
|
+
model._model if hasattr(model, "_model") else model,
|
|
85
|
+
path,
|
|
86
|
+
),
|
|
87
|
+
"Please install onnx to save ONNX models.",
|
|
88
|
+
),
|
|
81
89
|
),
|
|
82
90
|
}
|
|
83
91
|
|
|
@@ -281,7 +289,15 @@ class ModelSerializer:
|
|
|
281
289
|
tensor_schema = _model_artifact_pb2.TensorSchema()
|
|
282
290
|
|
|
283
291
|
for shape, dtype in tensor_specs:
|
|
284
|
-
|
|
292
|
+
# Handle Chalk Tensor types
|
|
293
|
+
if hasattr(dtype, "__mro__") and any("Tensor" in base.__name__ for base in dtype.__mro__):
|
|
294
|
+
# Extract shape and dtype from Tensor type
|
|
295
|
+
if hasattr(dtype, "shape") and hasattr(dtype, "dtype"):
|
|
296
|
+
shape = dtype.shape
|
|
297
|
+
pa_dtype = dtype.dtype
|
|
298
|
+
else:
|
|
299
|
+
raise ValueError(f"Tensor type is missing shape or dtype attributes")
|
|
300
|
+
elif not isinstance(dtype, pa.DataType):
|
|
285
301
|
if dtype == str:
|
|
286
302
|
pa_dtype = pa.string()
|
|
287
303
|
elif dtype == int:
|
|
@@ -305,12 +321,73 @@ class ModelSerializer:
|
|
|
305
321
|
|
|
306
322
|
return tensor_schema
|
|
307
323
|
|
|
324
|
+
@staticmethod
|
|
325
|
+
def convert_onnx_list_schema_to_dict(schema: Any, model: Any, is_input: bool = True) -> Any:
|
|
326
|
+
"""Convert list-based schema to dict-based schema for ONNX models.
|
|
327
|
+
|
|
328
|
+
Args:
|
|
329
|
+
schema: The schema (list or dict)
|
|
330
|
+
model: The ONNX model (ModelProto or wrapped)
|
|
331
|
+
is_input: True for input schema, False for output schema
|
|
332
|
+
|
|
333
|
+
Returns:
|
|
334
|
+
Dict-based schema with field names from ONNX model
|
|
335
|
+
"""
|
|
336
|
+
if not isinstance(schema, list):
|
|
337
|
+
return schema
|
|
338
|
+
|
|
339
|
+
try:
|
|
340
|
+
import onnx # type: ignore[reportMissingImports]
|
|
341
|
+
except ImportError:
|
|
342
|
+
raise ValueError("onnx package is required to convert list schemas for ONNX models")
|
|
343
|
+
|
|
344
|
+
# Unwrap model if needed
|
|
345
|
+
onnx_model = model._model if hasattr(model, "_model") else model
|
|
346
|
+
|
|
347
|
+
if not isinstance(onnx_model, onnx.ModelProto):
|
|
348
|
+
raise ValueError(
|
|
349
|
+
f"ONNX models must be registered with tabular schema (dict format). "
|
|
350
|
+
+ f"Use dict format like {{'input': Tensor[...]}} instead of list format."
|
|
351
|
+
)
|
|
352
|
+
|
|
353
|
+
# Get input/output names from ONNX model
|
|
354
|
+
if is_input:
|
|
355
|
+
names = [inp.name for inp in onnx_model.graph.input]
|
|
356
|
+
schema_type = "input"
|
|
357
|
+
else:
|
|
358
|
+
names = [out.name for out in onnx_model.graph.output]
|
|
359
|
+
schema_type = "output"
|
|
360
|
+
|
|
361
|
+
if len(names) != len(schema):
|
|
362
|
+
raise ValueError(f"ONNX model has {len(names)} {schema_type}s but schema has {len(schema)} entries")
|
|
363
|
+
|
|
364
|
+
# Convert to dict format
|
|
365
|
+
return {name: spec for name, spec in zip(names, schema)}
|
|
366
|
+
|
|
308
367
|
@staticmethod
|
|
309
368
|
def convert_schema(schema: Any) -> Optional[_model_artifact_pb2.ModelSchema]:
|
|
310
369
|
model_schema = _model_artifact_pb2.ModelSchema()
|
|
311
370
|
if schema is not None:
|
|
312
371
|
if isinstance(schema, dict):
|
|
313
|
-
|
|
372
|
+
# Convert Tensor/Vector types to their PyArrow types for tabular schema
|
|
373
|
+
converted_schema = {}
|
|
374
|
+
for col_name, dtype in schema.items():
|
|
375
|
+
if hasattr(dtype, "__mro__") and any("Tensor" in base.__name__ for base in dtype.__mro__):
|
|
376
|
+
# Use Tensor's to_pyarrow_dtype() method to convert to Arrow type
|
|
377
|
+
if hasattr(dtype, "to_pyarrow_dtype"):
|
|
378
|
+
converted_schema[col_name] = dtype.to_pyarrow_dtype()
|
|
379
|
+
else:
|
|
380
|
+
raise ValueError(f"Tensor type for '{col_name}' is missing to_pyarrow_dtype method")
|
|
381
|
+
elif hasattr(dtype, "__mro__") and any("Vector" in base.__name__ for base in dtype.__mro__):
|
|
382
|
+
# Vector already has a .dtype attribute that's a PyArrow type
|
|
383
|
+
if hasattr(dtype, "dtype"):
|
|
384
|
+
converted_schema[col_name] = dtype.dtype
|
|
385
|
+
else:
|
|
386
|
+
raise ValueError(f"Vector type for '{col_name}' is missing dtype attribute")
|
|
387
|
+
else:
|
|
388
|
+
converted_schema[col_name] = dtype
|
|
389
|
+
|
|
390
|
+
model_schema.tabular.CopyFrom(ModelSerializer.build_tabular_schema(converted_schema))
|
|
314
391
|
elif isinstance(schema, list):
|
|
315
392
|
model_schema.tensor.CopyFrom(ModelSerializer.build_tensor_schema(schema))
|
|
316
393
|
else:
|
|
@@ -322,21 +399,27 @@ class ModelSerializer:
|
|
|
322
399
|
|
|
323
400
|
@staticmethod
|
|
324
401
|
def convert_run_criterion_to_proto(
|
|
325
|
-
run_name: Optional[str], criterion: Optional[ModelRunCriterion]
|
|
402
|
+
run_id: Optional[str] = None, run_name: Optional[str] = None, criterion: Optional[ModelRunCriterion] = None
|
|
326
403
|
) -> Optional[RunCriterion]:
|
|
327
|
-
if run_name is None:
|
|
328
|
-
|
|
404
|
+
if run_id is None and run_name is None:
|
|
405
|
+
raise ValueError("Please specify either run_id or run_name.")
|
|
329
406
|
|
|
330
407
|
if criterion is None:
|
|
331
|
-
return RunCriterion(run_id=run_name)
|
|
408
|
+
return RunCriterion(run_id=run_id, run_name=run_name)
|
|
332
409
|
|
|
333
410
|
if criterion.direction == "max":
|
|
334
411
|
return RunCriterion(
|
|
335
|
-
run_id=
|
|
412
|
+
run_id=run_id,
|
|
413
|
+
run_name=run_name,
|
|
414
|
+
metric=criterion.metric,
|
|
415
|
+
direction=RunCriterionDirection.RUN_CRITERION_DIRECTION_MAX,
|
|
336
416
|
)
|
|
337
417
|
elif criterion.direction == "min":
|
|
338
418
|
return RunCriterion(
|
|
339
|
-
run_id=
|
|
419
|
+
run_id=run_id,
|
|
420
|
+
run_name=run_name,
|
|
421
|
+
metric=criterion.metric,
|
|
422
|
+
direction=RunCriterionDirection.RUN_CRITERION_DIRECTION_MIN,
|
|
340
423
|
)
|
|
341
424
|
else:
|
|
342
425
|
raise ValueError(
|