chalkpy 2.89.22__py3-none-any.whl → 2.95.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (268) hide show
  1. chalk/__init__.py +2 -1
  2. chalk/_gen/chalk/arrow/v1/arrow_pb2.py +7 -5
  3. chalk/_gen/chalk/arrow/v1/arrow_pb2.pyi +6 -0
  4. chalk/_gen/chalk/artifacts/v1/chart_pb2.py +36 -33
  5. chalk/_gen/chalk/artifacts/v1/chart_pb2.pyi +41 -1
  6. chalk/_gen/chalk/artifacts/v1/cron_query_pb2.py +8 -7
  7. chalk/_gen/chalk/artifacts/v1/cron_query_pb2.pyi +5 -0
  8. chalk/_gen/chalk/common/v1/offline_query_pb2.py +19 -13
  9. chalk/_gen/chalk/common/v1/offline_query_pb2.pyi +37 -0
  10. chalk/_gen/chalk/common/v1/online_query_pb2.py +54 -54
  11. chalk/_gen/chalk/common/v1/online_query_pb2.pyi +13 -1
  12. chalk/_gen/chalk/common/v1/script_task_pb2.py +13 -11
  13. chalk/_gen/chalk/common/v1/script_task_pb2.pyi +19 -1
  14. chalk/_gen/chalk/dataframe/__init__.py +0 -0
  15. chalk/_gen/chalk/dataframe/v1/__init__.py +0 -0
  16. chalk/_gen/chalk/dataframe/v1/dataframe_pb2.py +48 -0
  17. chalk/_gen/chalk/dataframe/v1/dataframe_pb2.pyi +123 -0
  18. chalk/_gen/chalk/dataframe/v1/dataframe_pb2_grpc.py +4 -0
  19. chalk/_gen/chalk/dataframe/v1/dataframe_pb2_grpc.pyi +4 -0
  20. chalk/_gen/chalk/graph/v1/graph_pb2.py +150 -149
  21. chalk/_gen/chalk/graph/v1/graph_pb2.pyi +25 -0
  22. chalk/_gen/chalk/graph/v1/sources_pb2.py +94 -84
  23. chalk/_gen/chalk/graph/v1/sources_pb2.pyi +56 -0
  24. chalk/_gen/chalk/kubernetes/v1/horizontalpodautoscaler_pb2.py +79 -0
  25. chalk/_gen/chalk/kubernetes/v1/horizontalpodautoscaler_pb2.pyi +377 -0
  26. chalk/_gen/chalk/kubernetes/v1/horizontalpodautoscaler_pb2_grpc.py +4 -0
  27. chalk/_gen/chalk/kubernetes/v1/horizontalpodautoscaler_pb2_grpc.pyi +4 -0
  28. chalk/_gen/chalk/kubernetes/v1/scaledobject_pb2.py +43 -7
  29. chalk/_gen/chalk/kubernetes/v1/scaledobject_pb2.pyi +252 -2
  30. chalk/_gen/chalk/protosql/v1/sql_service_pb2.py +54 -27
  31. chalk/_gen/chalk/protosql/v1/sql_service_pb2.pyi +131 -3
  32. chalk/_gen/chalk/protosql/v1/sql_service_pb2_grpc.py +45 -0
  33. chalk/_gen/chalk/protosql/v1/sql_service_pb2_grpc.pyi +14 -0
  34. chalk/_gen/chalk/python/v1/types_pb2.py +14 -14
  35. chalk/_gen/chalk/python/v1/types_pb2.pyi +8 -0
  36. chalk/_gen/chalk/server/v1/benchmark_pb2.py +76 -0
  37. chalk/_gen/chalk/server/v1/benchmark_pb2.pyi +156 -0
  38. chalk/_gen/chalk/server/v1/benchmark_pb2_grpc.py +258 -0
  39. chalk/_gen/chalk/server/v1/benchmark_pb2_grpc.pyi +84 -0
  40. chalk/_gen/chalk/server/v1/billing_pb2.py +40 -38
  41. chalk/_gen/chalk/server/v1/billing_pb2.pyi +17 -1
  42. chalk/_gen/chalk/server/v1/branches_pb2.py +45 -0
  43. chalk/_gen/chalk/server/v1/branches_pb2.pyi +80 -0
  44. chalk/_gen/chalk/server/v1/branches_pb2_grpc.pyi +36 -0
  45. chalk/_gen/chalk/server/v1/builder_pb2.py +372 -272
  46. chalk/_gen/chalk/server/v1/builder_pb2.pyi +479 -12
  47. chalk/_gen/chalk/server/v1/builder_pb2_grpc.py +360 -0
  48. chalk/_gen/chalk/server/v1/builder_pb2_grpc.pyi +96 -0
  49. chalk/_gen/chalk/server/v1/chart_pb2.py +10 -10
  50. chalk/_gen/chalk/server/v1/chart_pb2.pyi +18 -2
  51. chalk/_gen/chalk/server/v1/clickhouse_pb2.py +42 -0
  52. chalk/_gen/chalk/server/v1/clickhouse_pb2.pyi +17 -0
  53. chalk/_gen/chalk/server/v1/clickhouse_pb2_grpc.py +78 -0
  54. chalk/_gen/chalk/server/v1/clickhouse_pb2_grpc.pyi +38 -0
  55. chalk/_gen/chalk/server/v1/cloud_components_pb2.py +153 -107
  56. chalk/_gen/chalk/server/v1/cloud_components_pb2.pyi +146 -4
  57. chalk/_gen/chalk/server/v1/cloud_components_pb2_grpc.py +180 -0
  58. chalk/_gen/chalk/server/v1/cloud_components_pb2_grpc.pyi +48 -0
  59. chalk/_gen/chalk/server/v1/cloud_credentials_pb2.py +11 -3
  60. chalk/_gen/chalk/server/v1/cloud_credentials_pb2.pyi +20 -0
  61. chalk/_gen/chalk/server/v1/cloud_credentials_pb2_grpc.py +45 -0
  62. chalk/_gen/chalk/server/v1/cloud_credentials_pb2_grpc.pyi +12 -0
  63. chalk/_gen/chalk/server/v1/dataplanejobqueue_pb2.py +59 -35
  64. chalk/_gen/chalk/server/v1/dataplanejobqueue_pb2.pyi +127 -1
  65. chalk/_gen/chalk/server/v1/dataplanejobqueue_pb2_grpc.py +135 -0
  66. chalk/_gen/chalk/server/v1/dataplanejobqueue_pb2_grpc.pyi +36 -0
  67. chalk/_gen/chalk/server/v1/dataplaneworkflows_pb2.py +90 -0
  68. chalk/_gen/chalk/server/v1/dataplaneworkflows_pb2.pyi +264 -0
  69. chalk/_gen/chalk/server/v1/dataplaneworkflows_pb2_grpc.py +170 -0
  70. chalk/_gen/chalk/server/v1/dataplaneworkflows_pb2_grpc.pyi +62 -0
  71. chalk/_gen/chalk/server/v1/datasets_pb2.py +36 -24
  72. chalk/_gen/chalk/server/v1/datasets_pb2.pyi +71 -2
  73. chalk/_gen/chalk/server/v1/datasets_pb2_grpc.py +45 -0
  74. chalk/_gen/chalk/server/v1/datasets_pb2_grpc.pyi +12 -0
  75. chalk/_gen/chalk/server/v1/deploy_pb2.py +9 -3
  76. chalk/_gen/chalk/server/v1/deploy_pb2.pyi +12 -0
  77. chalk/_gen/chalk/server/v1/deploy_pb2_grpc.py +45 -0
  78. chalk/_gen/chalk/server/v1/deploy_pb2_grpc.pyi +12 -0
  79. chalk/_gen/chalk/server/v1/deployment_pb2.py +20 -15
  80. chalk/_gen/chalk/server/v1/deployment_pb2.pyi +25 -0
  81. chalk/_gen/chalk/server/v1/environment_pb2.py +25 -15
  82. chalk/_gen/chalk/server/v1/environment_pb2.pyi +93 -1
  83. chalk/_gen/chalk/server/v1/eventbus_pb2.py +44 -0
  84. chalk/_gen/chalk/server/v1/eventbus_pb2.pyi +64 -0
  85. chalk/_gen/chalk/server/v1/eventbus_pb2_grpc.py +4 -0
  86. chalk/_gen/chalk/server/v1/eventbus_pb2_grpc.pyi +4 -0
  87. chalk/_gen/chalk/server/v1/files_pb2.py +65 -0
  88. chalk/_gen/chalk/server/v1/files_pb2.pyi +167 -0
  89. chalk/_gen/chalk/server/v1/files_pb2_grpc.py +4 -0
  90. chalk/_gen/chalk/server/v1/files_pb2_grpc.pyi +4 -0
  91. chalk/_gen/chalk/server/v1/graph_pb2.py +41 -3
  92. chalk/_gen/chalk/server/v1/graph_pb2.pyi +191 -0
  93. chalk/_gen/chalk/server/v1/graph_pb2_grpc.py +92 -0
  94. chalk/_gen/chalk/server/v1/graph_pb2_grpc.pyi +32 -0
  95. chalk/_gen/chalk/server/v1/incident_pb2.py +57 -0
  96. chalk/_gen/chalk/server/v1/incident_pb2.pyi +165 -0
  97. chalk/_gen/chalk/server/v1/incident_pb2_grpc.py +4 -0
  98. chalk/_gen/chalk/server/v1/incident_pb2_grpc.pyi +4 -0
  99. chalk/_gen/chalk/server/v1/indexing_job_pb2.py +44 -0
  100. chalk/_gen/chalk/server/v1/indexing_job_pb2.pyi +38 -0
  101. chalk/_gen/chalk/server/v1/indexing_job_pb2_grpc.py +78 -0
  102. chalk/_gen/chalk/server/v1/indexing_job_pb2_grpc.pyi +38 -0
  103. chalk/_gen/chalk/server/v1/integrations_pb2.py +11 -9
  104. chalk/_gen/chalk/server/v1/integrations_pb2.pyi +34 -2
  105. chalk/_gen/chalk/server/v1/kube_pb2.py +29 -19
  106. chalk/_gen/chalk/server/v1/kube_pb2.pyi +28 -0
  107. chalk/_gen/chalk/server/v1/kube_pb2_grpc.py +45 -0
  108. chalk/_gen/chalk/server/v1/kube_pb2_grpc.pyi +12 -0
  109. chalk/_gen/chalk/server/v1/log_pb2.py +21 -3
  110. chalk/_gen/chalk/server/v1/log_pb2.pyi +68 -0
  111. chalk/_gen/chalk/server/v1/log_pb2_grpc.py +90 -0
  112. chalk/_gen/chalk/server/v1/log_pb2_grpc.pyi +24 -0
  113. chalk/_gen/chalk/server/v1/metadataplanejobqueue_pb2.py +73 -0
  114. chalk/_gen/chalk/server/v1/metadataplanejobqueue_pb2.pyi +212 -0
  115. chalk/_gen/chalk/server/v1/metadataplanejobqueue_pb2_grpc.py +217 -0
  116. chalk/_gen/chalk/server/v1/metadataplanejobqueue_pb2_grpc.pyi +74 -0
  117. chalk/_gen/chalk/server/v1/model_registry_pb2.py +10 -10
  118. chalk/_gen/chalk/server/v1/model_registry_pb2.pyi +4 -1
  119. chalk/_gen/chalk/server/v1/monitoring_pb2.py +84 -75
  120. chalk/_gen/chalk/server/v1/monitoring_pb2.pyi +1 -0
  121. chalk/_gen/chalk/server/v1/monitoring_pb2_grpc.py +136 -0
  122. chalk/_gen/chalk/server/v1/monitoring_pb2_grpc.pyi +38 -0
  123. chalk/_gen/chalk/server/v1/offline_queries_pb2.py +32 -10
  124. chalk/_gen/chalk/server/v1/offline_queries_pb2.pyi +73 -0
  125. chalk/_gen/chalk/server/v1/offline_queries_pb2_grpc.py +90 -0
  126. chalk/_gen/chalk/server/v1/offline_queries_pb2_grpc.pyi +24 -0
  127. chalk/_gen/chalk/server/v1/plandebug_pb2.py +53 -0
  128. chalk/_gen/chalk/server/v1/plandebug_pb2.pyi +86 -0
  129. chalk/_gen/chalk/server/v1/plandebug_pb2_grpc.py +168 -0
  130. chalk/_gen/chalk/server/v1/plandebug_pb2_grpc.pyi +60 -0
  131. chalk/_gen/chalk/server/v1/queries_pb2.py +76 -48
  132. chalk/_gen/chalk/server/v1/queries_pb2.pyi +155 -2
  133. chalk/_gen/chalk/server/v1/queries_pb2_grpc.py +180 -0
  134. chalk/_gen/chalk/server/v1/queries_pb2_grpc.pyi +48 -0
  135. chalk/_gen/chalk/server/v1/scheduled_query_pb2.py +4 -2
  136. chalk/_gen/chalk/server/v1/scheduled_query_pb2_grpc.py +45 -0
  137. chalk/_gen/chalk/server/v1/scheduled_query_pb2_grpc.pyi +12 -0
  138. chalk/_gen/chalk/server/v1/scheduled_query_run_pb2.py +12 -6
  139. chalk/_gen/chalk/server/v1/scheduled_query_run_pb2.pyi +75 -2
  140. chalk/_gen/chalk/server/v1/scheduler_pb2.py +24 -12
  141. chalk/_gen/chalk/server/v1/scheduler_pb2.pyi +61 -1
  142. chalk/_gen/chalk/server/v1/scheduler_pb2_grpc.py +90 -0
  143. chalk/_gen/chalk/server/v1/scheduler_pb2_grpc.pyi +24 -0
  144. chalk/_gen/chalk/server/v1/script_tasks_pb2.py +26 -14
  145. chalk/_gen/chalk/server/v1/script_tasks_pb2.pyi +33 -3
  146. chalk/_gen/chalk/server/v1/script_tasks_pb2_grpc.py +90 -0
  147. chalk/_gen/chalk/server/v1/script_tasks_pb2_grpc.pyi +24 -0
  148. chalk/_gen/chalk/server/v1/sql_interface_pb2.py +75 -0
  149. chalk/_gen/chalk/server/v1/sql_interface_pb2.pyi +142 -0
  150. chalk/_gen/chalk/server/v1/sql_interface_pb2_grpc.py +349 -0
  151. chalk/_gen/chalk/server/v1/sql_interface_pb2_grpc.pyi +114 -0
  152. chalk/_gen/chalk/server/v1/sql_queries_pb2.py +48 -0
  153. chalk/_gen/chalk/server/v1/sql_queries_pb2.pyi +150 -0
  154. chalk/_gen/chalk/server/v1/sql_queries_pb2_grpc.py +123 -0
  155. chalk/_gen/chalk/server/v1/sql_queries_pb2_grpc.pyi +52 -0
  156. chalk/_gen/chalk/server/v1/team_pb2.py +156 -137
  157. chalk/_gen/chalk/server/v1/team_pb2.pyi +56 -10
  158. chalk/_gen/chalk/server/v1/team_pb2_grpc.py +90 -0
  159. chalk/_gen/chalk/server/v1/team_pb2_grpc.pyi +24 -0
  160. chalk/_gen/chalk/server/v1/topic_pb2.py +5 -3
  161. chalk/_gen/chalk/server/v1/topic_pb2.pyi +10 -1
  162. chalk/_gen/chalk/server/v1/trace_pb2.py +50 -28
  163. chalk/_gen/chalk/server/v1/trace_pb2.pyi +121 -0
  164. chalk/_gen/chalk/server/v1/trace_pb2_grpc.py +135 -0
  165. chalk/_gen/chalk/server/v1/trace_pb2_grpc.pyi +42 -0
  166. chalk/_gen/chalk/server/v1/webhook_pb2.py +9 -3
  167. chalk/_gen/chalk/server/v1/webhook_pb2.pyi +18 -0
  168. chalk/_gen/chalk/server/v1/webhook_pb2_grpc.py +45 -0
  169. chalk/_gen/chalk/server/v1/webhook_pb2_grpc.pyi +12 -0
  170. chalk/_gen/chalk/streaming/v1/debug_service_pb2.py +62 -0
  171. chalk/_gen/chalk/streaming/v1/debug_service_pb2.pyi +75 -0
  172. chalk/_gen/chalk/streaming/v1/debug_service_pb2_grpc.py +221 -0
  173. chalk/_gen/chalk/streaming/v1/debug_service_pb2_grpc.pyi +88 -0
  174. chalk/_gen/chalk/streaming/v1/simple_streaming_service_pb2.py +19 -7
  175. chalk/_gen/chalk/streaming/v1/simple_streaming_service_pb2.pyi +96 -3
  176. chalk/_gen/chalk/streaming/v1/simple_streaming_service_pb2_grpc.py +48 -0
  177. chalk/_gen/chalk/streaming/v1/simple_streaming_service_pb2_grpc.pyi +20 -0
  178. chalk/_gen/chalk/utils/v1/field_change_pb2.py +32 -0
  179. chalk/_gen/chalk/utils/v1/field_change_pb2.pyi +42 -0
  180. chalk/_gen/chalk/utils/v1/field_change_pb2_grpc.py +4 -0
  181. chalk/_gen/chalk/utils/v1/field_change_pb2_grpc.pyi +4 -0
  182. chalk/_lsp/error_builder.py +11 -0
  183. chalk/_monitoring/Chart.py +1 -3
  184. chalk/_version.py +1 -1
  185. chalk/cli.py +5 -10
  186. chalk/client/client.py +178 -64
  187. chalk/client/client_async.py +154 -0
  188. chalk/client/client_async_impl.py +22 -0
  189. chalk/client/client_grpc.py +738 -112
  190. chalk/client/client_impl.py +541 -136
  191. chalk/client/dataset.py +27 -6
  192. chalk/client/models.py +99 -2
  193. chalk/client/serialization/model_serialization.py +126 -10
  194. chalk/config/project_config.py +1 -1
  195. chalk/df/LazyFramePlaceholder.py +1154 -0
  196. chalk/df/ast_parser.py +2 -10
  197. chalk/features/_class_property.py +7 -0
  198. chalk/features/_embedding/embedding.py +1 -0
  199. chalk/features/_embedding/sentence_transformer.py +1 -1
  200. chalk/features/_encoding/converter.py +83 -2
  201. chalk/features/_encoding/pyarrow.py +20 -4
  202. chalk/features/_encoding/rich.py +1 -3
  203. chalk/features/_tensor.py +1 -2
  204. chalk/features/dataframe/_filters.py +14 -5
  205. chalk/features/dataframe/_impl.py +91 -36
  206. chalk/features/dataframe/_validation.py +11 -7
  207. chalk/features/feature_field.py +40 -30
  208. chalk/features/feature_set.py +1 -2
  209. chalk/features/feature_set_decorator.py +1 -0
  210. chalk/features/feature_wrapper.py +42 -3
  211. chalk/features/hooks.py +81 -12
  212. chalk/features/inference.py +65 -10
  213. chalk/features/resolver.py +338 -56
  214. chalk/features/tag.py +1 -3
  215. chalk/features/underscore_features.py +2 -1
  216. chalk/functions/__init__.py +456 -21
  217. chalk/functions/holidays.py +1 -3
  218. chalk/gitignore/gitignore_parser.py +5 -1
  219. chalk/importer.py +186 -74
  220. chalk/ml/__init__.py +6 -2
  221. chalk/ml/model_hooks.py +368 -51
  222. chalk/ml/model_reference.py +68 -10
  223. chalk/ml/model_version.py +34 -21
  224. chalk/ml/utils.py +143 -40
  225. chalk/operators/_utils.py +14 -3
  226. chalk/parsed/_proto/export.py +22 -0
  227. chalk/parsed/duplicate_input_gql.py +4 -0
  228. chalk/parsed/expressions.py +1 -3
  229. chalk/parsed/json_conversions.py +21 -14
  230. chalk/parsed/to_proto.py +16 -4
  231. chalk/parsed/user_types_to_json.py +31 -10
  232. chalk/parsed/validation_from_registries.py +182 -0
  233. chalk/queries/named_query.py +16 -6
  234. chalk/queries/scheduled_query.py +13 -1
  235. chalk/serialization/parsed_annotation.py +25 -12
  236. chalk/sql/__init__.py +221 -0
  237. chalk/sql/_internal/integrations/athena.py +6 -1
  238. chalk/sql/_internal/integrations/bigquery.py +22 -2
  239. chalk/sql/_internal/integrations/databricks.py +61 -18
  240. chalk/sql/_internal/integrations/mssql.py +281 -0
  241. chalk/sql/_internal/integrations/postgres.py +11 -3
  242. chalk/sql/_internal/integrations/redshift.py +4 -0
  243. chalk/sql/_internal/integrations/snowflake.py +11 -2
  244. chalk/sql/_internal/integrations/util.py +2 -1
  245. chalk/sql/_internal/sql_file_resolver.py +55 -10
  246. chalk/sql/_internal/sql_source.py +36 -2
  247. chalk/streams/__init__.py +1 -3
  248. chalk/streams/_kafka_source.py +5 -1
  249. chalk/streams/_windows.py +16 -4
  250. chalk/streams/types.py +1 -2
  251. chalk/utils/__init__.py +1 -3
  252. chalk/utils/_otel_version.py +13 -0
  253. chalk/utils/async_helpers.py +14 -5
  254. chalk/utils/df_utils.py +2 -2
  255. chalk/utils/duration.py +1 -3
  256. chalk/utils/job_log_display.py +538 -0
  257. chalk/utils/missing_dependency.py +5 -4
  258. chalk/utils/notebook.py +255 -2
  259. chalk/utils/pl_helpers.py +190 -37
  260. chalk/utils/pydanticutil/pydantic_compat.py +1 -2
  261. chalk/utils/storage_client.py +246 -0
  262. chalk/utils/threading.py +1 -3
  263. chalk/utils/tracing.py +194 -86
  264. {chalkpy-2.89.22.dist-info → chalkpy-2.95.3.dist-info}/METADATA +53 -21
  265. {chalkpy-2.89.22.dist-info → chalkpy-2.95.3.dist-info}/RECORD +268 -198
  266. {chalkpy-2.89.22.dist-info → chalkpy-2.95.3.dist-info}/WHEEL +0 -0
  267. {chalkpy-2.89.22.dist-info → chalkpy-2.95.3.dist-info}/entry_points.txt +0 -0
  268. {chalkpy-2.89.22.dist-info → chalkpy-2.95.3.dist-info}/top_level.txt +0 -0
