kumoai 2.14.0.dev202601051732__cp311-cp311-macosx_11_0_arm64.whl → 2.15.0.dev202601141731__cp311-cp311-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -8,7 +8,6 @@ from contextlib import contextmanager
8
8
  from dataclasses import dataclass, replace
9
9
  from typing import Any, Literal, overload
10
10
 
11
- import numpy as np
12
11
  import pandas as pd
13
12
  from kumoapi.model_plan import RunMode
14
13
  from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
@@ -28,7 +27,10 @@ from kumoapi.rfm import (
28
27
  )
29
28
  from kumoapi.task import TaskType
30
29
  from kumoapi.typing import AggregationType, Stype
30
+ from rich.console import Console
31
+ from rich.markdown import Markdown
31
32
 
33
+ from kumoai import in_notebook
32
34
  from kumoai.client.rfm import RFMAPI
33
35
  from kumoai.exceptions import HTTPException
34
36
  from kumoai.experimental.rfm import Graph, TaskTable
@@ -106,10 +108,20 @@ class Explanation:
106
108
  def __repr__(self) -> str:
107
109
  return str((self.prediction, self.summary))
108
110
 
111
+ def __str__(self) -> str:
112
+ console = Console(soft_wrap=True)
113
+ with console.capture() as cap:
114
+ console.print(display.to_rich_table(self.prediction))
115
+ console.print(Markdown(self.summary))
116
+ return cap.get()[:-1]
117
+
109
118
  def print(self) -> None:
110
119
  r"""Prints the explanation."""
111
- display.dataframe(self.prediction)
112
- display.message(self.summary)
120
+ if in_notebook():
121
+ display.dataframe(self.prediction)
122
+ display.message(self.summary)
123
+ else:
124
+ print(self)
113
125
 
114
126
  def _ipython_display_(self) -> None:
115
127
  self.print()
@@ -180,7 +192,7 @@ class KumoRFM:
180
192
  self._client: RFMAPI | None = None
181
193
 
182
194
  self._batch_size: int | Literal['max'] | None = None
183
- self.num_retries: int = 0
195
+ self._num_retries: int = 0
184
196
 
185
197
  @property
186
198
  def _api_client(self) -> RFMAPI:
@@ -194,6 +206,30 @@ class KumoRFM:
194
206
  def __repr__(self) -> str:
195
207
  return f'{self.__class__.__name__}()'
196
208
 
209
+ @contextmanager
210
+ def retry(
211
+ self,
212
+ num_retries: int = 1,
213
+ ) -> Generator[None, None, None]:
214
+ """Context manager to retry failed queries due to unexpected server
215
+ issues.
216
+
217
+ .. code-block:: python
218
+
219
+ with model.retry(num_retries=1):
220
+ df = model.predict(query, indices=...)
221
+
222
+ Args:
223
+ num_retries: The maximum number of retries.
224
+ """
225
+ if num_retries < 0:
226
+ raise ValueError(f"'num_retries' must be greater than or equal to "
227
+ f"zero (got {num_retries})")
228
+
229
+ self._num_retries = num_retries
230
+ yield
231
+ self._num_retries = 0
232
+
197
233
  @contextmanager
