kumoai 2.14.0.dev202601011731__cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kumoai might be problematic. Click here for more details.

Files changed (122) hide show
  1. kumoai/__init__.py +300 -0
  2. kumoai/_logging.py +29 -0
  3. kumoai/_singleton.py +25 -0
  4. kumoai/_version.py +1 -0
  5. kumoai/artifact_export/__init__.py +9 -0
  6. kumoai/artifact_export/config.py +209 -0
  7. kumoai/artifact_export/job.py +108 -0
  8. kumoai/client/__init__.py +5 -0
  9. kumoai/client/client.py +223 -0
  10. kumoai/client/connector.py +110 -0
  11. kumoai/client/endpoints.py +150 -0
  12. kumoai/client/graph.py +120 -0
  13. kumoai/client/jobs.py +471 -0
  14. kumoai/client/online.py +78 -0
  15. kumoai/client/pquery.py +207 -0
  16. kumoai/client/rfm.py +112 -0
  17. kumoai/client/source_table.py +53 -0
  18. kumoai/client/table.py +101 -0
  19. kumoai/client/utils.py +130 -0
  20. kumoai/codegen/__init__.py +19 -0
  21. kumoai/codegen/cli.py +100 -0
  22. kumoai/codegen/context.py +16 -0
  23. kumoai/codegen/edits.py +473 -0
  24. kumoai/codegen/exceptions.py +10 -0
  25. kumoai/codegen/generate.py +222 -0
  26. kumoai/codegen/handlers/__init__.py +4 -0
  27. kumoai/codegen/handlers/connector.py +118 -0
  28. kumoai/codegen/handlers/graph.py +71 -0
  29. kumoai/codegen/handlers/pquery.py +62 -0
  30. kumoai/codegen/handlers/table.py +109 -0
  31. kumoai/codegen/handlers/utils.py +42 -0
  32. kumoai/codegen/identity.py +114 -0
  33. kumoai/codegen/loader.py +93 -0
  34. kumoai/codegen/naming.py +94 -0
  35. kumoai/codegen/registry.py +121 -0
  36. kumoai/connector/__init__.py +31 -0
  37. kumoai/connector/base.py +153 -0
  38. kumoai/connector/bigquery_connector.py +200 -0
  39. kumoai/connector/databricks_connector.py +213 -0
  40. kumoai/connector/file_upload_connector.py +189 -0
  41. kumoai/connector/glue_connector.py +150 -0
  42. kumoai/connector/s3_connector.py +278 -0
  43. kumoai/connector/snowflake_connector.py +252 -0
  44. kumoai/connector/source_table.py +471 -0
  45. kumoai/connector/utils.py +1796 -0
  46. kumoai/databricks.py +14 -0
  47. kumoai/encoder/__init__.py +4 -0
  48. kumoai/exceptions.py +26 -0
  49. kumoai/experimental/__init__.py +0 -0
  50. kumoai/experimental/rfm/__init__.py +210 -0
  51. kumoai/experimental/rfm/authenticate.py +432 -0
  52. kumoai/experimental/rfm/backend/__init__.py +0 -0
  53. kumoai/experimental/rfm/backend/local/__init__.py +42 -0
  54. kumoai/experimental/rfm/backend/local/graph_store.py +297 -0
  55. kumoai/experimental/rfm/backend/local/sampler.py +312 -0
  56. kumoai/experimental/rfm/backend/local/table.py +113 -0
  57. kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
  58. kumoai/experimental/rfm/backend/snow/sampler.py +297 -0
  59. kumoai/experimental/rfm/backend/snow/table.py +242 -0
  60. kumoai/experimental/rfm/backend/sqlite/__init__.py +32 -0
  61. kumoai/experimental/rfm/backend/sqlite/sampler.py +398 -0
  62. kumoai/experimental/rfm/backend/sqlite/table.py +184 -0
  63. kumoai/experimental/rfm/base/__init__.py +30 -0
  64. kumoai/experimental/rfm/base/column.py +152 -0
  65. kumoai/experimental/rfm/base/expression.py +44 -0
  66. kumoai/experimental/rfm/base/sampler.py +761 -0
  67. kumoai/experimental/rfm/base/source.py +19 -0
  68. kumoai/experimental/rfm/base/sql_sampler.py +143 -0
  69. kumoai/experimental/rfm/base/table.py +736 -0
  70. kumoai/experimental/rfm/graph.py +1237 -0
  71. kumoai/experimental/rfm/infer/__init__.py +19 -0
  72. kumoai/experimental/rfm/infer/categorical.py +40 -0
  73. kumoai/experimental/rfm/infer/dtype.py +82 -0
  74. kumoai/experimental/rfm/infer/id.py +46 -0
  75. kumoai/experimental/rfm/infer/multicategorical.py +48 -0
  76. kumoai/experimental/rfm/infer/pkey.py +128 -0
  77. kumoai/experimental/rfm/infer/stype.py +35 -0
  78. kumoai/experimental/rfm/infer/time_col.py +61 -0
  79. kumoai/experimental/rfm/infer/timestamp.py +41 -0
  80. kumoai/experimental/rfm/pquery/__init__.py +7 -0
  81. kumoai/experimental/rfm/pquery/executor.py +102 -0
  82. kumoai/experimental/rfm/pquery/pandas_executor.py +530 -0
  83. kumoai/experimental/rfm/relbench.py +76 -0
  84. kumoai/experimental/rfm/rfm.py +1184 -0
  85. kumoai/experimental/rfm/sagemaker.py +138 -0
  86. kumoai/experimental/rfm/task_table.py +231 -0
  87. kumoai/formatting.py +30 -0
  88. kumoai/futures.py +99 -0
  89. kumoai/graph/__init__.py +12 -0
  90. kumoai/graph/column.py +106 -0
  91. kumoai/graph/graph.py +948 -0
  92. kumoai/graph/table.py +838 -0
  93. kumoai/jobs.py +80 -0
  94. kumoai/kumolib.cpython-310-x86_64-linux-gnu.so +0 -0
  95. kumoai/mixin.py +28 -0
  96. kumoai/pquery/__init__.py +25 -0
  97. kumoai/pquery/prediction_table.py +287 -0
  98. kumoai/pquery/predictive_query.py +641 -0
  99. kumoai/pquery/training_table.py +424 -0
  100. kumoai/spcs.py +121 -0
  101. kumoai/testing/__init__.py +8 -0
  102. kumoai/testing/decorators.py +57 -0
  103. kumoai/testing/snow.py +50 -0
  104. kumoai/trainer/__init__.py +42 -0
  105. kumoai/trainer/baseline_trainer.py +93 -0
  106. kumoai/trainer/config.py +2 -0
  107. kumoai/trainer/distilled_trainer.py +175 -0
  108. kumoai/trainer/job.py +1192 -0
  109. kumoai/trainer/online_serving.py +258 -0
  110. kumoai/trainer/trainer.py +475 -0
  111. kumoai/trainer/util.py +103 -0
  112. kumoai/utils/__init__.py +11 -0
  113. kumoai/utils/datasets.py +83 -0
  114. kumoai/utils/display.py +51 -0
  115. kumoai/utils/forecasting.py +209 -0
  116. kumoai/utils/progress_logger.py +343 -0
  117. kumoai/utils/sql.py +3 -0
  118. kumoai-2.14.0.dev202601011731.dist-info/METADATA +71 -0
  119. kumoai-2.14.0.dev202601011731.dist-info/RECORD +122 -0
  120. kumoai-2.14.0.dev202601011731.dist-info/WHEEL +6 -0
  121. kumoai-2.14.0.dev202601011731.dist-info/licenses/LICENSE +9 -0
  122. kumoai-2.14.0.dev202601011731.dist-info/top_level.txt +1 -0