chalk/ml/model_hooks.py CHANGED
@@ -1,68 +1,385 @@
1
- from typing import Any, Callable, Dict, Optional, Tuple
1
+ from typing import TYPE_CHECKING, Any, Dict, Optional, Protocol, Tuple
2
2
 
3
- import pyarrow as pa
3
+ from chalk.ml.utils import ModelClass, ModelEncoding, ModelType
4
4
 
5
- from chalk.ml.utils import ModelEncoding, ModelType
5
+ if TYPE_CHECKING:
6
+ from chalk.features.resolver import ResourceHint
6
7
 
7
8
 
8
- def load_xgb_classifier(f: str):
9
- import xgboost # pyright: ignore[reportMissingImports]
9
+ class ModelInference(Protocol):
10
+ """Abstract base class for model loading and inference."""
10
11
 
11
- model = xgboost.XGBClassifier()
12
- model.load_model(f)
13
- return model
12
+ def load_model(self, path: str, resource_hint: Optional["ResourceHint"] = None) -> Any:
13
+ """Load a model from the given path."""
14
+ pass
14
15
 
16
+ def predict(self, model: Any, X: Any) -> Any:
17
+ """Run inference on the model with input X."""
18
+ pass
15
19
 
16
- def load_xgb_regressor(f: str):
17
- import xgboost # pyright: ignore[reportMissingImports]
20
+ def prepare_input(self, feature_table: Any) -> Any:
21
+ """Convert PyArrow table to model input format.
18
22
 
19
- model = xgboost.XGBRegressor()
20
- model.load_model(f)
21
- return model
23
+ Default implementation converts to numpy array via __array__().
24
+ Override for model-specific input formats (e.g., ONNX struct arrays).
25
+ """
26
+ return feature_table.__array__()
22
27
 
