kumoai 2.13.0.dev202511131731__cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kumoai might be problematic. Click here for more details.

Files changed (98) hide show
  1. kumoai/__init__.py +294 -0
  2. kumoai/_logging.py +29 -0
  3. kumoai/_singleton.py +25 -0
  4. kumoai/_version.py +1 -0
  5. kumoai/artifact_export/__init__.py +9 -0
  6. kumoai/artifact_export/config.py +209 -0
  7. kumoai/artifact_export/job.py +108 -0
  8. kumoai/client/__init__.py +5 -0
  9. kumoai/client/client.py +221 -0
  10. kumoai/client/connector.py +110 -0
  11. kumoai/client/endpoints.py +150 -0
  12. kumoai/client/graph.py +120 -0
  13. kumoai/client/jobs.py +447 -0
  14. kumoai/client/online.py +78 -0
  15. kumoai/client/pquery.py +203 -0
  16. kumoai/client/rfm.py +112 -0
  17. kumoai/client/source_table.py +53 -0
  18. kumoai/client/table.py +101 -0
  19. kumoai/client/utils.py +130 -0
  20. kumoai/codegen/__init__.py +19 -0
  21. kumoai/codegen/cli.py +100 -0
  22. kumoai/codegen/context.py +16 -0
  23. kumoai/codegen/edits.py +473 -0
  24. kumoai/codegen/exceptions.py +10 -0
  25. kumoai/codegen/generate.py +222 -0
  26. kumoai/codegen/handlers/__init__.py +4 -0
  27. kumoai/codegen/handlers/connector.py +118 -0
  28. kumoai/codegen/handlers/graph.py +71 -0
  29. kumoai/codegen/handlers/pquery.py +62 -0
  30. kumoai/codegen/handlers/table.py +109 -0
  31. kumoai/codegen/handlers/utils.py +42 -0
  32. kumoai/codegen/identity.py +114 -0
  33. kumoai/codegen/loader.py +93 -0
  34. kumoai/codegen/naming.py +94 -0
  35. kumoai/codegen/registry.py +121 -0
  36. kumoai/connector/__init__.py +31 -0
  37. kumoai/connector/base.py +153 -0
  38. kumoai/connector/bigquery_connector.py +200 -0
  39. kumoai/connector/databricks_connector.py +213 -0
  40. kumoai/connector/file_upload_connector.py +189 -0
  41. kumoai/connector/glue_connector.py +150 -0
  42. kumoai/connector/s3_connector.py +278 -0
  43. kumoai/connector/snowflake_connector.py +252 -0
  44. kumoai/connector/source_table.py +471 -0
  45. kumoai/connector/utils.py +1775 -0
  46. kumoai/databricks.py +14 -0
  47. kumoai/encoder/__init__.py +4 -0
  48. kumoai/exceptions.py +26 -0
  49. kumoai/experimental/__init__.py +0 -0
  50. kumoai/experimental/rfm/__init__.py +67 -0
  51. kumoai/experimental/rfm/authenticate.py +433 -0
  52. kumoai/experimental/rfm/infer/__init__.py +11 -0
  53. kumoai/experimental/rfm/infer/categorical.py +40 -0
  54. kumoai/experimental/rfm/infer/id.py +46 -0
  55. kumoai/experimental/rfm/infer/multicategorical.py +48 -0
  56. kumoai/experimental/rfm/infer/timestamp.py +41 -0
  57. kumoai/experimental/rfm/local_graph.py +810 -0
  58. kumoai/experimental/rfm/local_graph_sampler.py +184 -0
  59. kumoai/experimental/rfm/local_graph_store.py +359 -0
  60. kumoai/experimental/rfm/local_pquery_driver.py +689 -0
  61. kumoai/experimental/rfm/local_table.py +545 -0
  62. kumoai/experimental/rfm/pquery/__init__.py +7 -0
  63. kumoai/experimental/rfm/pquery/executor.py +102 -0
  64. kumoai/experimental/rfm/pquery/pandas_executor.py +532 -0
  65. kumoai/experimental/rfm/rfm.py +1130 -0
  66. kumoai/experimental/rfm/utils.py +344 -0
  67. kumoai/formatting.py +30 -0
  68. kumoai/futures.py +99 -0
  69. kumoai/graph/__init__.py +12 -0
  70. kumoai/graph/column.py +106 -0
  71. kumoai/graph/graph.py +948 -0
  72. kumoai/graph/table.py +838 -0
  73. kumoai/jobs.py +80 -0
  74. kumoai/kumolib.cpython-313-x86_64-linux-gnu.so +0 -0
  75. kumoai/mixin.py +28 -0
  76. kumoai/pquery/__init__.py +25 -0
  77. kumoai/pquery/prediction_table.py +287 -0
  78. kumoai/pquery/predictive_query.py +637 -0
  79. kumoai/pquery/training_table.py +424 -0
  80. kumoai/spcs.py +123 -0
  81. kumoai/testing/__init__.py +8 -0
  82. kumoai/testing/decorators.py +57 -0
  83. kumoai/trainer/__init__.py +42 -0
  84. kumoai/trainer/baseline_trainer.py +93 -0
  85. kumoai/trainer/config.py +2 -0
  86. kumoai/trainer/job.py +1192 -0
  87. kumoai/trainer/online_serving.py +258 -0
  88. kumoai/trainer/trainer.py +475 -0
  89. kumoai/trainer/util.py +103 -0
  90. kumoai/utils/__init__.py +10 -0
  91. kumoai/utils/datasets.py +83 -0
  92. kumoai/utils/forecasting.py +209 -0
  93. kumoai/utils/progress_logger.py +177 -0
  94. kumoai-2.13.0.dev202511131731.dist-info/METADATA +60 -0
  95. kumoai-2.13.0.dev202511131731.dist-info/RECORD +98 -0
  96. kumoai-2.13.0.dev202511131731.dist-info/WHEEL +6 -0
  97. kumoai-2.13.0.dev202511131731.dist-info/licenses/LICENSE +9 -0
  98. kumoai-2.13.0.dev202511131731.dist-info/top_level.txt +1 -0