@@ -0,0 +1,424 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import logging
5
+ import os
6
+ import time
7
+ from concurrent.futures import Future
8
+ from functools import cached_property
9
+ from typing import List, Optional, Tuple, Union
10
+
11
+ import pandas as pd
12
+ from kumoapi.common import JobStatus
13
+ from kumoapi.jobs import (
14
+ ArtifactExportRequest,
15
+ CustomTrainingTable,
16
+ GenerateTrainTableJobResource,
17
+ GenerateTrainTableRequest,
18
+ JobStatusReport,
19
+ SourceTableType,
20
+ TrainingTableOutputConfig,
21
+ TrainingTableSpec,
22
+ WriteMode,
23
+ )
24
+ from kumoapi.source_table import S3SourceTable
25
+ from tqdm.auto import tqdm
26
+ from typing_extensions import Self, override
27
+
28
+ from kumoai import global_state
29
+ from kumoai.artifact_export import (
30
+ ArtifactExportJob,
31
+ ArtifactExportResult,
32
+ TrainingTableExportConfig,
33
+ )
34
+ from kumoai.client.jobs import (
35
+ GenerateTrainTableJobAPI,
36
+ GenerateTrainTableJobID,
37
+ )
38
+ from kumoai.connector import S3Connector, SourceTable
39
+ from kumoai.databricks import to_db_table_name
40
+ from kumoai.formatting import pretty_print_error_details
41
+ from kumoai.futures import KumoProgressFuture, create_future
42
+ from kumoai.jobs import JobInterface
43
+
44
+ logger = logging.getLogger(__name__)
45
+
46
+ _DEFAULT_INTERVAL_S = 20
47
+
48
+
49
+ class TrainingTable:
50
+ r"""A training table in the Kumo platform. A training table can be
51
+ initialized from a job ID of a completed training table generation job.
52
+
53
+ .. code-block:: python
54
+
55
+ import kumoai
56
+
57
+ # Create a Training Table from a training table generation job. Note
58
+ # that the job ID passed here must be in a completed state:
59
+ training_table = kumoai.TrainingTable("gen-traintable-job-...")
60
+
61
+ # Read the training table as a Pandas DataFrame:
62
+ training_df = training_table.data_df()
63
+
64
+ # Get URLs to download the training table:
65
+ training_download_urls = training_table.data_urls()
66
+
67
+ # Add weight column to the training table:
68
+ # see `kumo-sdk.examples.datasets.weighted_train_table.py`
69
+ # for a more detailed example
70
+ # 1. Export train table
71
+ connector = kumo.S3Connector("s3_path")
72
+ training_table.export(TrainingTableExportConfig(
73
+ output_types={'training_table'},
74
+ output_connector=connector,
75
+ output_table_name="<any_name>"))
76
+ # 2. Assume the weight column was added to the train table
77
+ # and it was saved to the same S3 path as "<mod_name>"
78
+ training_table.update(SourceTable("<mod_table>", connector),
79
+ TrainingTableSpec(weight_col="weight"))
80
+
81
+ Args:
82
+ job_id: ID of the training table generation job which generated this
83
+ training table.
84
+ """
85
+ def __init__(self, job_id: GenerateTrainTableJobID):
86
+ self.job_id = job_id
87
+ status = _get_status(job_id).status
88
+ self._custom_train_table: Optional[CustomTrainingTable] = None
89
+ if status != JobStatus.DONE:
90
+ raise ValueError(
91
+ f"Job {job_id} is not yet complete (status: {status}). If you "
92
+ f"would like to create a future (waiting for training table "
93
+ f"generation success), please use `TrainingTableJob`.")
94
+
95
+ def data_urls(self) -> List[str]:
96
+ r"""Returns a list of URLs that can be used to view generated
97
+ training table data. The list will contain more than one element
98
+ if the table is partitioned; paths will be relative to the location of
99
+ the Kumo data plane.
100
+ """
101
+ api: GenerateTrainTableJobAPI = (
102
+ global_state.client.generate_train_table_job_api)
103
+ return api._get_table_data(self.job_id, presigned=True, raw_path=True)
104
+
105
+ def data_df(self) -> pd.DataFrame:
106
+ r"""Returns a :class:`~pandas.DataFrame` object representing the
107
+ generated training data.
108
+
109
+ .. warning::
110
+
111
+ This method will load the full training table into memory as a
112
+ :class:`~pandas.DataFrame` object. If you are working on a machine
113
+ with limited resources, please use
114
+ :meth:`~kumoai.pquery.TrainingTable.data_urls` instead to download
115
+ the data and perform analysis per-partition.
116
+ """
117
+ urls = self.data_urls()
118
+ if global_state.is_spcs:
119
+ from kumoai.spcs import _parquet_dataset_to_df
120
+
121
+ # TODO(dm): return type hint is wrong
122
+ return _parquet_dataset_to_df(self.data_urls())
123
+
124
+ try:
125
+ return pd.concat([pd.read_parquet(x) for x in urls])
126
+ except Exception as e:
127
+ raise ValueError(
128
+ f"Could not create a Pandas DataFrame object from data paths "
129
+ f"{urls}. Please construct the DataFrame manually.") from e
130
+
131
+ def __repr__(self) -> str:
132
+ return f'{self.__class__.__name__}(job_id={self.job_id})'
133
+
134
+ def _to_s3_api_source_table(self,
135
+ source_table: SourceTable) -> S3SourceTable:
136
+ assert isinstance(source_table.connector, S3Connector)
137
+ source_type = source_table._to_api_source_table()
138
+ root_dir: str = source_table.connector.root_dir # type: ignore
139
+ if root_dir[-1] != os.sep:
140
+ root_dir = root_dir + os.sep
141
+ return S3SourceTable(
142
+ s3_path=root_dir,
143
+ source_table_name=source_table.name,
144
+ file_type=source_type.file_type,
145
+ )
146
+
147
+ def export(
148
+ self,
149
+ output_config: TrainingTableExportConfig,
150
+ non_blocking: bool = True,
151
+ ) -> Union[ArtifactExportJob, ArtifactExportResult]:
152
+ r"""Export the training table to the connector.
153
+ specified in the output config. Use the exported table to
154
+ add a weight column then use `update` to update the training table.
155
+
156
+ Args:
157
+ output_config: The output configuration to write the training
158
+ table.
159
+ non_blocking: If ``True``, the method will return a future object
160
+ `ArtifactExportJob` representing the export job.
161
+ If ``False``, the method will block until the export job is
162
+ complete and return `ArtifactExportResult`.
163
+ """
164
+ assert output_config.output_connector is not None
165
+ assert output_config.output_types == {'training_table'}
166
+ output_table_name = to_db_table_name(output_config.output_table_name)
167
+ assert output_table_name is not None
168
+ s3_path = None
169
+ connector_id = None
170
+ table_name = ""
171
+ write_mode = WriteMode.OVERWRITE
172
+
173
+ if isinstance(output_config.output_connector, S3Connector):
174
+ assert output_config.output_connector.root_dir is not None
175
+ s3_path = output_config.output_connector.root_dir
176
+ s3_path = os.path.join(s3_path, output_table_name)
177
+ else:
178
+ connector_id = output_config.output_connector.name
179
+ table_name = output_table_name
180
+ if output_config.connector_specific_config:
181
+ write_mode = output_config.connector_specific_config.write_mode
182
+
183
+ api = global_state.client.artifact_export_api
184
+ output_config = TrainingTableOutputConfig(
185
+ s3_path=s3_path,
186
+ connector_id=connector_id,
187
+ table_name=table_name,
188
+ write_mode=write_mode,
189
+ )
190
+
191
+ request = ArtifactExportRequest(job_id=self.job_id,
192
+ training_table_output=output_config)
193
+ job_id = api.create(request)
194
+ if non_blocking:
195
+ return ArtifactExportJob(job_id)
196
+ return ArtifactExportJob(job_id).attach()
197
+
198
+ def validate_custom_table(
199
+ self,
200
+ source_table_type: SourceTableType,
201
+ train_table_mod: TrainingTableSpec,
202
+ ) -> None:
203
+ r"""Validates the modified training table.
204
+
205
+ Args:
206
+ source_table_type: The source table to be used as the modified
207
+ training table.
208
+ train_table_mod: The modification specification.
209
+
210
+ Raises:
211
+ ValueError: If the modified training table is invalid.
212
+
213
+ """
214
+ api: GenerateTrainTableJobAPI = (
215
+ global_state.client.generate_train_table_job_api)
216
+ response = api.validate_custom_train_table(self.job_id,
217
+ source_table_type,
218
+ train_table_mod)
219
+ if not response.ok:
220
+ raise ValueError("Invalid weighted train table",
221
+ response.error_message)
222
+
223
+ def update(
224
+ self,
225
+ source_table: SourceTable,
226
+ train_table_mod: TrainingTableSpec,
227
+ validate: bool = True,
228
+ ) -> Self:
229
+ r"""Sets the `source_table` as the modified training table.
230
+
231
+ .. note::
232
+ The only allowed modification is the addition of weight column
233
+ Any other modification might lead to unintentded ERRORS while
234
+ training.
235
+ Further negative/NA weight values are not supported.
236
+
237
+ The custom training table is ingested during trainer.fit()
238
+ and is used as the training table.
239
+
240
+ Args:
241
+ source_table: The source table to be used as the modified training
242
+ table.
243
+ train_table_mod: The modification specification.
244
+ validate: Whether to validate the modified training table. This can
245
+ be slow for large tables.
246
+ """
247
+ if isinstance(source_table.connector, S3Connector):
248
+ # Special handling for s3 as `source_table._to_api_source_table`
249
+ # concatenates root_dir and file name. But the backend expects
250
+ # these to be separate.
251
+ source_table_type = self._to_s3_api_source_table(source_table)
252
+ else:
253
+ source_table_type = source_table._to_api_source_table()
254
+ if validate:
255
+ self.validate_custom_table(source_table_type, train_table_mod)
256
+ self._custom_train_table = CustomTrainingTable(
257
+ source_table=source_table_type, table_mod_spec=train_table_mod,
258
+ validated=validate)
259
+ return self
260
+
261
+
262
+ # Training Table Future #######################################################
263
+
264
+
265
+ class TrainingTableJob(JobInterface[GenerateTrainTableJobID,
266
+ GenerateTrainTableRequest,
267
+ GenerateTrainTableJobResource],
268
+ KumoProgressFuture[TrainingTable]):
269
+ r"""A representation of an ongoing training table generation job in the
270
+ Kumo platform.
271
+
272
+ .. code-block:: python
273
+
274
+ import kumoai
275
+
276
+ # See `PredictiveQuery` documentation:
277
+ pquery = kumoai.PredictiveQuery(...)
278
+
279
+ # If a training table is generated in nonblocking mode, the response
280
+ # will be of type `TrainingTableJob`:
281
+ training_table_job = pquery.generate_training_table(non_blocking=True)
282
+
283
+ # You can also construct a `TrainingTableJob` from a job ID, e.g.
284
+ # one that is present in the Kumo Jobs page:
285
+ training_table_job = kumoai.TrainingTableJob("trainingjob-...")
286
+
287
+ # Get the status of the job:
288
+ print(training_table_job.status())
289
+
290
+ # Attach to the job, and poll progress updates:
291
+ training_table_job.attach()
292
+
293
+ # Cancel the job:
294
+ training_table_job.cancel()
295
+
296
+ # Wait for the job to complete, and return a `TrainingTable`:
297
+ training_table_job.result()
298
+
299
+ Args:
300
+ job_id: ID of the training table generation job.
301
+ """
302
+ @override
303
+ @staticmethod
304
+ def _api() -> GenerateTrainTableJobAPI:
305
+ return global_state.client.generate_train_table_job_api
306
+
307
+ def __init__(
308
+ self,
309
+ job_id: GenerateTrainTableJobID,
310
+ ) -> None:
311
+ self.job_id = job_id
312
+
313
+ @cached_property
314
+ def _fut(self) -> Future[TrainingTable]:
315
+ return create_future(_poll(self.job_id))
316
+
317
+ @override
318
+ @property
319
+ def id(self) -> GenerateTrainTableJobID:
320
+ r"""The unique ID of this training table generation process."""
321
+ return self.job_id
322
+
323
+ @override
324
+ def result(self) -> TrainingTable:
325
+ return self._fut.result()
326
+
327
+ @override
328
+ def future(self) -> Future[TrainingTable]:
329
+ return self._fut
330
+
331
+ @override
332
+ def status(self) -> JobStatusReport:
333
+ r"""Returns the status of a running training table generation job."""
334
+ return _get_status(self.job_id)
335
+
336
+ @override
337
+ def _attach_internal(self, interval_s: float = 20.0) -> TrainingTable:
338
+ assert interval_s >= 4.0
339
+ print(f"Attaching to training table generation job {self.job_id}. "
340
+ f"Tracking this job in the Kumo UI is coming soon. To detach "
341
+ f"from this job, please enter Ctrl+C (the job will continue to "
342
+ f"run, and you can re-attach anytime).")
343
+
344
+ def _get_progress() -> Optional[Tuple[int, int]]:
345
+ progress = self._api().get_progress(self.job_id)
346
+ if len(progress) == 0:
347
+ return None
348
+ expected_iter = progress['num_expected_iterations']
349
+ completed_iter = progress['num_finished_iterations']
350
+ return (expected_iter, completed_iter)
351
+
352
+ # Print progress bar:
353
+ print("Training table generation is in progress. If your task is "
354
+ "temporal, progress per timeframe will be loaded shortly.")
355
+
356
+ # Wait for either timeframes to become available, or the job is done:
357
+ progress = _get_progress()
358
+ while progress is None:
359
+ progress = _get_progress()
360
+ # Not a temporal task, and it's done:
361
+ if self.status().status.is_terminal:
362
+ return self.result()
363
+ time.sleep(interval_s)
364
+
365
+ # Wait for timeframes to become available:
366
+ progress = _get_progress()
367
+ assert progress is not None
368
+ total, prog = progress
369
+ pbar = tqdm(total=total, unit="timeframe",
370
+ desc="Generating Training Table")
371
+ pbar.update(pbar.n - prog)
372
+ while not self.done():
373
+ progress = _get_progress()
374
+ assert progress is not None
375
+ total, prog = progress
376
+ pbar.reset(total)
377
+ pbar.update(prog)
378
+ time.sleep(interval_s)
379
+ pbar.update(pbar.total)
380
+ pbar.close()
381
+
382
+ # Future is done:
383
+ return self.result()
384
+
385
+ def cancel(self) -> None:
386
+ r"""Cancels a running training table generation job, and raises an
387
+ error if cancellation failed.
388
+ """
389
+ return self._api().cancel(self.job_id)
390
+
391
+ @override
392
+ def load_config(self) -> GenerateTrainTableRequest:
393
+ r"""Load the full configuration for this training table generation job.
394
+
395
+ Returns:
396
+ GenerateTrainTableRequest: Complete configuration including plan,
397
+ pquery_id, graph_snapshot_id, etc.
398
+ """
399
+ return self._api().get_config(self.job_id)
400
+
401
+
402
+ def _get_status(job_id: str) -> JobStatusReport:
403
+ api = global_state.client.generate_train_table_job_api
404
+ resource: GenerateTrainTableJobResource = api.get(job_id)
405
+ return resource.job_status_report
406
+
407
+
408
+ async def _poll(job_id: str) -> TrainingTable:
409
+ # TODO(manan): make asynchronous natively with aiohttp:
410
+ status = _get_status(job_id).status
411
+ while not status.is_terminal:
412
+ await asyncio.sleep(_DEFAULT_INTERVAL_S)
413
+ status = _get_status(job_id).status
414
+
415
+ if status != JobStatus.DONE:
416
+ api = global_state.client.generate_train_table_job_api
417
+ error_details = api.get_job_error(job_id)
418
+ error_str = pretty_print_error_details(error_details)
419
+ raise RuntimeError(
420
+ f"Training table generation for job {job_id} failed with "
421
+ f"job status {status}. Encountered below error(s):"
422
+ f'{error_str}')
423
+
424
+ return TrainingTable(job_id)
kumoai/spcs.py ADDED
@@ -0,0 +1,121 @@
1
+ import asyncio
2
+ import os
3
+ from functools import reduce
4
+ from typing import TYPE_CHECKING, Dict, List, Optional
5
+
6
+ if TYPE_CHECKING:
7
+ from snowflake.snowpark import DataFrame, Session
8
+
9
+
10
+ def _get_spcs_token(snowflake_credentials: Dict[str, str]) -> str:
11
+ r"""Fetches a token to access a Kumo application deployed in Snowflake
12
+ Snowpark Container Services (SPCS). This token is valid for 1 hour, after
13
+ which the token must be re-generated.
14
+ """
15
+ # Create a request to the ingress endpoint with authz:
16
+ active_session = _get_active_session()
17
+ if active_session is not None:
18
+ ctx = active_session.connection
19
+ else:
20
+ user = snowflake_credentials["user"]
21
+ password = snowflake_credentials["password"]
22
+ account = snowflake_credentials["account"]
23
+ import snowflake.connector
24
+ ctx = snowflake.connector.connect(
25
+ user=user,
26
+ password=password,
27
+ account=account,
28
+ session_parameters={
29
+ 'PYTHON_CONNECTOR_QUERY_RESULT_FORMAT': 'json'
30
+ },
31
+ )
32
+
33
+ # Obtain a session token:
34
+ token_data = ctx._rest._token_request('ISSUE')
35
+ token_extract = token_data['data']['sessionToken']
36
+ return f'\"{token_extract}\"'
37
+
38
+
39
+ def _refresh_spcs_token() -> None:
40
+ r"""Refreshes the SPCS token in global state to avoid expiration."""
41
+ from kumoai import KumoClient, global_state
42
+ if (not global_state.initialized
43
+ or (not global_state._snowflake_credentials
44
+ and not global_state._snowpark_session)):
45
+ raise ValueError(
46
+ "Please initialize the Kumo application with snowflake "
47
+ "credentials before attempting to refresh this token.")
48
+ spcs_token = _get_spcs_token(global_state._snowflake_credentials or {})
49
+
50
+ # Verify token validity:
51
+ assert global_state._url is not None
52
+ client = KumoClient(
53
+ url=global_state._url,
54
+ api_key=global_state._api_key,
55
+ spcs_token=spcs_token,
56
+ )
57
+ client.authenticate()
58
+
59
+ # Update state:
60
+ global_state.set_spcs_token(spcs_token)
61
+
62
+
63
+ async def _run_refresh_spcs_token(minutes: int) -> None:
64
+ r"""Runs the SPCS token refresh loop every `minutes` minutes."""
65
+ while True:
66
+ await asyncio.sleep(minutes * 60)
67
+ _refresh_spcs_token()
68
+
69
+
70
+ def _get_active_session() -> 'Optional[Session]':
71
+ try:
72
+ from snowflake.snowpark.context import get_active_session
73
+ return get_active_session()
74
+ except Exception:
75
+ return None
76
+
77
+
78
+ def _get_session() -> 'Session':
79
+ import snowflake.snowpark as snowpark
80
+
81
+ from kumoai import global_state
82
+ params = global_state._snowflake_credentials
83
+ assert params is not None
84
+
85
+ database = os.getenv("SNOWFLAKE_DATABASE")
86
+ schema = os.getenv("SNOWFLAKE_SCHEMA")
87
+ if not database or not schema:
88
+ raise ValueError("Please set the SNOWFLAKE_DATABASE and "
89
+ "SNOWFLAKE_SCHEMA environment variables.")
90
+ params['database'] = database
91
+ params['schema'] = schema
92
+ params['client_session_keep_alive'] = True
93
+
94
+ return snowpark.Session.builder.configs(params).create()
95
+
96
+
97
+ def _remove_path(session: 'Session', stage_path: str, file_path: str) -> None:
98
+ stage_prefix = '.'.join(stage_path.split('.')[:2])
99
+ name_remove = '.'.join([stage_prefix, file_path])
100
+ session.sql(f"REMOVE {name_remove}").collect()
101
+
102
+
103
+ def _parquet_to_df(path: str) -> 'DataFrame':
104
+ r"""Reads parquet from the given path and returns a snowpark DataFrame."""
105
+ session = _get_session()
106
+ if not path.endswith(os.path.sep):
107
+ path += os.path.sep
108
+ file_list = session.sql(f"LIST {path}").collect()
109
+ for file_row in file_list:
110
+ if file_row.name.endswith('.parquet'):
111
+ continue
112
+ _remove_path(session, path, file_row.name)
113
+ df = session.read.parquet(path)
114
+ return df
115
+
116
+
117
+ def _parquet_dataset_to_df(paths: List[str]) -> 'DataFrame':
118
+ r"""Reads parquet from the given paths and returns a snowpark DataFrame."""
119
+ from snowflake.snowpark import DataFrame
120
+ df_list = [_parquet_to_df(url) for url in paths]
121
+ return reduce(DataFrame.union_all, df_list)
@@ -0,0 +1,8 @@
1
+ from .decorators import has_package, withPackage, is_full_test, onlyFullTest
2
+
3
+ __all__ = [
4
+ 'is_full_test',
5
+ 'onlyFullTest',
6
+ 'has_package',
7
+ 'withPackage',
8
+ ]
@@ -0,0 +1,57 @@
1
+ import importlib
2
+ import os
3
+ from typing import Callable
4
+
5
+ import packaging
6
+ from packaging.requirements import Requirement
7
+
8
+
9
+ def is_full_test() -> bool:
10
+ r"""Whether to run the full but time-consuming test suite."""
11
+ return os.getenv('FULL_TEST', '0') == '1'
12
+
13
+
14
+ def onlyFullTest(func: Callable) -> Callable:
15
+ r"""A decorator to specify that this function belongs to the full test
16
+ suite.
17
+ """
18
+ import pytest
19
+ return pytest.mark.skipif(
20
+ not is_full_test(),
21
+ reason="Fast test run",
22
+ )(func)
23
+
24
+
25
+ def has_package(package: str) -> bool:
26
+ r"""Returns ``True`` in case ``package`` is installed."""
27
+ req = Requirement(package)
28
+ if importlib.util.find_spec(req.name) is None: # type: ignore
29
+ return False
30
+
31
+ try:
32
+ module = importlib.import_module(req.name)
33
+ if not hasattr(module, '__version__'):
34
+ return True
35
+
36
+ version = packaging.version.Version(module.__version__).base_version
37
+ return version in req.specifier
38
+ except Exception:
39
+ return False
40
+
41
+
42
+ def withPackage(*args: str) -> Callable:
43
+ r"""A decorator to skip tests if certain packages are not installed.
44
+ Also supports version specification.
45
+ """
46
+ na_packages = {package for package in args if not has_package(package)}
47
+
48
+ if len(na_packages) == 1:
49
+ reason = f"Package '{list(na_packages)[0]}' not found"
50
+ else:
51
+ reason = f"Packages {na_packages} not found"
52
+
53
+ def decorator(func: Callable) -> Callable:
54
+ import pytest
55
+ return pytest.mark.skipif(len(na_packages) > 0, reason=reason)(func)
56
+
57
+ return decorator
kumoai/testing/snow.py ADDED
@@ -0,0 +1,50 @@
1
+ import json
2
+ import os
3
+
4
+ from kumoai.experimental.rfm.backend.snow import Connection
5
+ from kumoai.experimental.rfm.backend.snow import connect as _connect
6
+
7
+
8
+ def connect(
9
+ region: str,
10
+ id: str,
11
+ account: str,
12
+ user: str,
13
+ warehouse: str,
14
+ database: str | None = None,
15
+ schema: str | None = None,
16
+ ) -> Connection:
17
+
18
+ kwargs = dict(password=os.getenv('SNOWFLAKE_PASSWORD'))
19
+ if kwargs['password'] is None:
20
+ import boto3
21
+ from cryptography.hazmat.primitives import serialization
22
+
23
+ client = boto3.client(
24
+ service_name='secretsmanager',
25
+ region_name=region,
26
+ )
27
+ secret_id = (f'arn:aws:secretsmanager:{region}:{id}:secret:'
28
+ f'{account}.snowflakecomputing.com')
29
+ response = client.get_secret_value(SecretId=secret_id)['SecretString']
30
+ secret = json.loads(response)
31
+
32
+ private_key = serialization.load_pem_private_key(
33
+ secret['kumo_user_secretkey'].encode(),
34
+ password=None,
35
+ )
36
+ kwargs['private_key'] = private_key.private_bytes(
37
+ encoding=serialization.Encoding.DER,
38
+ format=serialization.PrivateFormat.PKCS8,
39
+ encryption_algorithm=serialization.NoEncryption(),
40
+ )
41
+
42
+ return _connect(
43
+ account=account,
44
+ user=user,
45
+ warehouse='WH_XS',
46
+ database='KUMO',
47
+ schema=schema,
48
+ session_parameters=dict(CLIENT_TELEMETRY_ENABLED=False),
49
+ **kwargs,
50
+ )