28
+ def extract_output(self, result: Any, output_feature_name: str) -> Any:
29
+ """Extract single output from model result.
23
30
 
24
- def load_pytorch_model(f: str):
25
- import torch # pyright: ignore[reportMissingImports]
31
+ Default implementation returns result as-is (for single outputs).
32
+ Override for models with structured outputs (e.g., ONNX struct arrays).
33
+ """
34
+ return result
26
35
 
27
- torch.set_grad_enabled(False)
28
- model = torch.jit.load(f)
29
- model.input_to_tensor = lambda X: torch.from_numpy(X.__array__()).float()
30
- return model
31
36
 
37
+ class XGBoostClassifierInference(ModelInference):
38
+ """Model inference for XGBoost classifiers."""
32
39
 
33
- MODEL_HOOKS: Dict[Tuple[ModelType, ModelEncoding, Optional[str]], Callable[[str], Any]] = {
34
- (ModelType.PYTORCH, ModelEncoding.PICKLE, None): load_pytorch_model,
35
- (ModelType.SKLEARN, ModelEncoding.PICKLE, None): lambda f: __import__("joblib").load(f),
36
- (ModelType.TENSORFLOW, ModelEncoding.HDF5, None): lambda f: __import__("tensorflow").keras.models.load_model(f),
37
- (ModelType.TENSORFLOW, ModelEncoding.SAFETENSOR, None): lambda f: __import__("tensorflow").keras.models.load_model(
38
- f
39
- ),
40
- (ModelType.XGBOOST, ModelEncoding.JSON, None): load_xgb_regressor,
41
- (ModelType.XGBOOST, ModelEncoding.JSON, "classifier"): load_xgb_classifier,
42
- (ModelType.XGBOOST, ModelEncoding.JSON, "regressor"): load_xgb_regressor,
43
- (ModelType.LIGHTGBM, ModelEncoding.TEXT, None): lambda f: __import__("lightgbm").Booster(model_file=f),
44
- (ModelType.CATBOOST, ModelEncoding.CBM, None): lambda f: __import__("catboost").CatBoost().load_model(f),
45
- (ModelType.ONNX, ModelEncoding.PROTOBUF, None): lambda f: __import__("onnxruntime").InferenceSession(f),
46
- }
40
+ def load_model(self, path: str, resource_hint: Optional["ResourceHint"] = None) -> Any:
41
+ import xgboost # pyright: ignore[reportMissingImports]
47
42
 
