kumoai 2.14.0.dev202512211732__cp313-cp313-win_amd64.whl → 2.15.0.dev202601151732__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. kumoai/__init__.py +23 -26
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +6 -0
  4. kumoai/client/jobs.py +26 -0
  5. kumoai/connector/utils.py +21 -7
  6. kumoai/experimental/rfm/__init__.py +24 -22
  7. kumoai/experimental/rfm/backend/local/graph_store.py +12 -21
  8. kumoai/experimental/rfm/backend/local/sampler.py +0 -3
  9. kumoai/experimental/rfm/backend/local/table.py +24 -25
  10. kumoai/experimental/rfm/backend/snow/sampler.py +235 -80
  11. kumoai/experimental/rfm/backend/snow/table.py +146 -70
  12. kumoai/experimental/rfm/backend/sqlite/sampler.py +196 -89
  13. kumoai/experimental/rfm/backend/sqlite/table.py +85 -55
  14. kumoai/experimental/rfm/base/__init__.py +6 -9
  15. kumoai/experimental/rfm/base/column.py +95 -11
  16. kumoai/experimental/rfm/base/expression.py +44 -0
  17. kumoai/experimental/rfm/base/mapper.py +69 -0
  18. kumoai/experimental/rfm/base/sampler.py +28 -18
  19. kumoai/experimental/rfm/base/source.py +1 -1
  20. kumoai/experimental/rfm/base/sql_sampler.py +320 -19
  21. kumoai/experimental/rfm/base/table.py +256 -109
  22. kumoai/experimental/rfm/base/utils.py +36 -0
  23. kumoai/experimental/rfm/graph.py +130 -110
  24. kumoai/experimental/rfm/infer/dtype.py +7 -2
  25. kumoai/experimental/rfm/infer/multicategorical.py +1 -1
  26. kumoai/experimental/rfm/infer/time_col.py +4 -2
  27. kumoai/experimental/rfm/relbench.py +76 -0
  28. kumoai/experimental/rfm/rfm.py +540 -306
  29. kumoai/experimental/rfm/task_table.py +292 -0
  30. kumoai/kumolib.cp313-win_amd64.pyd +0 -0
  31. kumoai/pquery/training_table.py +16 -2
  32. kumoai/testing/snow.py +3 -3
  33. kumoai/trainer/distilled_trainer.py +175 -0
  34. kumoai/utils/display.py +87 -0
  35. kumoai/utils/progress_logger.py +15 -2
  36. kumoai/utils/sql.py +2 -2
  37. {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.15.0.dev202601151732.dist-info}/METADATA +2 -2
  38. {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.15.0.dev202601151732.dist-info}/RECORD +41 -36
  39. kumoai/experimental/rfm/base/column_expression.py +0 -50
  40. kumoai/experimental/rfm/base/sql_table.py +0 -229
  41. {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.15.0.dev202601151732.dist-info}/WHEEL +0 -0
  42. {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.15.0.dev202601151732.dist-info}/licenses/LICENSE +0 -0
  43. {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.15.0.dev202601151732.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,292 @@
1
+ import copy
2
+ from collections.abc import Sequence
3
+
4
+ import pandas as pd
5
+ from kumoapi.task import TaskType
6
+ from kumoapi.typing import Dtype, Stype
7
+ from typing_extensions import Self
8
+
9
+ from kumoai.experimental.rfm.base import Column
10
+ from kumoai.experimental.rfm.infer import contains_timestamp, infer_dtype
11
+
12
+
13
+ class TaskTable:
14
+ r"""A :class:`TaskTable` fully specifies the task, *i.e.* its context and
15
+ prediction examples with entity IDs, targets and timestamps.
16
+
17
+ Args:
18
+ task_type: The task type.
19
+ context_df: The data frame holding context examples.
20
+ pred_df: The data frame holding prediction examples.
21
+ entity_table_name: The entity table to predict for. For link prediction
22
+ tasks, needs to hold both entity and target table names.
23
+ entity_column: The name of the entity column.
24
+ target_column: The name of the target column.
25
+ time_column: The name of the time column to use as anchor time. If
26
+ ``TaskTable.ENTITY_TIME``, use the timestamp of the entity table
27
+ as anchor time.
28
+ """
29
+ ENTITY_TIME = '__entity_time__'
30
+
31
+ def __init__(
32
+ self,
33
+ task_type: TaskType,
34
+ context_df: pd.DataFrame,
35
+ pred_df: pd.DataFrame,
36
+ entity_table_name: str | Sequence[str],
37
+ entity_column: str,
38
+ target_column: str,
39
+ time_column: str | None = None,
40
+ ) -> None:
41
+
42
+ task_type = TaskType(task_type)
43
+ if task_type not in { # Currently supported task types:
44
+ TaskType.BINARY_CLASSIFICATION,
45
+ TaskType.MULTICLASS_CLASSIFICATION,
46
+ TaskType.REGRESSION,
47
+ TaskType.TEMPORAL_LINK_PREDICTION,
48
+ }:
49
+ raise ValueError # TODO
50
+ self._task_type = task_type
51
+
52
+ # TODO Binary classification and regression checks
53
+
54
+ # TODO Check dfs (unify from local table)
55
+ if context_df.empty:
56
+ raise ValueError("No context examples given")
57
+ self._context_df = context_df.copy(deep=False)
58
+
59
+ if pred_df.empty:
60
+ raise ValueError("Provide at least one entity to predict for")
61
+ self._pred_df = pred_df.copy(deep=False)
62
+
63
+ self._dtype_dict: dict[str, Dtype] = {
64
+ column_name: infer_dtype(context_df[column_name])
65
+ for column_name in context_df.columns
66
+ }
67
+
68
+ self._entity_table_names: tuple[str] | tuple[str, str]
69
+ if isinstance(entity_table_name, str):
70
+ self._entity_table_names = (entity_table_name, )
71
+ elif len(entity_table_name) == 1:
72
+ self._entity_table_names = (entity_table_name[0], )
73
+ elif len(entity_table_name) == 2:
74
+ self._entity_table_names = (
75
+ entity_table_name[0],
76
+ entity_table_name[1],
77
+ )
78
+ else:
79
+ raise ValueError # TODO
80
+
81
+ self._entity_column: str = ''
82
+ self._target_column: str = ''
83
+ self._time_column: str | None = None
84
+
85
+ self.entity_column = entity_column
86
+ self.target_column = target_column
87
+ if time_column is not None:
88
+ self.time_column = time_column
89
+
90
+ self._query: str = '' # A description of the task, e.g., for XAI.
91
+
92
+ @property
93
+ def num_context_examples(self) -> int:
94
+ return len(self._context_df)
95
+
96
+ @property
97
+ def num_prediction_examples(self) -> int:
98
+ return len(self._pred_df)
99
+
100
+ @property
101
+ def task_type(self) -> TaskType:
102
+ r"""The task type."""
103
+ return self._task_type
104
+
105
+ def narrow_context(self, start: int, length: int) -> Self:
106
+ r"""Returns a new :class:`TaskTable` that holds a narrowed version of
107
+ context examples.
108
+
109
+ Args:
110
+ start: Index of the prediction examples to start narrowing.
111
+ length: Length of the prediction examples.
112
+ """
113
+ out = copy.copy(self)
114
+ df = out._context_df.iloc[start:start + length].reset_index(drop=True)
115
+ out._context_df = df
116
+ return out
117
+
118
+ def narrow_prediction(self, start: int, length: int) -> Self:
119
+ r"""Returns a new :class:`TaskTable` that holds a narrowed version of
120
+ prediction examples.
121
+
122
+ Args:
123
+ start: Index of the prediction examples to start narrowing.
124
+ length: Length of the prediction examples.
125
+ """
126
+ out = copy.copy(self)
127
+ df = out._pred_df.iloc[start:start + length].reset_index(drop=True)
128
+ out._pred_df = df
129
+ return out
130
+
131
+ # Entity column ###########################################################
132
+
133
+ @property
134
+ def entity_table_name(self) -> str:
135
+ return self._entity_table_names[0]
136
+
137
+ @property
138
+ def entity_table_names(self) -> tuple[str] | tuple[str, str]:
139
+ return self._entity_table_names
140
+
141
+ @property
142
+ def entity_column(self) -> Column:
143
+ return Column(
144
+ name=self._entity_column,
145
+ expr=None,
146
+ dtype=self._dtype_dict[self._entity_column],
147
+ stype=Stype.ID,
148
+ )
149
+
150
+ @entity_column.setter
151
+ def entity_column(self, name: str) -> None:
152
+ if name not in self._context_df:
153
+ raise ValueError # TODO
154
+ if name not in self._pred_df:
155
+ raise ValueError # TODO
156
+ if not Stype.ID.supports_dtype(self._dtype_dict[name]):
157
+ raise ValueError # TODO
158
+
159
+ self._entity_column = name
160
+
161
+ # Target column ###########################################################
162
+
163
+ @property
164
+ def evaluate(self) -> bool:
165
+ r"""Returns ``True`` if this task can be used for model evaluation."""
166
+ return self._target_column in self._pred_df
167
+
168
+ @property
169
+ def _target_stype(self) -> Stype:
170
+ if self.task_type in {
171
+ TaskType.BINARY_CLASSIFICATION,
172
+ TaskType.MULTICLASS_CLASSIFICATION,
173
+ }:
174
+ return Stype.categorical
175
+ if self.task_type in {TaskType.REGRESSION}:
176
+ return Stype.numerical
177
+ if self.task_type.is_link_pred:
178
+ return Stype.multicategorical
179
+ raise ValueError
180
+
181
+ @property
182
+ def target_column(self) -> Column:
183
+ return Column(
184
+ name=self._target_column,
185
+ expr=None,
186
+ dtype=self._dtype_dict[self._target_column],
187
+ stype=self._target_stype,
188
+ )
189
+
190
+ @target_column.setter
191
+ def target_column(self, name: str) -> None:
192
+ if name not in self._context_df:
193
+ raise ValueError # TODO
194
+ if not self._target_stype.supports_dtype(self._dtype_dict[name]):
195
+ raise ValueError # TODO
196
+
197
+ self._target_column = name
198
+
199
+ # Time column #############################################################
200
+
201
+ def has_time_column(self) -> bool:
202
+ r"""Returns ``True`` if this task has a time column; ``False``
203
+ otherwise.
204
+ """
205
+ return self._time_column not in {None, self.ENTITY_TIME}
206
+
207
+ @property
208
+ def use_entity_time(self) -> bool:
209
+ r"""Whether to use the timestamp of the entity table as anchor time."""
210
+ return self._time_column == self.ENTITY_TIME
211
+
212
+ @property
213
+ def time_column(self) -> Column | None:
214
+ r"""The time column of this task.
215
+
216
+ The getter returns the time column of this task, or ``None`` if no
217
+ such time column is present.
218
+
219
+ The setter sets a column as a time column for this task, and raises a
220
+ :class:`ValueError` if the time column has a non-timestamp compatible
221
+ data type or if the column name does not match a column in the data
222
+ frame.
223
+ """
224
+ if not self.has_time_column():
225
+ return None
226
+ assert self._time_column is not None
227
+ return Column(
228
+ name=self._time_column,
229
+ expr=None,
230
+ dtype=self._dtype_dict[self._time_column],
231
+ stype=Stype.timestamp,
232
+ )
233
+
234
+ @time_column.setter
235
+ def time_column(self, name: str | None) -> None:
236
+ if name is None or name == self.ENTITY_TIME:
237
+ self._time_column = name
238
+ return
239
+
240
+ if name not in self._context_df:
241
+ raise ValueError # TODO
242
+ if name not in self._pred_df:
243
+ raise ValueError # TODO
244
+ if not contains_timestamp(
245
+ ser=self._context_df[name],
246
+ column_name=name,
247
+ dtype=self._dtype_dict[name],
248
+ ):
249
+ raise ValueError # TODO
250
+
251
+ self._time_column = name
252
+
253
+ # Metadata ################################################################
254
+
255
+ @property
256
+ def metadata(self) -> pd.DataFrame:
257
+ raise NotImplementedError
258
+
259
+ def print_metadata(self) -> None:
260
+ raise NotImplementedError
261
+
262
+ # Python builtins #########################################################
263
+
264
+ def __hash__(self) -> int:
265
+ return hash((
266
+ self.task_type,
267
+ self.entity_table_names,
268
+ self._entity_column,
269
+ self._target_column,
270
+ self._time_column,
271
+ ))
272
+
273
+ def __repr__(self) -> str:
274
+ if self.task_type.is_link_pred:
275
+ entity_table_repr = f'entity_table_names={self.entity_table_names}'
276
+ else:
277
+ entity_table_repr = f'entity_table_name={self.entity_table_name}'
278
+
279
+ if self.use_entity_time:
280
+ time_repr = 'use_entity_time=True'
281
+ else:
282
+ time_repr = f'time_column={self._time_column}'
283
+
284
+ return (f'{self.__class__.__name__}(\n'
285
+ f' task_type={self.task_type},\n'
286
+ f' num_context_examples={self.num_context_examples},\n'
287
+ f' num_prediction_examples={self.num_prediction_examples},\n'
288
+ f' {entity_table_repr},\n'
289
+ f' entity_column={self._entity_column},\n'
290
+ f' target_column={self._target_column},\n'
291
+ f' {time_repr},\n'
292
+ f')')
Binary file
@@ -199,6 +199,7 @@ class TrainingTable:
199
199
  self,
200
200
  source_table_type: SourceTableType,
201
201
  train_table_mod: TrainingTableSpec,
202
+ extensive_validation: bool = False,
202
203
  ) -> None:
203
204
  r"""Validates the modified training table.
204
205
 
@@ -206,6 +207,8 @@ class TrainingTable:
206
207
  source_table_type: The source table to be used as the modified
207
208
  training table.
208
209
  train_table_mod: The modification specification.
210
+ extensive_validation: Enable extensive validation for custom
211
+ table.
209
212
 
210
213
  Raises:
211
214
  ValueError: If the modified training table is invalid.
@@ -215,7 +218,8 @@ class TrainingTable:
215
218
  global_state.client.generate_train_table_job_api)
216
219
  response = api.validate_custom_train_table(self.job_id,
217
220
  source_table_type,
218
- train_table_mod)
221
+ train_table_mod,
222
+ extensive_validation)
219
223
  if not response.ok:
220
224
  raise ValueError("Invalid weighted train table",
221
225
  response.error_message)
@@ -225,6 +229,7 @@ class TrainingTable:
225
229
  source_table: SourceTable,
226
230
  train_table_mod: TrainingTableSpec,
227
231
  validate: bool = True,
232
+ extensive_validation: bool = False,
228
233
  ) -> Self:
229
234
  r"""Sets the `source_table` as the modified training table.
230
235
 
@@ -243,6 +248,9 @@ class TrainingTable:
243
248
  train_table_mod: The modification specification.