kumoai/__init__.py ADDED
@@ -0,0 +1,294 @@
1
+ import os
2
+ import sys
3
+ import threading
4
+ from dataclasses import dataclass
5
+ from typing import Any, Dict, Optional
6
+ import logging
7
+
8
+ from kumoapi.typing import Dtype, Stype
9
+
10
+ from kumoai.client.client import KumoClient
11
+ from kumoai._logging import initialize_logging, _ENV_KUMO_LOG
12
+ from kumoai._singleton import Singleton
13
+ from kumoai.futures import create_future, initialize_event_loop
14
+ from kumoai.spcs import (
15
+ _get_active_session,
16
+ _get_spcs_token,
17
+ _run_refresh_spcs_token,
18
+ )
19
+
20
+ initialize_logging()
21
+ initialize_event_loop()
22
+
23
+
24
+ @dataclass
25
+ class GlobalState(metaclass=Singleton):
26
+ r"""Global storage of the state needed to create a Kumo client object. A
27
+ singleton so its initialized state can be referenced elsewhere for free.
28
+ """
29
+
30
+ # NOTE fork semantics: CoW on Linux, and re-execed on Windows. So this will
31
+ # likely not work on Windows unless we have special handling for the shared
32
+ # state:
33
+ _url: Optional[str] = None
34
+ _api_key: Optional[str] = None
35
+ _snowflake_credentials: Optional[Dict[str, Any]] = None
36
+ _spcs_token: Optional[str] = None
37
+ _snowpark_session: Optional[Any] = None
38
+
39
+ thread_local: threading.local = threading.local()
40
+
41
+ def clear(self) -> None:
42
+ if hasattr(self.thread_local, '_client'):
43
+ del self.thread_local._client
44
+ self._url = None
45
+ self._api_key = None
46
+ self._snowflake_credentials = None
47
+ self._spcs_token = None
48
+
49
+ def set_spcs_token(self, spcs_token: str) -> None:
50
+ # Set the spcs token in the global state. This will be picked up the
51
+ # next time client() is accessed.
52
+ self._spcs_token = spcs_token
53
+
54
+ @property
55
+ def initialized(self) -> bool:
56
+ return self._url is not None and (
57
+ self._api_key is not None or self._snowflake_credentials
58
+ is not None or self._snowpark_session is not None)
59
+
60
+ @property
61
+ def client(self) -> KumoClient:
62
+ r"""Accessor for the Kumo client. Note that clients are stored as
63
+ thread-local variables as the requests Session library is not
64
+ guaranteed to be thread-safe.
65
+
66
+ For more information, see https://github.com/psf/requests/issues/1871.
67
+ """
68
+ if self._url is None or (self._api_key is None
69
+ and self._spcs_token is None
70
+ and self._snowpark_session is None):
71
+ raise ValueError(
72
+ "Client creation or authentication failed; please re-create "
73
+ "your client before proceeding.")
74
+
75
+ if hasattr(self.thread_local, '_client'):
76
+ # Set the spcs token in the client to ensure it has the latest.
77
+ self.thread_local._client.set_spcs_token(self._spcs_token)
78
+ return self.thread_local._client
79
+
80
+ client = KumoClient(self._url, self._api_key, self._spcs_token)
81
+ self.thread_local._client = client
82
+ return client
83
+
84
+ @property
85
+ def is_spcs(self) -> bool:
86
+ return (self._api_key is None
87
+ and (self._snowflake_credentials is not None
88
+ or self._snowpark_session is not None))
89
+
90
+
91
+ global_state: GlobalState = GlobalState()
92
+
93
+
94
+ def init(
95
+ url: Optional[str] = None,
96
+ api_key: Optional[str] = None,
97
+ snowflake_credentials: Optional[Dict[str, str]] = None,
98
+ snowflake_application: Optional[str] = None,
99
+ log_level: str = "INFO",
100
+ ) -> None:
101
+ r"""Initializes and authenticates the API key against the Kumo service.
102
+ Successful authentication is required to use the SDK.
103
+
104
+ Example:
105
+ >>> import kumoai
106
+ >>> kumoai.init(url="<api_url>", api_key="<api_key>") # doctest: +SKIP
107
+
108
+ Args:
109
+ url: The Kumo API endpoint. Can also be provided via the
110
+ ``KUMO_API_ENDPOINT`` envronment variable. Will be inferred from
111
+ the provided API key, if not provided.
112
+ api_key: The Kumo API key. Can also be provided via the
113
+ ``KUMO_API_KEY`` environment variable.
114
+ snowflake_credentials: The Snowflake credentials to authenticate
115
+ against the Kumo service. The dictionary should contain the keys
116
+ ``"user"``, ``"password"``, and ``"account"``. This should only be
117
+ provided for SPCS.
118
+ snowflake_application: The Snowflake application.
119
+ log_level: The logging level that Kumo operates under. Defaults to
120
+ INFO; for more information, please see
121
+ :class:`~kumoai.set_log_level`. Can also be set with the
122
+ environment variable ``KUMOAI_LOG``.
123
+ """ # noqa
124
+ # Avoid mutations to the global state after it is set:
125
+ if global_state.initialized:
126
+ print(
127
+ "Client has already been created. To re-initialize Kumo, please "
128
+ "start a new interpreter. No changes will be made to the current "
129
+ "session.")
130
+ return
131
+
132
+ set_log_level(os.getenv(_ENV_KUMO_LOG, log_level))
133
+
134
+ # Get API key:
135
+ api_key = api_key or os.getenv("KUMO_API_KEY")
136
+
137
+ snowpark_session = None
138
+ if snowflake_application:
139
+ if url is not None:
140
+ raise ValueError(
141
+ "Client creation failed: both snowflake_application and url "
142
+ "are specified. If running from a snowflake notebook, specify"
143
+ "only snowflake_application.")
144
+ snowpark_session = _get_active_session()
145
+ if not snowpark_session:
146
+ raise ValueError(
147
+ "Client creation failed: snowflake_application is specified "
148
+ "without an active snowpark session. If running outside "
149
+ "a snowflake notebook, specify a URL and credentials.")
150
+ description = snowpark_session.sql(
151
+ f"DESCRIBE SERVICE {snowflake_application}."
152
+ "USER_SCHEMA.KUMO_SERVICE").collect()[0]
153
+ url = f"http://{description.dns_name}:8888/public_api"
154
+
155
+ if api_key is None and not snowflake_application:
156
+ if snowflake_credentials is None:
157
+ raise ValueError(
158
+ "Client creation failed: Neither API key nor snowflake "
159
+ "credentials provided. Please either set the 'KUMO_API_KEY' "
160
+ "or explicitly call `kumoai.init(...)`.")
161
+ if (set(snowflake_credentials.keys())
162
+ != {'user', 'password', 'account'}):
163
+ raise ValueError(
164
+ f"Provided credentials should be a dictionary with keys "
165
+ f"'user', 'password', and 'account'. Only "
166
+ f"{set(snowflake_credentials.keys())} were provided.")
167
+
168
+ # Get or infer URL:
169
+ url = url or os.getenv("KUMO_API_ENDPOINT")
170
+ try:
171
+ if api_key:
172
+ url = url or f"http://{api_key.split(':')[0]}.kumoai.cloud/api"
173
+ except KeyError:
174
+ pass
175
+ if url is None:
176
+ raise ValueError(
177
+ "Client creation failed: endpoint URL not provided. Please "
178
+ "either set the 'KUMO_API_ENDPOINT' environment variable or "
179
+ "explicitly call `kumoai.init(...)`.")
180
+
181
+ # Assign global state after verification that client can be created and
182
+ # authenticated successfully:
183
+ spcs_token = _get_spcs_token(
184
+ snowflake_credentials
185
+ ) if not api_key and snowflake_credentials else None
186
+ client = KumoClient(url=url, api_key=api_key, spcs_token=spcs_token)
187
+ if client.authenticate():
188
+ global_state._url = client._url
189
+ global_state._api_key = client._api_key
190
+ global_state._snowflake_credentials = snowflake_credentials
191
+ global_state._spcs_token = client._spcs_token
192
+ global_state._snowpark_session = snowpark_session
193
+ else:
194
+ raise ValueError("Client authentication failed. Please check if you "
195
+ "have a valid API key.")
196
+
197
+ if not api_key and snowflake_credentials:
198
+ # Refresh token every 10 minutes (expires in 1 hour):
199
+ create_future(_run_refresh_spcs_token(minutes=10))
200
+
201
+ logger = logging.getLogger('kumoai')
202
+ log_level = logging.getLevelName(logger.getEffectiveLevel())
203
+
204
+ logger.info(
205
+ f"Successfully initialized the Kumo SDK (version {__version__}) "
206
+ f"against deployment {url}, with "
207
+ f"log level {log_level}.")
208
+
209
+
210
+ def set_log_level(level: str) -> None:
211
+ r"""Sets the Kumo logging level, which defines the amount of output that
212
+ methods produce.
213
+
214
+ Example:
215
+ >>> import kumoai
216
+ >>> kumoai.set_log_level("INFO") # doctest: +SKIP
217
+
218
+ Args:
219
+ level: the logging level. Can be one of (in order of lowest to highest
220
+ log output) :obj:`DEBUG`, :obj:`INFO`, :obj:`WARNING`,
221
+ :obj:`ERROR`, :obj:`FATAL`, :obj:`CRITICAL`.
222
+ """
223
+ # logging library will ensure `level` is a valid string, and raise a
224
+ # warning if not:
225
+ logging.getLogger('kumoai').setLevel(level)
226
+
227
+
228
+ # Try to initialize purely with environment variables:
229
+ if ("pytest" not in sys.modules and "KUMO_API_KEY" in os.environ
230
+ and "KUMO_API_ENDPOINT" in os.environ):
231
+ init()
232
+
233
+ import kumoai.connector # noqa
234
+ import kumoai.encoder # noqa
235
+ import kumoai.graph # noqa
236
+ import kumoai.pquery # noqa
237
+ import kumoai.trainer # noqa
238
+ import kumoai.utils # noqa
239
+ import kumoai.databricks # noqa
240
+
241
+ from kumoai.connector import ( # noqa
242
+ SourceTable, SourceTableFuture, S3Connector, SnowflakeConnector,
243
+ DatabricksConnector, BigQueryConnector, FileUploadConnector, GlueConnector)
244
+ from kumoai.graph import Column, Edge, Graph, Table # noqa
245
+ from kumoai.pquery import ( # noqa
246
+ PredictionTableGenerationPlan, PredictiveQuery,
247
+ TrainingTableGenerationPlan, TrainingTable, TrainingTableJob,
248
+ PredictionTable, PredictionTableJob)
249
+ from kumoai.trainer import ( # noqa
250
+ ModelPlan, Trainer, TrainingJobResult, TrainingJob,
251
+ BatchPredictionJobResult, BatchPredictionJob)
252
+ from kumoai._version import __version__ # noqa
253
+
254
+ __all__ = [
255
+ 'Dtype',
256
+ 'Stype',
257
+ 'SourceTable',
258
+ 'SourceTableFuture',
259
+ 'S3Connector',
260
+ 'SnowflakeConnector',
261
+ 'DatabricksConnector',
262
+ 'BigQueryConnector',
263
+ 'FileUploadConnector',
264
+ 'GlueConnector',
265
+ 'Column',
266
+ 'Table',
267
+ 'Graph',
268
+ 'Edge',
269
+ 'PredictiveQuery',
270
+ 'TrainingTable',
271
+ 'TrainingTableJob',
272
+ 'TrainingTableGenerationPlan',
273
+ 'PredictionTable',
274
+ 'PredictionTableJob',
275
+ 'PredictionTableGenerationPlan',
276
+ 'Trainer',
277
+ 'TrainingJobResult',
278
+ 'TrainingJob',
279
+ 'BatchPredictionJobResult',
280
+ 'BatchPredictionJob',
281
+ 'ModelPlan',
282
+ '__version__',
283
+ ]
284
+
285
+
286
+ def in_notebook() -> bool:
287
+ try:
288
+ from IPython import get_ipython
289
+ shell = get_ipython()
290
+ if 'google.colab' in str(shell.__class__):
291
+ return True
292
+ return shell.__class__.__name__ == 'ZMQInteractiveShell'
293
+ except Exception:
294
+ return False
kumoai/_logging.py ADDED
@@ -0,0 +1,29 @@
1
+ import logging
2
+ import os
3
+
4
+ _ENV_KUMO_LOG = "KUMO_LOG"
5
+
6
+
7
+ def initialize_logging() -> None:
8
+ r"""Initializes Kumo logging."""
9
+ logger: logging.Logger = logging.getLogger('kumoai')
10
+
11
+ # From openai-python/blob/main/src/openai/_utils/_logs.py#L4
12
+ logging.basicConfig(
13
+ format=(
14
+ "[%(asctime)s - %(name)s:%(lineno)d - %(levelname)s] %(message)s"),
15
+ datefmt="%Y-%m-%d %H:%M:%S",
16
+ )
17
+
18
+ default_level = os.getenv(_ENV_KUMO_LOG, "INFO")
19
+ try:
20
+ logger.setLevel(default_level)
21
+ except (TypeError, ValueError):
22
+ logger.setLevel(logging.INFO)
23
+ logger.warning(
24
+ "Logging level %s could not be properly parsed. "
25
+ "Defaulting to INFO log level.", default_level)
26
+
27
+ for name in ["matplotlib", "urllib3", "snowflake"]:
28
+ # TODO(dm) required for spcs
29
+ logging.getLogger(name).setLevel(logging.ERROR)
kumoai/_singleton.py ADDED
@@ -0,0 +1,25 @@
1
+ from abc import ABCMeta
2
+ from typing import Any, Dict
3
+
4
+
5
+ class Singleton(ABCMeta):
6
+ r"""A per-process singleton definition."""
7
+ _instances: Dict[type, Any] = {}
8
+
9
+ def __call__(cls, *args: Any, **kwargs: Any) -> Any:
10
+ if cls not in cls._instances:
11
+ # Calls the `__init__` method of the subclass and returns a
12
+ # reference, which is stored to prevent multiple instantiations.
13
+ instance = super(Singleton, cls).__call__(*args, **kwargs)
14
+ cls._instances[cls] = instance
15
+ return instance
16
+ return cls._instances[cls]
17
+
18
+ def clear(cls) -> None:
19
+ r"""Clears the singleton class instance, so the next construction
20
+ will re-initialize the clas.
21
+ """
22
+ try:
23
+ del Singleton._instances[cls]
24
+ except KeyError:
25
+ pass
kumoai/_version.py ADDED
@@ -0,0 +1 @@
1
+ __version__ = '2.13.0.dev202511131731'
@@ -0,0 +1,9 @@
1
+ from .job import ArtifactExportJob, ArtifactExportResult
2
+ from .config import OutputConfig, TrainingTableExportConfig
3
+
4
+ __all__ = [
5
+ "ArtifactExportJob",
6
+ "ArtifactExportResult",
7
+ "OutputConfig",
8
+ "TrainingTableExportConfig",
9
+ ]
@@ -0,0 +1,209 @@
1
+ import functools
2
+ from dataclasses import dataclass
3
+ from typing import Any, Dict, List, Optional, Set, Tuple, Union
4
+
5
+ from kumoapi.jobs import MetadataField, WriteMode
6
+ from kumoapi.typing import WITH_PYDANTIC_V2
7
+ from pydantic.dataclasses import dataclass as pydantic_dataclass
8
+
9
+ if WITH_PYDANTIC_V2:
10
+ from pydantic import field_validator, model_validator # type: ignore
11
+ else:
12
+ from pydantic import root_validator, validator
13
+
14
+ from kumoai.connector.base import Connector
15
+
16
+
17
+ def compatible_field_validator(field_name: str): # type: ignore
18
+ """Decorator factory that creates a field validator compatible with both
19
+ Pydantic v1 and v2.
20
+
21
+ Usage:
22
+ @compatible_field_validator('field_name')
23
+ def validate_field(cls, v, values_or_info):
24
+ # Your validation logic here
25
+ return v
26
+ """
27
+ def decorator(func): # type: ignore
28
+ if WITH_PYDANTIC_V2:
29
+
30
+ @field_validator(field_name)
31
+ @classmethod
32
+ @functools.wraps(func)
33
+ def wrapper(cls, v, info): # type: ignore
34
+ # Convert info to values dict for compatibility
35
+ values = info.data if hasattr(info, 'data') else {}
36
+ return func(cls, v, values)
37
+
38
+ return wrapper
39
+ else:
40
+
41
+ @validator(field_name)
42
+ @functools.wraps(func)
43
+ def wrapper(cls, v, values): # type: ignore
44
+ return func(cls, v, values)
45
+
46
+ return wrapper
47
+
48
+ return decorator
49
+
50
+
51
+ # TODO: probably will need to be removed b/c using __post_init__ instead
52
+ def compatible_model_validator(mode='before'): # type: ignore
53
+ """Decorator factory that creates a model validator compatible with both
54
+ Pydantic v1 and v2.
55
+
56
+ Usage:
57
+ @compatible_model_validator()
58
+ def validate_model(cls, values):
59
+ # Your validation logic here
60
+ return values
61
+ """
62
+ def decorator(func): # type: ignore
63
+ if WITH_PYDANTIC_V2:
64
+
65
+ @model_validator(mode=mode)
66
+ @classmethod
67
+ @functools.wraps(func)
68
+ def wrapper(cls, values): # type: ignore
69
+ return func(cls, values)
70
+
71
+ return wrapper
72
+ else:
73
+
74
+ @root_validator
75
+ @functools.wraps(func)
76
+ def wrapper(cls, values): # type: ignore
77
+ return func(cls, values)
78
+
79
+ return wrapper
80
+
81
+ return decorator
82
+
83
+
84
+ @dataclass(frozen=True)
85
+ class QueryConnectorConfig:
86
+ # If using OVERWRITE, big query connector will first write to a staging
87
+ # table followed by overwriting to the destination table.
88
+ # When using APPEND, it is strongly recommended to use
89
+ # MetadataField.JOB_TIMESTAMP to indicate the timestamp of the job.
90
+ write_mode: WriteMode = WriteMode.APPEND
91
+
92
+
93
+ @dataclass(frozen=True)
94
+ class BigQueryOutputConfig(QueryConnectorConfig):
95
+ pass
96
+
97
+
98
+ @dataclass(frozen=True)
99
+ class SnowflakeConnectorConfig(QueryConnectorConfig):
100
+ pass
101
+
102
+
103
+ CONNECTOR_CONFIG_MAPPING = {
104
+ 'BigQueryConnector': BigQueryOutputConfig,
105
+ 'SnowflakeConnector': SnowflakeConnectorConfig,
106
+ # 'DatabricksConnector': DatabricksOutputConfig,
107
+ # 'S3Connector': S3OutputConfig,
108
+ }
109
+
110
+
111
+ @pydantic_dataclass(frozen=True, config={'arbitrary_types_allowed': True})
112
+ class OutputConfig:
113
+ """Output configuration associated with a Batch Prediction Job.
114
+ Specifies the output types and optionally output data source
115
+ configuration.
116
+
117
+ Args:
118
+ output_types(`Set[str]`): The types of outputs that should be produced
119
+ by the prediction job. Can include either ``'predictions'``,
120
+ ``'embeddings'``, or both.
121
+ output_connector(`Connector` or None): The output data source that Kumo
122
+ should write batch predictions to, if it is None,
123
+ produce local download output only.
124
+ output_table_name(`str` or `Tuple[str, str]` or None): The name of the
125
+ table in the output data source
126
+ that Kumo should write batch predictions to. In the case of
127
+ a Databricks connector, this should be a tuple of two strings:
128
+ the schema name and the output prediction table name.
129
+ output_metadata_fields(`List[MetadataField]` or None): Any additional
130
+ metadata fields to include as new columns in the produced
131
+ ``'predictions'`` output. Currently, allowed options are
132
+ ``JOB_TIMESTAMP`` and ``ANCHOR_TIMESTAMP``.
133
+ connector_specific_config(`QueryConnectorConfig` or None): The custom
134
+ connector specific output config for predictions, for
135
+ example whether to append or overwrite existing table.
136
+ """
137
+ output_types: Set[str]
138
+ output_connector: Optional[Connector] = None
139
+ output_table_name: Optional[Union[str, Tuple]] = None
140
+ output_metadata_fields: Optional[List[MetadataField]] = None
141
+ connector_specific_config: Optional[Union[
142
+ BigQueryOutputConfig,
143
+ SnowflakeConnectorConfig,
144
+ ]] = None
145
+
146
+ @compatible_field_validator('connector_specific_config')
147
+ def validate_connector_config(cls, v: Any, values: Dict) -> Any:
148
+ """Validate the connector specific output config. Raises ValueError if
149
+ there is a mismatch between the connector type and the config type.
150
+ """
151
+ # Skip validation if no connector or no specific config
152
+ if values.get('output_connector') is None or v is None:
153
+ return v
154
+
155
+ connector_type = type(values['output_connector']).__name__
156
+ expected_config_type = CONNECTOR_CONFIG_MAPPING.get(connector_type)
157
+
158
+ # If we don't have a mapping for this connector type, it doesn't
159
+ # support specific configs yet
160
+ if expected_config_type is None:
161
+ raise ValueError(
162
+ f"Connector type '{connector_type}' does not support "
163
+ f"specific output configurations")
164
+
165
+ # Check if the provided config is of the correct type
166
+ if not isinstance(v, expected_config_type):
167
+ raise ValueError(
168
+ f"Connector type '{connector_type}' requires output "
169
+ f"config of type '{expected_config_type.__name__}', but "
170
+ f"got '{type(v).__name__}'")
171
+
172
+ return v
173
+
174
+
175
+ @pydantic_dataclass(frozen=True, config={'arbitrary_types_allowed': True})
176
+ class TrainingTableExportConfig(OutputConfig):
177
+ """Export configuration associated with a Training Table.
178
+
179
+ Args:
180
+ output_types(`Set[str]`): The artifact to export from the training
181
+ table job. Currently only `'training_table'` is supported.
182
+ Which exports the full training table to the output connector.
183
+ output_connector(`Connector`): The output data source that Kumo should
184
+ write training table artifacts to.
185
+ output_table_name(str): The name of the table in the output data source
186
+ that Kumo should write batch predictions to. In the case of
187
+ a Databricks connector, this should be a tuple of two strings:
188
+ the schema name and the output prediction table name.
189
+ connector_specific_config(QueryConnectorConfig or None):
190
+ Defines custom connector specific output
191
+ for example whether to append or overwrite
192
+ existing table. This is currently only supported for BigQuery and
193
+ Snowflake.
194
+ """
195
+ output_connector: Connector
196
+ output_table_name: str
197
+
198
+ def __post_init__(self) -> None:
199
+ if self.output_types != {'training_table'}:
200
+ raise ValueError("output_type must be set(['training_table'])"
201
+ f" (got {self.output_types})")
202
+ if self.output_connector is None:
203
+ raise ValueError("output_connector is required")
204
+ if self.output_table_name is None:
205
+ raise ValueError("output_table_name is required")
206
+ if self.output_metadata_fields is not None:
207
+ raise ValueError(
208
+ "output_metadata_fields is not supported for training"
209
+ "table export")
@@ -0,0 +1,108 @@
1
+ import asyncio
2
+ import concurrent
3
+ import concurrent.futures
4
+ import time
5
+
6
+ from kumoapi.common import JobStatus
7
+ from typing_extensions import override
8
+
9
+ from kumoai import global_state
10
+ from kumoai.futures import KumoProgressFuture, create_future
11
+
12
+
13
+ class ArtifactExportResult:
14
+ r"""Represents a completed artifact export job."""
15
+ def __init__(self, job_id: str) -> None:
16
+ self.job_id = job_id
17
+
18
+ def tracking_url(self) -> str:
19
+ r"""Returns a tracking URL pointing to the UI display of
20
+ this prediction export job.
21
+ """
22
+ raise NotImplementedError
23
+
24
+ def __repr__(self) -> str:
25
+ return f"{self.__class__.__name__}(job_id={self.job_id})"
26
+
27
+
28
+ class ArtifactExportJob(KumoProgressFuture[ArtifactExportResult]):
29
+ """Represents an in-progress artifact export job."""
30
+ def __init__(self, job_id: str) -> None:
31
+ self.job_id = job_id
32
+ self._fut: concurrent.futures.Future = create_future(
33
+ _poll_export(job_id))
34
+
35
+ @property
36
+ def id(self) -> str:
37
+ """The unique ID of this export job."""
38
+ return self.job_id
39
+
40
+ @override
41
+ def result(self) -> ArtifactExportResult:
42
+ return self._fut.result()
43
+
44
+ @override
45
+ def future(self) -> 'concurrent.futures.Future[ArtifactExportResult]':
46
+ return self._fut
47
+
48
+ @override
49
+ def _attach_internal(
50
+ self,
51
+ interval_s: float = 20.0,
52
+ ) -> ArtifactExportResult:
53
+ """Allows a user to attach to a running export job and view
54
+ its progress.
55
+
56
+ Args:
57
+ interval_s (float): Time interval (seconds) between polls, minimum
58
+ value allowed is 4 seconds.
59
+ """
60
+ assert interval_s >= 4.0
61
+ print(f"Attaching to export job {self.job_id}. To detach from "
62
+ f"this job, please enter Ctrl+C (the job will continue to run, "
63
+ f"and you can re-attach anytime).")
64
+
65
+ # TODO improve print statements.
66
+ # Will require changes to status to return
67
+ # JobStatusReport instead of JobStatus.
68
+ while not self.done():
69
+ status = self.status()
70
+ print(f"Export job {self.job_id} status: {status}")
71
+ time.sleep(interval_s)
72
+
73
+ return self.result()
74
+
75
+ def status(self) -> JobStatus:
76
+ """Returns the status of a running export job."""
77
+ return get_export_status(self.job_id)
78
+
79
+ def cancel(self) -> bool:
80
+ """Cancels a running export job.
81
+
82
+ Returns:
83
+ bool: True if the job is in a terminal state.
84
+ """
85
+ api = global_state.client.artifact_export_api
86
+ status = api.cancel(self.job_id)
87
+ if status == JobStatus.CANCELLED:
88
+ return True
89
+ return False
90
+
91
+
92
+ def get_export_status(job_id: str) -> JobStatus:
93
+ api = global_state.client.artifact_export_api
94
+ resource = api.get(job_id)
95
+ return resource
96
+
97
+
98
+ async def _poll_export(job_id: str) -> ArtifactExportResult:
99
+ status = get_export_status(job_id)
100
+ while not status.is_terminal:
101
+ await asyncio.sleep(10)
102
+ status = get_export_status(job_id)
103
+
104
+ if status != JobStatus.DONE:
105
+ raise RuntimeError(f"Export job {job_id} failed "
106
+ f"with job status {status}.")
107
+
108
+ return ArtifactExportResult(job_id=job_id)
@@ -0,0 +1,5 @@
1
+ from .client import KumoClient
2
+
3
+ __all__ = [
4
+ 'KumoClient',
5
+ ]