43
+ model = xgboost.XGBClassifier()
44
+ model.load_model(path)
45
+ return model
48
46
 
49
- def pytorch_predict(model: Any, X: Any):
50
- outputs = model(model.input_to_tensor(X))
51
- result = outputs.detach().numpy().astype("float64")
52
- result = result.squeeze()
53
- # Convert 0-dimensional array to scalar, or ensure we have a proper 1D array
54
- if result.ndim == 0:
55
- return result.item()
56
- return result
47
+ def predict(self, model: Any, X: Any) -> Any:
48
+ return model.predict(X)
57
49
 
58
50
 
59
- PREDICT_HOOKS: Dict[Tuple[ModelType, ModelEncoding, Optional[str]], Callable[[Any, pa.Table], Any]] = {
60
- (ModelType.PYTORCH, ModelEncoding.PICKLE, None): pytorch_predict,
61
- (ModelType.SKLEARN, ModelEncoding.PICKLE, None): lambda model, X: model.predict(X),
62
- (ModelType.TENSORFLOW, ModelEncoding.HDF5, None): lambda model, X: model.predict(X),
63
- (ModelType.TENSORFLOW, ModelEncoding.SAFETENSOR, None): lambda model, X: model.predict(X),
64
- (ModelType.XGBOOST, ModelEncoding.JSON, None): lambda model, X: model.predict(X),
65
- (ModelType.LIGHTGBM, ModelEncoding.TEXT, None): lambda model, X: model.predict(X),
66
- (ModelType.CATBOOST, ModelEncoding.CBM, None): lambda model, X: model.predict(X),
67
- (ModelType.ONNX, ModelEncoding.PROTOBUF, None): lambda model, X: model.run(None, {"input": X.astype("float32")})[0],
68
- }
51
+ class XGBoostRegressorInference(ModelInference):
52
+ """Model inference for XGBoost regressors."""
53
+
54
+ def load_model(self, path: str, resource_hint: Optional["ResourceHint"] = None) -> Any:
55
+ import xgboost # pyright: ignore[reportMissingImports]
56
+
57
+ model = xgboost.XGBRegressor()
58
+ model.load_model(path)
59
+ return model
60
+
61
+ def predict(self, model: Any, X: Any) -> Any:
62
+ return model.predict(X)
63
+
64
+
65
+ class PyTorchInference(ModelInference):
66
+ """Model inference for PyTorch models."""
67
+
68
+ def load_model(self, path: str, resource_hint: Optional["ResourceHint"] = None) -> Any:
69
+ import torch # pyright: ignore[reportMissingImports]
70
+
71
+ torch.set_grad_enabled(False)
72
+
73
+ # Load the model
74
+ model = torch.jit.load(path)
75
+
76
+ # If resource_hint is "gpu", move model to GPU
77
+ if resource_hint == "gpu" and torch.cuda.is_available():
78
+ device = torch.device("cuda")
79
+ model = model.to(device)
80
+ model.input_to_tensor = lambda X: torch.from_numpy(X).float().to(device)
81
+ else:
82
+ model.input_to_tensor = lambda X: torch.from_numpy(X).float()
83
+
84
+ return model
85
+
86
+ def predict(self, model: Any, X: Any) -> Any:
87
+ outputs = model(model.input_to_tensor(X))
88
+ result = outputs.detach().cpu().numpy().astype("float64")
89
+ result = result.squeeze()
90
+
91
+ # Convert 0-dimensional array to scalar, or ensure we have a proper 1D array
92
+ if result.ndim == 0:
93
+ return result.item()
94
+
95
+ return result
96
+
97
+
98
+ class SklearnInference(ModelInference):
99
+ """Model inference for scikit-learn models."""
100
+
101
+ def load_model(self, path: str, resource_hint: Optional["ResourceHint"] = None) -> Any:
102
+ import joblib # pyright: ignore[reportMissingImports]
103
+
104
+ return joblib.load(path)
105
+
106
+ def predict(self, model: Any, X: Any) -> Any:
107
+ return model.predict(X)
108
+
109
+
110
+ class TensorFlowInference(ModelInference):
111
+ """Model inference for TensorFlow models."""
112
+
113
+ def load_model(self, path: str, resource_hint: Optional["ResourceHint"] = None) -> Any:
114
+ import tensorflow # pyright: ignore[reportMissingImports]
115
+
116
+ return tensorflow.keras.models.load_model(path)
117
+
118
+ def predict(self, model: Any, X: Any) -> Any:
119
+ return model.predict(X)
120
+
121
+
122
+ class LightGBMInference(ModelInference):
123
+ """Model inference for LightGBM models."""
124
+
125
+ def load_model(self, path: str, resource_hint: Optional["ResourceHint"] = None) -> Any:
126
+ import lightgbm # pyright: ignore[reportMissingImports]
127
+
128
+ return lightgbm.Booster(model_file=path)
129
+
130
+ def predict(self, model: Any, X: Any) -> Any:
131
+ return model.predict(X)
132
+
133
+
134
+ class CatBoostInference(ModelInference):
135
+ """Model inference for CatBoost models."""
136
+
137
+ def load_model(self, path: str, resource_hint: Optional["ResourceHint"] = None) -> Any:
138
+ import catboost # pyright: ignore[reportMissingImports]
139
+
140
+ return catboost.CatBoost().load_model(path)
141
+
142
+ def predict(self, model: Any, X: Any) -> Any:
143
+ return model.predict(X)
144
+
145
+
146
+ class ONNXInference(ModelInference):
147
+ """Model inference for ONNX models with struct input/output support."""
148
+
149
+ def load_model(self, path: str, resource_hint: Optional["ResourceHint"] = None) -> Any:
150
+ import onnxruntime # pyright: ignore[reportMissingImports]
151
+
152
+ # Conditionally add CUDAExecutionProvider based on resource_hint
153
+ providers = (
154
+ ["CUDAExecutionProvider", "CPUExecutionProvider"] if resource_hint == "gpu" else ["CPUExecutionProvider"]
155
+ )
156
+ return onnxruntime.InferenceSession(path, providers=providers)
157
+
158
+ def prepare_input(self, feature_table: Any) -> Any:
159
+ """Convert PyArrow table to struct array for ONNX models."""
160
+ import pyarrow as pa
161
+
162
+ # Get arrays for each column, combining chunks if necessary
163
+ arrays = []
164
+ for i in range(feature_table.num_columns):
165
+ col = feature_table.column(i)
166
+ if isinstance(col, pa.ChunkedArray):
167
+ arrays.append(col.combine_chunks())
168
+ else:
169
+ arrays.append(col)
170
+
171
+ # Create fields from schema, preserving original field names
172
+ # Field names should match ONNX input names exactly
173
+ fields = []
174
+ for field in feature_table.schema:
175
+ fields.append(pa.field(field.name, field.type))
176
+
177
+ # Create struct array where each row is a struct with named fields
178
+ return pa.StructArray.from_arrays(arrays, fields=fields)
179
+
180
+ def extract_output(self, result: Any, output_feature_name: str) -> Any:
181
+ """Extract single field from ONNX struct output."""
182
+ import pyarrow as pa
183
+
184
+ if not isinstance(result, (pa.StructArray, pa.ChunkedArray)):
185
+ return result
186
+
187
+ struct_type = result.type if isinstance(result, pa.StructArray) else result.chunk(0).type
188
+
189
+ # Find matching field by name, or use first field
190
+ field_index = None
191
+ for i, field in enumerate(struct_type):
192
+ if field.name == output_feature_name:
193
+ field_index = i
194
+ break
195
+
196
+ return result.field(field_index if field_index is not None else 0)
197
+
198
+ def predict(self, model: Any, X: Any) -> Any:
199
+ """Run ONNX inference with struct input/output."""
200
+ # Get ONNX model input/output names
201
+ input_names = [inp.name for inp in model.get_inputs()]
202
+ output_names = [out.name for out in model.get_outputs()]
203
+
204
+ # Convert struct input to ONNX input dict
205
+ input_dict = self._struct_to_inputs(X, input_names)
206
+
207
+ # Run ONNX inference
208
+ outputs = model.run(output_names, input_dict)
209
+
210
+ # Always return outputs as struct array
211
+ return self._outputs_to_struct(output_names, outputs)
212
+
213
+ def _struct_to_inputs(self, struct_array: Any, input_names: list) -> dict:
214
+ """Extract ONNX inputs from struct array by matching field names.
215
+
216
+ Struct field names must match ONNX input names (supports list/Tensor types).
217
+ If ONNX expects a single input but struct has multiple scalar fields,
218
+ stack them into a 2D array.
219
+ """
220
+ import numpy as np
221
+ import pyarrow as pa
222
+
223
+ if isinstance(struct_array, pa.ChunkedArray):
224
+ struct_array = struct_array.combine_chunks()
225
+
226
+ input_dict = {}
227
+ struct_fields = {field.name: i for i, field in enumerate(struct_array.type)}
228
+
229
+ # Check if struct field names match ONNX input names
230
+ fields_match = all(input_name in struct_fields for input_name in input_names)
231
+
232
+ if not fields_match:
233
+ # Special case 1: ONNX expects single input and struct has single field
234
+ # Use that field regardless of name mismatch
235
+ if len(input_names) == 1 and len(struct_fields) == 1:
236
+ field_data = struct_array.field(0)
237
+ input_dict[input_names[0]] = self._arrow_to_numpy(field_data)
238
+ return input_dict
239
+
240
+ # Special case 2: ONNX expects single input, but struct has multiple scalar fields
241
+ # Stack them into a 2D array [batch_size, num_fields]
242
+ if len(input_names) == 1 and len(struct_fields) > 1:
243
+ # Check if all fields are scalar (not nested lists)
244
+ all_scalar = all(
245
+ not pa.types.is_list(struct_array.type[i].type)
246
+ and not pa.types.is_large_list(struct_array.type[i].type)
247
+ for i in range(len(struct_array.type))
248
+ )
249
+
250
+ if all_scalar:
251
+ # Stack all fields into a single 2D array
252
+ columns = []
253
+ for i in range(len(struct_array.type)):
254
+ field_data = struct_array.field(i)
255
+ col_array = self._arrow_to_numpy(field_data)
256
+ columns.append(col_array)
257
+
258
+ # Stack columns horizontally to create [batch_size, num_features]
259
+ stacked = np.column_stack(columns)
260
+ input_dict[input_names[0]] = stacked
261
+ return input_dict
262
+
263
+ raise ValueError(
264
+ f"ONNX inputs {input_names} not found in struct fields {list(struct_fields.keys())}. "
265
+ + "Struct field names must match ONNX input names."
266
+ )
267
+
268
+ # Direct mapping: struct fields match ONNX inputs (for Tensor/list types or named inputs)
269
+ for input_name in input_names:
270
+ field_data = struct_array.field(struct_fields[input_name])
271
+ input_dict[input_name] = self._arrow_to_numpy(field_data)
272
+
273
+ return input_dict
274
+
275
+ def _arrow_to_numpy(self, arrow_array: Any) -> Any:
276
+ """Convert Arrow array (including nested lists) to dense numpy array."""
277
+ import numpy as np
278
+ import pyarrow as pa
279
+
280
+ if isinstance(arrow_array, pa.ChunkedArray):
281
+ arrow_array = arrow_array.combine_chunks()
282
+
283
+ # Convert to Python list, then numpy - handles all cases (nested lists, flat arrays, etc.)
284
+ return np.array(arrow_array.to_pylist(), dtype=np.float32)
285
+
286
+ def _outputs_to_struct(self, output_names: list, outputs: list) -> Any:
287
+ """Convert ONNX outputs to PyArrow struct array."""
288
+ import pyarrow as pa
289
+
290
+ if not outputs:
291
+ raise ValueError("ONNX model returned no outputs")
292
+
293
+ # Convert each output to Arrow array with proper type
294
+ fields = []
295
+ arrays = []
296
+
297
+ for name, output_array in zip(output_names, outputs):
298
+ arrow_array = self._numpy_to_arrow_array(output_array)
299
+ fields.append(pa.field(name, arrow_array.type))
300
+ arrays.append(arrow_array)
301
+
302
+ return pa.StructArray.from_arrays(arrays, fields=fields)
303
+
304
+ def _numpy_to_arrow_array(self, arr: Any) -> Any:
305
+ """Convert numpy array to PyArrow array (possibly nested list)."""
306
+ import pyarrow as pa
307
+
308
+ # PyArrow can infer the correct nested list type from Python lists
309
+ # Shape (batch, dim1, dim2, ...) -> list[list[...]]
310
+ return pa.array(arr.tolist())
311
+
312
+
313
+ class ModelInferenceRegistry:
314
+ """Registry for model inference implementations."""
315
+
316
+ def __init__(self):
317
+ super().__init__()
318
+ self._registry: Dict[Tuple[ModelType, ModelEncoding, Optional[ModelClass]], ModelInference] = {}
319
+
320
+ def register(
321
+ self,
322
+ model_type: ModelType,
323
+ encoding: ModelEncoding,
324
+ model_class: Optional[ModelClass],
325
+ inference: ModelInference,
326
+ ) -> None:
327
+ """Register a model inference implementation."""
328
+ self._registry[(model_type, encoding, model_class)] = inference
329
+
330
+ def register_for_all_classes(
331
+ self,
332
+ model_type: ModelType,
333
+ encoding: ModelEncoding,
334
+ inference: ModelInference,
335
+ ) -> None:
336
+ """Register inference for None, CLASSIFICATION, and REGRESSION variants."""
337
+ self.register(model_type, encoding, None, inference)
338
+ self.register(model_type, encoding, ModelClass.CLASSIFICATION, inference)
339
+ self.register(model_type, encoding, ModelClass.REGRESSION, inference)
340
+
341
+ def get(
342
+ self,
343
+ model_type: ModelType,
344
+ encoding: ModelEncoding,
345
+ model_class: Optional[ModelClass] = None,
346
+ ) -> Optional[ModelInference]:
347
+ """Get a model inference implementation from the registry."""
348
+ return self._registry.get((model_type, encoding, model_class), None)
349
+
350
+ def get_loader(
351
+ self,
352
+ model_type: ModelType,
353
+ encoding: ModelEncoding,
354
+ model_class: Optional[ModelClass] = None,
355
+ ):
356
+ """Get the load_model function for a given configuration."""
357
+ inference = self.get(model_type, encoding, model_class)
358
+ return inference.load_model if inference else None
359
+
360
+ def get_predictor(
361
+ self,
362
+ model_type: ModelType,
363
+ encoding: ModelEncoding,
364
+ model_class: Optional[ModelClass] = None,
365
+ ):
366
+ """Get the predict function for a given configuration."""
367
+ inference = self.get(model_type, encoding, model_class)
368
+ return inference.predict if inference else None
369
+
370
+
371
+ # Global registry instance
372
+ MODEL_REGISTRY = ModelInferenceRegistry()
373
+
374
+ # Register all model types
375
+ MODEL_REGISTRY.register_for_all_classes(ModelType.PYTORCH, ModelEncoding.PICKLE, PyTorchInference())
376
+ MODEL_REGISTRY.register_for_all_classes(ModelType.SKLEARN, ModelEncoding.PICKLE, SklearnInference())
377
+ MODEL_REGISTRY.register_for_all_classes(ModelType.TENSORFLOW, ModelEncoding.HDF5, TensorFlowInference())
378
+ MODEL_REGISTRY.register_for_all_classes(ModelType.LIGHTGBM, ModelEncoding.TEXT, LightGBMInference())
379
+ MODEL_REGISTRY.register_for_all_classes(ModelType.CATBOOST, ModelEncoding.CBM, CatBoostInference())
380
+ MODEL_REGISTRY.register_for_all_classes(ModelType.ONNX, ModelEncoding.PROTOBUF, ONNXInference())
381
+
382
+ # XGBoost requires different implementations for classification vs regression
383
+ MODEL_REGISTRY.register(ModelType.XGBOOST, ModelEncoding.JSON, None, XGBoostRegressorInference())
384
+ MODEL_REGISTRY.register(ModelType.XGBOOST, ModelEncoding.JSON, ModelClass.CLASSIFICATION, XGBoostClassifierInference())
385
+ MODEL_REGISTRY.register(ModelType.XGBOOST, ModelEncoding.JSON, ModelClass.REGRESSION, XGBoostRegressorInference())
@@ -3,12 +3,22 @@ from __future__ import annotations
3
3
  import inspect