198
234
  def batch_mode(
199
235
  self,
@@ -217,15 +253,10 @@ class KumoRFM:
217
253
  raise ValueError(f"'batch_size' must be greater than zero "
218
254
  f"(got {batch_size})")
219
255
 
220
- if num_retries < 0:
221
- raise ValueError(f"'num_retries' must be greater than or equal to "
222
- f"zero (got {num_retries})")
223
-
224
256
  self._batch_size = batch_size
225
- self.num_retries = num_retries
226
- yield
257
+ with self.retry(self._num_retries or num_retries):
258
+ yield
227
259
  self._batch_size = None
228
- self.num_retries = 0
229
260
 
230
261
  @overload
231
262
  def predict(
@@ -265,6 +296,25 @@ class KumoRFM:
265
296
  ) -> Explanation:
266
297
  pass
267
298
 
299
+ @overload
300
+ def predict(
301
+ self,
302
+ query: str,
303
+ indices: list[str] | list[float] | list[int] | None = None,
304
+ *,
305
+ explain: bool | ExplainConfig | dict[str, Any] = False,
306
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
307
+ context_anchor_time: pd.Timestamp | None = None,
308
+ run_mode: RunMode | str = RunMode.FAST,
309
+ num_neighbors: list[int] | None = None,
310
+ num_hops: int = 2,
311
+ max_pq_iterations: int = 10,
312
+ random_seed: int | None = _RANDOM_SEED,
313
+ verbose: bool | ProgressLogger = True,
314
+ use_prediction_time: bool = False,
315
+ ) -> pd.DataFrame | Explanation:
316
+ pass
317
+
268
318
  def predict(
269
319
  self,
270
320
  query: str,
@@ -288,8 +338,7 @@ class KumoRFM:
288
338
  indices: The entity primary keys to predict for. Will override the
289
339
  indices given as part of the predictive query. Predictions will
290
340
  be generated for all indices, independent of whether they
291
- fulfill entity filter constraints. To pre-filter entities, use
292
- :meth:`~KumoRFM.is_valid_entity`.
341
+ fulfill entity filter constraints.
293
342
  explain: Configuration for explainability.
294
343
  If set to ``True``, will additionally explain the prediction.
295
344
  Passing in an :class:`ExplainConfig` instance provides control
@@ -329,8 +378,11 @@ class KumoRFM:
329
378
  raise ValueError("Cannot find entities to predict for. Please "
330
379
  "pass them via `predict(query, indices=...)`")
331
380
  indices = query_def.get_rfm_entity_id_list()
332
- else:
333
- query_def = replace(query_def, rfm_entity_ids=None)
381
+ query_def = replace(
382
+ query_def,
383
+ for_each='FOR EACH',
384
+ rfm_entity_ids=None,
385
+ )
334
386
 
335
387
  if not isinstance(verbose, ProgressLogger):
336
388
  query_repr = query_def.to_string(rich=True, exclude_predict=True)
@@ -351,11 +403,11 @@ class KumoRFM:
351
403
  random_seed=random_seed,
352
404
  logger=logger,
353
405
  )
354
- task_table._query = query_def.to_string() # type: ignore
406
+ task_table._query = query_def.to_string()
355
407
 
356
408
  return self.predict_task(
357
409
  task_table,
358
- explain=explain, # type: ignore
410
+ explain=explain,
359
411
  run_mode=run_mode,
360
412
  num_neighbors=num_neighbors,
361
413
  num_hops=num_hops,
@@ -397,6 +449,22 @@ class KumoRFM:
397
449
  ) -> Explanation:
398
450
  pass
399
451
 
452
+ @overload
453
+ def predict_task(
454
+ self,
455
+ task: TaskTable,
456
+ *,
457
+ explain: bool | ExplainConfig | dict[str, Any] = False,
458
+ run_mode: RunMode | str = RunMode.FAST,
459
+ num_neighbors: list[int] | None = None,
460
+ num_hops: int = 2,
461
+ verbose: bool | ProgressLogger = True,
462
+ exclude_cols_dict: dict[str, list[str]] | None = None,
463
+ use_prediction_time: bool = False,
464
+ top_k: int | None = None,
465
+ ) -> pd.DataFrame | Explanation:
466
+ pass
467
+
400
468
  def predict_task(
401
469
  self,
402
470
  task: TaskTable,
@@ -477,9 +545,9 @@ class KumoRFM:
477
545
  task_type_repr = str(task.task_type)
478
546
 
479
547
  if explain_config is not None:
480
- msg = f'Explain {task_type_repr} task'
548
+ msg = f"Explaining {task_type_repr} task"
481
549
  else:
482
- msg = f'Predict {task_type_repr} task'
550
+ msg = f"Predicting {task_type_repr} task"
483
551
  verbose = ProgressLogger.default(msg=msg, verbose=verbose)
484
552
 
485
553
  with verbose as logger:
@@ -525,7 +593,7 @@ class KumoRFM:
525
593
  request = RFMPredictRequest(
526
594
  context=context,
527
595
  run_mode=RunMode(run_mode),
528
- query=getattr(task, '_query', ''),
596
+ query=task._query,
529
597
  use_prediction_time=use_prediction_time,
530
598
  )
531
599
  with warnings.catch_warnings():
@@ -544,7 +612,7 @@ class KumoRFM:
544
612
  num = math.ceil(task.num_prediction_examples / batch_size)
545
613
  verbose.init_progress(total=num, description='Predicting')
546
614
 
547
- for attempt in range(self.num_retries + 1):
615
+ for attempt in range(self._num_retries + 1):
548
616
  try:
549
617
  if explain_config is not None:
550
618
  resp = self._api_client.explain(
@@ -582,7 +650,7 @@ class KumoRFM:
582
650
 
583
651
  break
584
652
  except HTTPException as e:
585
- if attempt == self.num_retries:
653
+ if attempt == self._num_retries:
586
654
  try:
587
655
  msg = json.loads(e.detail)['detail']
588
656
  except Exception:
@@ -612,51 +680,6 @@ class KumoRFM:
612
680
 
613
681
  return prediction
614
682
 
615
- def is_valid_entity(
616
- self,
617
- query: str,
618
- indices: list[str] | list[float] | list[int] | None = None,
619
- *,
620
- anchor_time: pd.Timestamp | Literal['entity'] | None = None,
621
- ) -> np.ndarray:
622
- r"""Returns a mask that denotes which entities are valid for the
623
- given predictive query, *i.e.*, which entities fulfill (temporal)
624
- entity filter constraints.
625
-
626
- Args:
627
- query: The predictive query.
628
- indices: The entity primary keys to predict for. Will override the
629
- indices given as part of the predictive query.
630
- anchor_time: The anchor timestamp for the prediction. If set to
631
- ``None``, will use the maximum timestamp in the data.
632
- If set to ``"entity"``, will use the timestamp of the entity.
633
- """
634
- query_def = self._parse_query(query)
635
-
636
- if indices is None:
637
- if query_def.rfm_entity_ids is None:
638
- raise ValueError("Cannot find entities to predict for. Please "
639
- "pass them via "
640
- "`is_valid_entity(query, indices=...)`")
641
- indices = query_def.get_rfm_entity_id_list()
642
-
643
- if len(indices) == 0:
644
- raise ValueError("At least one entity is required")
645
-
646
- if anchor_time is None:
647
- anchor_time = self._get_default_anchor_time(query_def)
648
-
649
- if isinstance(anchor_time, pd.Timestamp):
650
- self._validate_time(query_def, anchor_time, None, False)
651
- else:
652
- assert anchor_time == 'entity'
653
- if query_def.entity_table not in self._sampler.time_column_dict:
654
- raise ValueError(f"Anchor time 'entity' requires the entity "
655
- f"table '{query_def.entity_table}' "
656
- f"to have a time column.")
657
-
658
- raise NotImplementedError
659
-
660
683
  def evaluate(
661
684
  self,
662
685
  query: str,
@@ -701,29 +724,12 @@ class KumoRFM:
701
724
  Returns:
702
725
  The metrics as a :class:`pandas.DataFrame`
703
726
  """
704
- query_def = self._parse_query(query)
705
- if query_def.rfm_entity_ids is not None:
706
- query_def = replace(
707
- query_def,
708
- rfm_entity_ids=None,
709
- )
710
-
711
- task_type = self._get_task_type(
712
- query=query_def,
713
- edge_types=self._sampler.edge_types,
727
+ query_def = replace(
728
+ self._parse_query(query),
729
+ for_each='FOR EACH',
730
+ rfm_entity_ids=None,
714
731
  )
715
732
 
716
- if num_hops != 2 and num_neighbors is not None:
717
- warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
718
- f"custom 'num_hops={num_hops}' option")
719
- if num_neighbors is None:
720
- key = RunMode.FAST if task_type.is_link_pred else run_mode
721
- num_neighbors = _DEFAULT_NUM_NEIGHBORS[key][:num_hops]
722
-
723
- if metrics is not None and len(metrics) > 0:
724
- self._validate_metrics(metrics, task_type)
725
- metrics = list(dict.fromkeys(metrics))
726
-
727
733
  if not isinstance(verbose, ProgressLogger):
728
734
  query_repr = query_def.to_string(rich=True, exclude_predict=True)
729
735
  msg = f'[bold]EVALUATE[/bold] {query_repr}'
@@ -740,11 +746,96 @@ class KumoRFM:
740
746
  random_seed=random_seed,
741
747
  logger=logger,
742
748
  )
743
- context = self._get_context(
744
- task=task_table,
749
+
750
+ return self.evaluate_task(
751
+ task_table,
752
+ metrics=metrics,
745
753
  run_mode=run_mode,
746
754
  num_neighbors=num_neighbors,
755
+ num_hops=num_hops,
756
+ verbose=verbose,
747
757
  exclude_cols_dict=query_def.get_exclude_cols_dict(),
758
+ use_prediction_time=use_prediction_time,
759
+ )
760
+
761
+ def evaluate_task(
762
+ self,
763
+ task: TaskTable,
764
+ *,
765
+ metrics: list[str] | None = None,
766
+ run_mode: RunMode | str = RunMode.FAST,
767
+ num_neighbors: list[int] | None = None,
768
+ num_hops: int = 2,
769
+ verbose: bool | ProgressLogger = True,
770
+ exclude_cols_dict: dict[str, list[str]] | None = None,
771
+ use_prediction_time: bool = False,
772
+ ) -> pd.DataFrame:
773
+ """Evaluates a custom task specification.
774
+
775
+ Args:
776
+ task: The custom :class:`TaskTable`.
777
+ metrics: The metrics to use.
778
+ run_mode: The :class:`RunMode` for the query.
779
+ num_neighbors: The number of neighbors to sample for each hop.
780
+ If specified, the ``num_hops`` option will be ignored.
781
+ num_hops: The number of hops to sample when generating the context.
782
+ verbose: Whether to print verbose output.
783
+ exclude_cols_dict: Any column in any table to exclude from the
784
+ model input.
785
+ use_prediction_time: Whether to use the anchor timestamp as an
786
+ additional feature during prediction. This is typically
787
+ beneficial for time series forecasting tasks.
788
+
789
+ Returns:
790
+ The metrics as a :class:`pandas.DataFrame`
791
+ """
792
+ if num_hops != 2 and num_neighbors is not None:
793
+ warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
794
+ f"custom 'num_hops={num_hops}' option")
795
+ if num_neighbors is None:
796
+ key = RunMode.FAST if task.task_type.is_link_pred else run_mode
797
+ num_neighbors = _DEFAULT_NUM_NEIGHBORS[key][:num_hops]
798
+
799
+ if metrics is not None and len(metrics) > 0:
800
+ self._validate_metrics(metrics, task.task_type)
801
+ metrics = list(dict.fromkeys(metrics))
802
+
803
+ if not isinstance(verbose, ProgressLogger):
804
+ if task.task_type == TaskType.BINARY_CLASSIFICATION:
805
+ task_type_repr = 'binary classification'
806
+ elif task.task_type == TaskType.MULTICLASS_CLASSIFICATION:
807
+ task_type_repr = 'multi-class classification'
808
+ elif task.task_type == TaskType.REGRESSION:
809
+ task_type_repr = 'regression'
810
+ elif task.task_type == TaskType.TEMPORAL_LINK_PREDICTION:
811
+ task_type_repr = 'link prediction'
812
+ else:
813
+ task_type_repr = str(task.task_type)
814
+
815
+ msg = f"Evaluating {task_type_repr} task"
816
+ verbose = ProgressLogger.default(msg=msg, verbose=verbose)
817
+
818
+ with verbose as logger:
819
+ if task.num_context_examples > _MAX_CONTEXT_SIZE[run_mode]:
820
+ logger.log(f"Sub-sampled {_MAX_CONTEXT_SIZE[run_mode]:,} "
821
+ f"out of {task.num_context_examples:,} in-context "
822
+ f"examples")
823
+ task = task.narrow_context(0, _MAX_CONTEXT_SIZE[run_mode])
824
+
825
+ if task.num_prediction_examples > _MAX_TEST_SIZE[task.task_type]:
826
+ logger.log(f"Sub-sampled {_MAX_TEST_SIZE[task.task_type]:,} "
827
+ f"out of {task.num_prediction_examples:,} test "
828
+ f"examples")
829
+ task = task.narrow_prediction(
830
+ start=0,
831
+ length=_MAX_TEST_SIZE[task.task_type],
832
+ )
833
+
834
+ context = self._get_context(
835
+ task=task,
836
+ run_mode=run_mode,
837
+ num_neighbors=num_neighbors,
838
+ exclude_cols_dict=exclude_cols_dict,
748
839
  )
749
840
 
750
841
  request = RFMEvaluateRequest(
@@ -764,12 +855,12 @@ class KumoRFM:
764
855
  stats_msg = Context.get_memory_stats(request_msg.context)
765
856
  raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats_msg))
766
857
 
767
- for attempt in range(self.num_retries + 1):
858
+ for attempt in range(self._num_retries + 1):
768
859
  try:
769
860
  resp = self._api_client.evaluate(request_bytes)
770
861
  break
771
862
  except HTTPException as e:
772
- if attempt == self.num_retries:
863
+ if attempt == self._num_retries:
773
864
  try:
774
865
  msg = json.loads(e.detail)['detail']
775
866
  except Exception:
@@ -865,12 +956,12 @@ class KumoRFM:
865
956
  graph_definition=self._graph_def,
866
957
  )
867
958
 
868
- for attempt in range(self.num_retries + 1):
959
+ for attempt in range(self._num_retries + 1):
869
960
  try:
870
961
  resp = self._api_client.parse_query(request)
871
962
  break
872
963
  except HTTPException as e:
873
- if attempt == self.num_retries:
964
+ if attempt == self._num_retries:
874
965
  try:
875
966
  msg = json.loads(e.detail)['detail']
876
967
  except Exception:
@@ -953,8 +1044,16 @@ class KumoRFM:
953
1044
  if len(self._sampler.time_column_dict) == 0:
954
1045
  return # Graph without timestamps
955
1046
 
956
- min_time = self._sampler.get_min_time()
957
- max_time = self._sampler.get_max_time()
1047
+ if query.query_type == QueryType.TEMPORAL:
1048
+ aggr_table_names = [
1049
+ aggr._get_target_column_name().split('.')[0]
1050
+ for aggr in query.get_all_target_aggregations()
1051
+ ]
1052
+ min_time = self._sampler.get_min_time(aggr_table_names)
1053
+ max_time = self._sampler.get_max_time(aggr_table_names)
1054
+ else:
1055
+ min_time = self._sampler.get_min_time()
1056
+ max_time = self._sampler.get_max_time()
958
1057
 
959
1058
  if anchor_time < min_time:
960
1059
  raise ValueError(f"Anchor timestamp '{anchor_time}' is before "
@@ -1193,24 +1292,6 @@ class KumoRFM:
1193
1292
  top_k: int | None = None,
1194
1293
  ) -> Context:
1195
1294
 
1196
- # TODO Remove all
1197
- if task.num_context_examples > max(_MAX_CONTEXT_SIZE.values()):
1198
- raise ValueError(f"Cannot process a context with more than "
1199
- f"{max(_MAX_CONTEXT_SIZE.values()):,} samples "
1200
- f"(got {task.num_context_examples:,})")
1201
- if task.evaluate:
1202
- if task.num_prediction_examples > _MAX_TEST_SIZE[task.task_type]:
1203
- raise ValueError(f"Cannot process a test set with more than "
1204
- f"{_MAX_TEST_SIZE[task.task_type]:,} samples "
1205
- f"for evaluation "
1206
- f"(got {task.num_prediction_examples:,})")
1207
- else:
1208
- if task.num_prediction_examples > _MAX_PRED_SIZE[task.task_type]:
1209
- raise ValueError(f"Cannot predict for more than "
1210
- f"{_MAX_PRED_SIZE[task.task_type]:,} "
1211
- f"entities at once "
1212
- f"(got {task.num_prediction_examples:,})")
1213
-
1214
1295
  if num_neighbors is None:
1215
1296
  key = RunMode.FAST if task.task_type.is_link_pred else run_mode
1216
1297
  num_neighbors = _DEFAULT_NUM_NEIGHBORS[key][:2]
@@ -87,6 +87,8 @@ class TaskTable:
87
87
  if time_column is not None:
88
88
  self.time_column = time_column
89
89
 
90
+ self._query: str = '' # A description of the task, e.g., for XAI.
91
+
90
92
  @property
91
93
  def num_context_examples(self) -> int:
92
94
  return len(self._context_df)
@@ -199,6 +199,7 @@ class TrainingTable:
199
199
  self,
200
200
  source_table_type: SourceTableType,
201
201
  train_table_mod: TrainingTableSpec,
202
+ extensive_validation: bool = False,
202
203
  ) -> None:
203
204
  r"""Validates the modified training table.
204
205
 
@@ -206,6 +207,8 @@ class TrainingTable:
206
207
  source_table_type: The source table to be used as the modified
207
208
  training table.
208
209
  train_table_mod: The modification specification.
210
+ extensive_validation: Enable extensive validation for custom
211
+ table.
209
212
 
210
213
  Raises:
211
214
  ValueError: If the modified training table is invalid.
@@ -215,7 +218,8 @@ class TrainingTable:
215
218
  global_state.client.generate_train_table_job_api)
216
219
  response = api.validate_custom_train_table(self.job_id,
217
220
  source_table_type,
218
- train_table_mod)
221
+ train_table_mod,
222
+ extensive_validation)
219
223
  if not response.ok:
220
224
  raise ValueError("Invalid weighted train table",
221
225
  response.error_message)
@@ -225,6 +229,7 @@ class TrainingTable:
225
229
  source_table: SourceTable,
226
230
  train_table_mod: TrainingTableSpec,
227
231
  validate: bool = True,
232
+ extensive_validation: bool = False,
228
233
  ) -> Self:
229
234
  r"""Sets the `source_table` as the modified training table.
230
235
 
@@ -243,6 +248,9 @@ class TrainingTable:
243
248
  train_table_mod: The modification specification.
244
249
  validate: Whether to validate the modified training table. This can
245
250
  be slow for large tables.
251
+ extensive_validation: Whether to validate number of rows in
252
+ existing and modified training table.
253
+ It can be slow for large tables.
246
254
  """
247
255
  if isinstance(source_table.connector, S3Connector):
248
256
  # Special handling for s3 as `source_table._to_api_source_table`
@@ -252,7 +260,13 @@ class TrainingTable:
252
260
  else:
253
261
  source_table_type = source_table._to_api_source_table()
254
262
  if validate:
255
- self.validate_custom_table(source_table_type, train_table_mod)
263
+ if extensive_validation:
264
+ logger.warning(
265
+ "You have opted in to perform extensive validation on"
266
+ " your custom training table."
267
+ " This operation can be slow for large tables.")
268
+ self.validate_custom_table(source_table_type, train_table_mod,
269
+ extensive_validation)
256
270
  self._custom_train_table = CustomTrainingTable(
257
271
  source_table=source_table_type, table_mod_spec=train_table_mod,
258
272
  validated=validate)
kumoai/testing/snow.py CHANGED
@@ -10,7 +10,7 @@ def connect(
10
10
  id: str,
11
11
  account: str,
12
12
  user: str,
13
- warehouse: str,
13
+ warehouse: str | None = None,
14
14
  database: str | None = None,
15
15
  schema: str | None = None,
16
16
  ) -> Connection:
@@ -42,8 +42,8 @@ def connect(
42
42
  return _connect(
43
43
  account=account,
44
44
  user=user,
45
- warehouse='WH_XS',
46
- database='KUMO',
45
+ warehouse=warehouse or 'WH_XS',
46
+ database=database or 'KUMO',
47
47
  schema=schema,
48
48
  session_parameters=dict(CLIENT_TELEMETRY_ENABLED=False),
49
49
  **kwargs,
kumoai/utils/display.py CHANGED
@@ -1,13 +1,15 @@
1
1
  from collections.abc import Sequence
2
2
 
3
3
  import pandas as pd
4
+ from rich import box
5
+ from rich.console import Console
6
+ from rich.table import Table
7
+ from rich.text import Text
4
8
 
5
9
  from kumoai import in_notebook, in_snowflake_notebook
6
10
 
7
11
 
8
12
  def message(msg: str) -> None:
9
- msg = msg.replace("`", "'") if not in_notebook() else msg
10
-
11
13
  if in_snowflake_notebook():
12
14
  import streamlit as st
13
15
  st.markdown(msg)
@@ -15,23 +17,40 @@ def message(msg: str) -> None:
15
17
  from IPython.display import Markdown, display
16
18
  display(Markdown(msg))
17
19
  else:
18
- print(msg)
20
+ print(msg.replace("`", "'"))
19
21
 
20
22
 
21
23
  def title(msg: str) -> None:
22
- message(f"### {msg}" if in_notebook() else f"{msg}:")
24
+ if in_notebook():
25
+ message(f"### {msg}")
26
+ else:
27
+ msg = msg.replace("`", "'")
28
+ Console().print(f"[bold]{msg}[/bold]", highlight=False)
23
29
 
24
30
 
25
31
  def italic(msg: str) -> None:
26
- message(f"*{msg}*" if in_notebook() else msg)
32
+ if in_notebook():
33
+ message(f"*{msg}*")
34
+ else:
35
+ msg = msg.replace("`", "'")
36
+ Console().print(
37
+ f"[italic]{msg}[/italic]",
38
+ highlight=False,
39
+ style='dim',
40
+ )
27
41
 
28
42
 
29
43
  def unordered_list(items: Sequence[str]) -> None:
30
44
  if in_notebook():
31
45
  msg = '\n'.join([f"- {item}" for item in items])
46
+ message(msg)
32
47
  else:
33
- msg = '\n'.join([f"• {item.replace('`', '')}" for item in items])
34
- message(msg)
48
+ text = Text('\n').join(
49
+ Text.assemble(
50
+ Text(' • ', style='yellow'),
51
+ Text(item.replace('`', '')),
52
+ ) for item in items)
53
+ Console().print(text, highlight=False)
35
54
 
36
55
 
37
56
  def dataframe(df: pd.DataFrame) -> None:
@@ -48,4 +67,21 @@ def dataframe(df: pd.DataFrame) -> None:
48
67
  except ImportError:
49
68
  print(df.to_string(index=False)) # missing jinja2
50
69
  else:
51
- print(df.to_string(index=False))
70
+ Console().print(to_rich_table(df))
71
+
72
+
73
+ def to_rich_table(df: pd.DataFrame) -> Table:
74
+ table = Table(box=box.ROUNDED)
75
+ for column in df.columns:
76
+ table.add_column(str(column))
77
+ for _, row in df.iterrows():
78
+ values: list[str | Text] = []
79
+ for value in row:
80
+ if str(value) == 'True':
81
+ values.append('✅')
82
+ elif str(value) in {'False', '-'}:
83
+ values.append(Text('-', style='dim'))
84
+ else:
85
+ values.append(str(value))
86
+ table.add_row(*values)
87
+ return table
@@ -57,7 +57,8 @@ class ProgressLogger:
57
57
 
58
58
  def __enter__(self) -> Self:
59
59
  self.depth += 1
60
- self.start_time = time.perf_counter()
60
+ if self.depth == 1:
61
+ self.start_time = time.perf_counter()
61
62
  return self
62
63
 
63
64
  def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
kumoai/utils/sql.py CHANGED
@@ -1,3 +1,3 @@
1
- def quote_ident(name: str) -> str:
1
+ def quote_ident(ident: str, char: str = '"') -> str:
2
2
  r"""Quotes a SQL identifier."""
3
- return '"' + name.replace('"', '""') + '"'
3
+ return char + ident.replace(char, char + char) + char
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: kumoai
3
- Version: 2.14.0.dev202601051732
3
+ Version: 2.15.0.dev202601141731
4
4
  Summary: AI on the Modern Data Stack
5
5
  Author-email: "Kumo.AI" <hello@kumo.ai>
6
6
  License-Expression: MIT
@@ -23,7 +23,7 @@ Requires-Dist: requests>=2.28.2
23
23
  Requires-Dist: urllib3
24
24
  Requires-Dist: plotly
25
25
  Requires-Dist: typing_extensions>=4.5.0
26
- Requires-Dist: kumo-api==0.49.0
26
+ Requires-Dist: kumo-api<1.0.0,>=0.53.0
27
27
  Requires-Dist: tqdm>=4.66.0
28
28
  Requires-Dist: aiohttp>=3.10.0
29
29
  Requires-Dist: pydantic>=1.10.21