244
249
  validate: Whether to validate the modified training table. This can
245
250
  be slow for large tables.
251
+ extensive_validation: Whether to validate number of rows in
252
+ existing and modified training table.
253
+ It can be slow for large tables.
246
254
  """
247
255
  if isinstance(source_table.connector, S3Connector):
248
256
  # Special handling for s3 as `source_table._to_api_source_table`
@@ -252,7 +260,13 @@ class TrainingTable:
252
260
  else:
253
261
  source_table_type = source_table._to_api_source_table()
254
262
  if validate:
255
- self.validate_custom_table(source_table_type, train_table_mod)
263
+ if extensive_validation:
264
+ logger.warning(
265
+ "You have opted in to perform extensive validation on"
266
+ " your custom training table."
267
+ " This operation can be slow for large tables.")
268
+ self.validate_custom_table(source_table_type, train_table_mod,
269
+ extensive_validation)
256
270
  self._custom_train_table = CustomTrainingTable(
257
271
  source_table=source_table_type, table_mod_spec=train_table_mod,
258
272
  validated=validate)
kumoai/testing/snow.py CHANGED
@@ -10,7 +10,7 @@ def connect(
10
10
  id: str,
11
11
  account: str,
12
12
  user: str,
13
- warehouse: str,
13
+ warehouse: str | None = None,
14
14
  database: str | None = None,
15
15
  schema: str | None = None,
16
16
  ) -> Connection:
@@ -42,8 +42,8 @@ def connect(
42
42
  return _connect(
43
43
  account=account,
44
44
  user=user,
45
- warehouse='WH_XS',
46
- database='KUMO',
45
+ warehouse=warehouse or 'WH_XS',
46
+ database=database or 'KUMO',
47
47
  schema=schema,
48
48
  session_parameters=dict(CLIENT_TELEMETRY_ENABLED=False),
49
49
  **kwargs,
@@ -0,0 +1,175 @@
1
+ import logging
2
+ from typing import Literal, Mapping, Optional, Union, overload
3
+
4
+ from kumoapi.distilled_model_plan import DistilledModelPlan
5
+ from kumoapi.jobs import DistillationJobRequest, DistillationJobResource
6
+
7
+ from kumoai import global_state
8
+ from kumoai.client.jobs import TrainingJobID
9
+ from kumoai.graph import Graph
10
+ from kumoai.pquery.training_table import TrainingTable, TrainingTableJob
11
+ from kumoai.trainer.job import TrainingJob, TrainingJobResult
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class DistillationTrainer:
17
+ r"""A trainer supports creating a Kumo machine learning model
18
+ for use in an online serving endpoint. The distllation process involes
19
+ training a shallow model on a :class:`~kumoai.pquery.PredictiveQuery` using
20
+ the embeddings generated by a base model :args:`base_training_job_id`.
21
+
22
+ Args:
23
+ model_plan: The distilled model plan to use for the distillation process.
24
+ base_training_job_id: The ID of the base training job to use for the distillation process.
25
+ """ # noqa: E501
26
+
27
+ def __init__(
28
+ self,
29
+ model_plan: DistilledModelPlan,
30
+ base_training_job_id: TrainingJobID,
31
+ ) -> None:
32
+ self.model_plan: DistilledModelPlan = model_plan
33
+ self.base_training_job_id: TrainingJobID = base_training_job_id
34
+
35
+ # Cached from backend:
36
+ self._training_job_id: Optional[TrainingJobID] = None
37
+
38
+ # Metadata ################################################################
39
+
40
+ @property
41
+ def is_trained(self) -> bool:
42
+ r"""Returns ``True`` if this trainer instance has successfully been
43
+ trained (and is therefore ready for prediction); ``False`` otherwise.
44
+ """
45
+ raise NotImplementedError(
46
+ "Checking if a distilled trainer is trained is not "
47
+ "implemented yet.")
48
+
49
+ @overload
50
+ def fit(
51
+ self,
52
+ graph: Graph,
53
+ train_table: Union[TrainingTable, TrainingTableJob],
54
+ ) -> TrainingJobResult:
55
+ pass
56
+
57
+ @overload
58
+ def fit(
59
+ self,
60
+ graph: Graph,
61
+ train_table: Union[TrainingTable, TrainingTableJob],
62
+ *,
63
+ non_blocking: Literal[False],
64
+ ) -> TrainingJobResult:
65
+ pass
66
+
67
+ @overload
68
+ def fit(
69
+ self,
70
+ graph: Graph,
71
+ train_table: Union[TrainingTable, TrainingTableJob],
72
+ *,
73
+ non_blocking: Literal[True],
74
+ ) -> TrainingJob:
75
+ pass
76
+
77
+ @overload
78
+ def fit(
79
+ self,
80
+ graph: Graph,
81
+ train_table: Union[TrainingTable, TrainingTableJob],
82
+ *,
83
+ non_blocking: bool,
84
+ ) -> Union[TrainingJob, TrainingJobResult]:
85
+ pass
86
+
87
+ def fit(
88
+ self,
89
+ graph: Graph,
90
+ train_table: Union[TrainingTable, TrainingTableJob],
91
+ *,
92
+ non_blocking: bool = False,
93
+ custom_tags: Mapping[str, str] = {},
94
+ ) -> Union[TrainingJob, TrainingJobResult]:
95
+ r"""Fits a model to the specified graph and training table, with the
96
+ strategy defined by :class:`DistilledTrainer`'s :obj:`model_plan`.
97
+
98
+ Args:
99
+ graph: The :class:`~kumoai.graph.Graph` object that represents the
100
+ tables and relationships that Kumo will learn from.
101
+ train_table: The :class:`~kumoai.pquery.TrainingTable`, or
102
+ in-progress :class:`~kumoai.pquery.TrainingTableJob`, that
103
+ represents the training data produced by a
104
+ :class:`~kumoai.pquery.PredictiveQuery` on :obj:`graph`.
105
+ non_blocking: Whether this operation should return immediately
106
+ after launching the training job, or await completion of the
107
+ training job.
108
+ custom_tags: Additional, customer defined k-v tags to be associated
109
+ with the job to be launched. Job tags are useful for grouping
110
+ and searching jobs.
111
+
112
+ Returns:
113
+ Union[TrainingJobResult, TrainingJob]:
114
+ If ``non_blocking=False``, returns a training job object. If
115
+ ``non_blocking=True``, returns a training job future object.
116
+ """
117
+ # TODO(manan, siyang): remove soon:
118
+ job_id = train_table.job_id
119
+ assert job_id is not None
120
+
121
+ train_table_job_api = global_state.client.generate_train_table_job_api
122
+ pq_id = train_table_job_api.get(job_id).config.pquery_id
123
+ assert pq_id is not None
124
+
125
+ custom_table = None
126
+ if isinstance(train_table, TrainingTable):
127
+ custom_table = train_table._custom_train_table
128
+
129
+ # NOTE the backend implementation currently handles sequentialization
130
+ # between a training table future and a training job; that is, if the
131
+ # training table future is still executing, the backend will wait on
132
+ # the job ID completion before executing a training job. This preserves
133
+ # semantics for both futures, ensures that Kumo works as expected if
134
+ # used only via REST API, and allows us to avoid chaining calllbacks
135
+ # in an ugly way here:
136
+ api = global_state.client.distillation_job_api
137
+ self._training_job_id = api.create(
138
+ DistillationJobRequest(
139
+ dict(custom_tags),
140
+ pquery_id=pq_id,
141
+ base_training_job_id=self.base_training_job_id,
142
+ distilled_model_plan=self.model_plan,
143
+ graph_snapshot_id=graph.snapshot(non_blocking=non_blocking),
144
+ train_table_job_id=job_id,
145
+ custom_train_table=custom_table,
146
+ ))
147
+
148
+ out = TrainingJob(job_id=self._training_job_id)
149
+ if non_blocking:
150
+ return out
151
+ return out.attach()
152
+
153
+ @classmethod
154
+ def _load_from_job(
155
+ cls,
156
+ job: DistillationJobResource,
157
+ ) -> 'DistillationTrainer':
158
+ trainer = cls(job.config.distilled_model_plan,
159
+ job.config.base_training_job_id)
160
+ trainer._training_job_id = job.job_id
161
+ return trainer
162
+
163
+ @classmethod
164
+ def load(cls, job_id: TrainingJobID) -> 'DistillationTrainer':
165
+ r"""Creates a :class:`~kumoai.trainer.Trainer` instance from a training
166
+ job ID.
167
+ """
168
+ raise NotImplementedError(
169
+ "Loading a distilled trainer from a job ID is not implemented yet."
170
+ )
171
+
172
+ @classmethod
173
+ def load_from_tags(cls, tags: Mapping[str, str]) -> 'DistillationTrainer':
174
+ raise NotImplementedError(
175
+ "Loading a distilled trainer from tags is not implemented yet.")
@@ -0,0 +1,87 @@
1
+ from collections.abc import Sequence
2
+
3
+ import pandas as pd
4
+ from rich import box
5
+ from rich.console import Console
6
+ from rich.table import Table
7
+ from rich.text import Text
8
+
9
+ from kumoai import in_notebook, in_snowflake_notebook
10
+
11
+
12
+ def message(msg: str) -> None:
13
+ if in_snowflake_notebook():
14
+ import streamlit as st
15
+ st.markdown(msg)
16
+ elif in_notebook():
17
+ from IPython.display import Markdown, display
18
+ display(Markdown(msg))
19
+ else:
20
+ print(msg.replace("`", "'"))
21
+
22
+
23
+ def title(msg: str) -> None:
24
+ if in_notebook():
25
+ message(f"### {msg}")
26
+ else:
27
+ msg = msg.replace("`", "'")
28
+ Console().print(f"[bold]{msg}[/bold]", highlight=False)
29
+
30
+
31
+ def italic(msg: str) -> None:
32
+ if in_notebook():
33
+ message(f"*{msg}*")
34
+ else:
35
+ msg = msg.replace("`", "'")
36
+ Console().print(
37
+ f"[italic]{msg}[/italic]",
38
+ highlight=False,
39
+ style='dim',
40
+ )
41
+
42
+
43
+ def unordered_list(items: Sequence[str]) -> None:
44
+ if in_notebook():
45
+ msg = '\n'.join([f"- {item}" for item in items])
46
+ message(msg)
47
+ else:
48
+ text = Text('\n').join(
49
+ Text.assemble(
50
+ Text(' • ', style='yellow'),
51
+ Text(item.replace('`', '')),
52
+ ) for item in items)
53
+ Console().print(text, highlight=False)
54
+
55
+
56
+ def dataframe(df: pd.DataFrame) -> None:
57
+ if in_snowflake_notebook():
58
+ import streamlit as st
59
+ st.dataframe(df, hide_index=True)
60
+ elif in_notebook():
61
+ from IPython.display import display
62
+ try:
63
+ if hasattr(df.style, 'hide'):
64
+ display(df.style.hide(axis='index')) # pandas=2
65
+ else:
66
+ display(df.style.hide_index()) # pandas<1.3
67
+ except ImportError:
68
+ print(df.to_string(index=False)) # missing jinja2
69
+ else:
70
+ Console().print(to_rich_table(df))
71
+
72
+
73
+ def to_rich_table(df: pd.DataFrame) -> Table:
74
+ table = Table(box=box.ROUNDED)
75
+ for column in df.columns:
76
+ table.add_column(str(column))
77
+ for _, row in df.iterrows():
78
+ values: list[str | Text] = []
79
+ for value in row:
80
+ if str(value) == 'True':
81
+ values.append('✅')
82
+ elif str(value) in {'False', '-'}:
83
+ values.append(Text('-', style='dim'))
84
+ else:
85
+ values.append(str(value))
86
+ table.add_row(*values)
87
+ return table
@@ -24,6 +24,7 @@ class ProgressLogger:
24
24
  def __init__(self, msg: str, verbose: bool = True) -> None:
25
25
  self.msg = msg
26
26
  self.verbose = verbose
27
+ self.depth = 0
27
28
 
28
29
  self.logs: list[str] = []
29
30
 
@@ -55,10 +56,13 @@ class ProgressLogger:
55
56
  pass
56
57
 
57
58
  def __enter__(self) -> Self:
58
- self.start_time = time.perf_counter()
59
+ self.depth += 1
60
+ if self.depth == 1:
61
+ self.start_time = time.perf_counter()
59
62
  return self
60
63
 
61
64
  def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
65
+ self.depth -= 1
62
66
  self.end_time = time.perf_counter()
63
67
 
64
68
  def __repr__(self) -> str:
@@ -123,6 +127,9 @@ class RichProgressLogger(ProgressLogger):
123
127
 
124
128
  super().__enter__()
125
129
 
130
+ if self.depth > 1:
131
+ return self
132
+
126
133
  if not in_notebook(): # Render progress bar in TUI.
127
134
  sys.stdout.write("\x1b]9;4;3\x07")
128
135
  sys.stdout.flush()
@@ -142,6 +149,9 @@ class RichProgressLogger(ProgressLogger):
142
149
 
143
150
  super().__exit__(exc_type, exc_val, exc_tb)
144
151
 
152
+ if self.depth > 1:
153
+ return
154
+
145
155
  if exc_type is not None:
146
156
  self._exception = True
147
157
 
@@ -213,6 +223,9 @@ class StreamlitProgressLogger(ProgressLogger):
213
223
 
214
224
  import streamlit as st
215
225
 
226
+ if self.depth > 1:
227
+ return self
228
+
216
229
  # Adjust layout for prettier output:
217
230
  st.markdown(STREAMLIT_CSS, unsafe_allow_html=True)
218
231
 
@@ -253,7 +266,7 @@ class StreamlitProgressLogger(ProgressLogger):
253
266
  def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
254
267
  super().__exit__(exc_type, exc_val, exc_tb)
255
268
 
256
- if not self.verbose or self._status is None:
269
+ if not self.verbose or self._status is None or self.depth > 1:
257
270
  return
258
271
 
259
272
  label = f'{self._sanitize_text(self.msg)} ({self.duration:.2f}s)'