4
4
  import os
5
5
  from datetime import datetime
6
+ from typing import TYPE_CHECKING
6
7
 
7
8
  from chalk.ml.model_version import ModelVersion
8
- from chalk.ml.utils import REGISTRY_METADATA_FILE, get_model_spec, model_encoding_from_proto, model_type_from_proto
9
+ from chalk.ml.utils import (
10
+ ModelClass,
11
+ get_model_spec,
12
+ get_registry_metadata_file,
13
+ model_encoding_from_proto,
14
+ model_type_from_proto,
15
+ )
9
16
  from chalk.utils.object_inspect import get_source_object_starting
10
17
  from chalk.utils.source_parsing import should_skip_source_code_parsing
11
18
 
19
+ if TYPE_CHECKING:
20
+ from chalk.features.resolver import ResourceHint
21
+
12
22
 
13
23
  class ModelReference:
14
24
  def __init__(
@@ -18,6 +28,8 @@ class ModelReference:
18
28
  version: int | None = None,
19
29
  alias: str | None = None,
20
30
  as_of_date: datetime | None = None,
31
+ resource_hint: "ResourceHint | None" = None,
32
+ resource_group: str | None = None,
21
33
  ):
22
34
  """Specifies the model version that should be loaded into the deployment.
23
35
 
@@ -68,6 +80,8 @@ class ModelReference:
68
80
  self.as_of_date = as_of_date
69
81
  self.alias = alias
70
82
  self.identifier = identifier
83
+ self.resource_hint = resource_hint
84
+ self.resource_group = resource_group
71
85
 
72
86
  self.filename = filename
73
87
  self.source_line_start = source_line_start
@@ -89,7 +103,8 @@ class ModelReference:
89
103
  MODEL_REFERENCE_REGISTRY[(name, identifier)] = self
90
104
 
91
105
  # Only load model if the metadata file exists, which only happens in deployed environments
92
- if REGISTRY_METADATA_FILE is not None and os.path.exists(REGISTRY_METADATA_FILE):
106
+ registry_metadata_file = get_registry_metadata_file()
107
+ if registry_metadata_file is not None and os.path.exists(registry_metadata_file):
93
108
  model_artifact_metadata = get_model_spec(model_name=name, identifier=identifier)
94
109
 
95
110
  mv = ModelVersion(
@@ -100,6 +115,11 @@ class ModelReference:
100
115
  identifier=identifier,
101
116
  model_type=model_type_from_proto(model_artifact_metadata.spec.model_type),
102
117
  model_encoding=model_encoding_from_proto(model_artifact_metadata.spec.model_encoding),
118
+ model_class=ModelClass(model_artifact_metadata.spec.model_class)
119
+ if model_artifact_metadata.spec.model_class
120
+ else None,
121
+ resource_hint=resource_hint,
122
+ resource_group=resource_group,
103
123
  )
104
124
 
105
125
  from chalk.features.hooks import before_all
@@ -107,14 +127,22 @@ class ModelReference:
107
127
  def hook():
108
128
  mv.load_model()
109
129
 
110
- before_all(hook)
130
+ before_all(hook, resource_hint=resource_hint, resource_group=resource_group)
111
131
 
112
132
  self.model_version = mv
113
133
  else:
114
- self.model_version = ModelVersion(name=name, identifier=identifier)
134
+ self.model_version = ModelVersion(
135
+ name=name, identifier=identifier, resource_hint=resource_hint, resource_group=resource_group
136
+ )
115
137
 
116
138
  @classmethod
117
- def as_of(cls, name: str, when: datetime) -> ModelVersion:
139
+ def as_of(
140
+ cls,
141
+ name: str,
142
+ when: datetime,
143
+ resource_hint: "ResourceHint | None" = None,
144
+ resource_group: str | None = None,
145
+ ) -> ModelVersion:
118
146
  """Creates a ModelReference for a specific point in time.
119
147
 
120
148
  Parameters
@@ -123,6 +151,11 @@ class ModelReference:
123
151
  The name of the model.
124
152
  when
125
153
  The datetime to use for creating the model version identifier.
154
+ resource_hint
155
+ Whether this model loading is bound by CPU, I/O, or GPU.
156
+ resource_group
157
+ The resource group for the model: this is used to isolate execution
158
+ onto a separate pod (or set of nodes), such as on a GPU-enabled node.
126
159
 
127
160
  Returns
128
161
  -------
@@ -134,13 +167,20 @@ class ModelReference:
134
167
  >>> import datetime
135
168
  >>> timestamp = datetime.datetime(2023, 10, 15, 14, 30, 0)
136
169
  >>> model = ModelReference.as_of("fraud_model", timestamp)
170
+ >>> model = ModelReference.as_of("fraud_model", timestamp, resource_hint="gpu", resource_group="gpu-group")
137
171
  """
138
172
 
139
- mr = ModelReference(name=name, as_of_date=when)
173
+ mr = ModelReference(name=name, as_of_date=when, resource_hint=resource_hint, resource_group=resource_group)
140
174
  return mr.model_version
141
175
 
142
176
  @classmethod
143
- def from_version(cls, name: str, version: int) -> ModelVersion:
177
+ def from_version(
178
+ cls,
179
+ name: str,
180
+ version: int,
181
+ resource_hint: "ResourceHint | None" = None,
182
+ resource_group: str | None = None,
183
+ ) -> ModelVersion:
144
184
  """Creates a ModelReference using a numeric version identifier.
145
185
 
146
186
  Parameters
@@ -149,6 +189,11 @@ class ModelReference:
149
189
  The name of the model.
150
190
  version
151
191
  The version number. Must be a non-negative integer.
192
+ resource_hint
193
+ Whether this model loading is bound by CPU, I/O, or GPU.
194
+ resource_group
195
+ The resource group for the model: this is used to isolate execution
196
+ onto a separate pod (or set of nodes), such as on a GPU-enabled node.
152
197
 
153
198
  Returns
154
199
  -------
@@ -163,15 +208,22 @@ class ModelReference:
163
208
  Examples
164
209
  --------
165
210
  >>> model = ModelReference.from_version("fraud_model", 1)
211
+ >>> model = ModelReference.from_version("fraud_model", 1, resource_hint="gpu", resource_group="gpu-group")
166
212
  """
167
213
  if version < 0:
168
214
  raise ValueError("Version number must be a non-negative integer.")
169
215
 
170
- mr = ModelReference(name=name, version=version)
216
+ mr = ModelReference(name=name, version=version, resource_hint=resource_hint, resource_group=resource_group)
171
217
  return mr.model_version
172
218
 
173
219
  @classmethod
174
- def from_alias(cls, name: str, alias: str) -> ModelVersion:
220
+ def from_alias(
221
+ cls,
222
+ name: str,
223
+ alias: str,
224
+ resource_hint: "ResourceHint | None" = None,
225
+ resource_group: str | None = None,
226
+ ) -> ModelVersion:
175
227
  """Creates a ModelReference using an alias identifier.
176
228
 
177
229
  Parameters
@@ -180,6 +232,11 @@ class ModelReference:
180
232
  The name of the model.
181
233
  alias
182
234
  The alias string. Must be non-empty.
235
+ resource_hint
236
+ Whether this model loading is bound by CPU, I/O, or GPU.
237
+ resource_group
238
+ The resource group for the model: this is used to isolate execution
239
+ onto a separate pod (or set of nodes), such as on a GPU-enabled node.
183
240
 
184
241
  Returns
185
242
  -------
@@ -194,11 +251,12 @@ class ModelReference:
194
251
  Examples
195
252
  --------
196
253
  >>> model = ModelReference.from_alias("fraud_model", "latest")
254
+ >>> model = ModelReference.from_alias("fraud_model", "latest", resource_hint="gpu", resource_group="gpu-group")
197
255
  """
198
256
  if not alias:
199
257
  raise ValueError("Alias must be a non-empty string.")
200
258
 
201
- mr = ModelReference(name=name, alias=alias)
259
+ mr = ModelReference(name=name, alias=alias, resource_hint=resource_hint, resource_group=resource_group)
202
260
  return mr.model_version
203
261
 
204
262