kumoai 2.14.0.dev202512181731__cp312-cp312-macosx_11_0_arm64.whl → 2.14.0.dev202601041732__cp312-cp312-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +23 -26
- kumoai/_version.py +1 -1
- kumoai/client/client.py +6 -0
- kumoai/client/jobs.py +24 -0
- kumoai/connector/utils.py +21 -7
- kumoai/experimental/rfm/__init__.py +24 -22
- kumoai/experimental/rfm/backend/local/graph_store.py +12 -21
- kumoai/experimental/rfm/backend/local/sampler.py +0 -3
- kumoai/experimental/rfm/backend/local/table.py +25 -24
- kumoai/experimental/rfm/backend/snow/sampler.py +106 -61
- kumoai/experimental/rfm/backend/snow/table.py +146 -51
- kumoai/experimental/rfm/backend/sqlite/sampler.py +127 -78
- kumoai/experimental/rfm/backend/sqlite/table.py +94 -47
- kumoai/experimental/rfm/base/__init__.py +6 -7
- kumoai/experimental/rfm/base/column.py +97 -5
- kumoai/experimental/rfm/base/expression.py +44 -0
- kumoai/experimental/rfm/base/sampler.py +5 -17
- kumoai/experimental/rfm/base/source.py +1 -1
- kumoai/experimental/rfm/base/sql_sampler.py +68 -9
- kumoai/experimental/rfm/base/table.py +291 -126
- kumoai/experimental/rfm/graph.py +139 -86
- kumoai/experimental/rfm/infer/__init__.py +6 -4
- kumoai/experimental/rfm/infer/dtype.py +6 -1
- kumoai/experimental/rfm/infer/multicategorical.py +1 -1
- kumoai/experimental/rfm/infer/stype.py +35 -0
- kumoai/experimental/rfm/relbench.py +76 -0
- kumoai/experimental/rfm/rfm.py +30 -42
- kumoai/experimental/rfm/task_table.py +247 -0
- kumoai/trainer/distilled_trainer.py +175 -0
- kumoai/utils/display.py +51 -0
- {kumoai-2.14.0.dev202512181731.dist-info → kumoai-2.14.0.dev202601041732.dist-info}/METADATA +1 -1
- {kumoai-2.14.0.dev202512181731.dist-info → kumoai-2.14.0.dev202601041732.dist-info}/RECORD +35 -31
- kumoai/experimental/rfm/base/column_expression.py +0 -16
- kumoai/experimental/rfm/base/sql_table.py +0 -113
- {kumoai-2.14.0.dev202512181731.dist-info → kumoai-2.14.0.dev202601041732.dist-info}/WHEEL +0 -0
- {kumoai-2.14.0.dev202512181731.dist-info → kumoai-2.14.0.dev202601041732.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.14.0.dev202512181731.dist-info → kumoai-2.14.0.dev202601041732.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,247 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
from collections.abc import Sequence
|
|
3
|
+
|
|
4
|
+
import pandas as pd
|
|
5
|
+
from kumoapi.task import TaskType
|
|
6
|
+
from kumoapi.typing import Dtype, Stype
|
|
7
|
+
from typing_extensions import Self
|
|
8
|
+
|
|
9
|
+
from kumoai.experimental.rfm.base import Column
|
|
10
|
+
from kumoai.experimental.rfm.infer import contains_timestamp, infer_dtype
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class TaskTable:
|
|
14
|
+
r"""A :class:`TaskTable` fully specifies the task, *i.e.* its context and
|
|
15
|
+
prediction examples with entity IDs, targets and timestamps.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
task_type: The task type.
|
|
19
|
+
context_df: The data frame holding context examples.
|
|
20
|
+
pred_df: The data frame holding prediction examples.
|
|
21
|
+
entity_table_name: The entity table to predict for. For link prediction
|
|
22
|
+
tasks, needs to hold both entity and target table names.
|
|
23
|
+
entity_column: The name of the entity column.
|
|
24
|
+
target_column: The name of the target column.
|
|
25
|
+
time_column: The name of the time column, if it exists.
|
|
26
|
+
"""
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
task_type: TaskType,
|
|
30
|
+
context_df: pd.DataFrame,
|
|
31
|
+
pred_df: pd.DataFrame,
|
|
32
|
+
entity_table_name: str | Sequence[str],
|
|
33
|
+
entity_column: str,
|
|
34
|
+
target_column: str,
|
|
35
|
+
time_column: str | None = None,
|
|
36
|
+
) -> None:
|
|
37
|
+
|
|
38
|
+
task_type = TaskType(task_type)
|
|
39
|
+
if task_type not in { # Currently supported task types:
|
|
40
|
+
TaskType.BINARY_CLASSIFICATION,
|
|
41
|
+
TaskType.MULTICLASS_CLASSIFICATION,
|
|
42
|
+
TaskType.REGRESSION,
|
|
43
|
+
TaskType.TEMPORAL_LINK_PREDICTION,
|
|
44
|
+
}:
|
|
45
|
+
raise ValueError # TODO
|
|
46
|
+
self._task_type = task_type
|
|
47
|
+
|
|
48
|
+
# TODO Check dfs (unify from local table)
|
|
49
|
+
self._context_df = context_df.copy(deep=False)
|
|
50
|
+
self._pred_df = pred_df.copy(deep=False)
|
|
51
|
+
|
|
52
|
+
self._dtype_dict: dict[str, Dtype] = {
|
|
53
|
+
column_name: infer_dtype(context_df[column_name])
|
|
54
|
+
for column_name in context_df.columns
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
self._entity_table_names: tuple[str] | tuple[str, str]
|
|
58
|
+
if isinstance(entity_table_name, str):
|
|
59
|
+
self._entity_table_names = (entity_table_name, )
|
|
60
|
+
elif len(entity_table_name) == 1:
|
|
61
|
+
self._entity_table_names = (entity_table_name[0], )
|
|
62
|
+
elif len(entity_table_name) == 2:
|
|
63
|
+
self._entity_table_names = (
|
|
64
|
+
entity_table_name[0],
|
|
65
|
+
entity_table_name[1],
|
|
66
|
+
)
|
|
67
|
+
else:
|
|
68
|
+
raise ValueError # TODO
|
|
69
|
+
|
|
70
|
+
self._entity_column: str = ''
|
|
71
|
+
self._target_column: str = ''
|
|
72
|
+
self._time_column: str | None = None
|
|
73
|
+
|
|
74
|
+
self.entity_column = entity_column
|
|
75
|
+
self.target_column = target_column
|
|
76
|
+
if time_column is not None:
|
|
77
|
+
self.time_column = time_column
|
|
78
|
+
|
|
79
|
+
@property
|
|
80
|
+
def task_type(self) -> TaskType:
|
|
81
|
+
r"""The task type."""
|
|
82
|
+
return self._task_type
|
|
83
|
+
|
|
84
|
+
def narrow(self, start: int, length: int) -> Self:
|
|
85
|
+
r"""Returns a new :class:`TaskTable` that holds a narrowed version of
|
|
86
|
+
prediction examples.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
start: Index of the prediction examples to start narrowing.
|
|
90
|
+
length: Length of the prediction examples.
|
|
91
|
+
"""
|
|
92
|
+
out = copy.copy(self)
|
|
93
|
+
df = out._pred_df.iloc[start:start + length].reset_index(drop=True)
|
|
94
|
+
out._pred_df = df
|
|
95
|
+
return out
|
|
96
|
+
|
|
97
|
+
# Entity column ###########################################################
|
|
98
|
+
|
|
99
|
+
@property
|
|
100
|
+
def entity_table_name(self) -> str:
|
|
101
|
+
return self._entity_table_names[0]
|
|
102
|
+
|
|
103
|
+
@property
|
|
104
|
+
def entity_table_names(self) -> tuple[str] | tuple[str, str]:
|
|
105
|
+
return self._entity_table_names
|
|
106
|
+
|
|
107
|
+
@property
|
|
108
|
+
def entity_column(self) -> Column:
|
|
109
|
+
return Column(
|
|
110
|
+
name=self._entity_column,
|
|
111
|
+
expr=None,
|
|
112
|
+
dtype=self._dtype_dict[self._entity_column],
|
|
113
|
+
stype=Stype.ID,
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
@entity_column.setter
|
|
117
|
+
def entity_column(self, name: str) -> None:
|
|
118
|
+
if name in {self._target_column, self._time_column}:
|
|
119
|
+
raise ValueError # TODO
|
|
120
|
+
if name not in self._context_df:
|
|
121
|
+
raise ValueError # TODO
|
|
122
|
+
if name not in self._pred_df:
|
|
123
|
+
raise ValueError # TODO
|
|
124
|
+
if not Stype.ID.supports_dtype(self._dtype_dict[name]):
|
|
125
|
+
raise ValueError # TODO
|
|
126
|
+
|
|
127
|
+
self._entity_column = name
|
|
128
|
+
|
|
129
|
+
# Target column ###########################################################
|
|
130
|
+
|
|
131
|
+
@property
|
|
132
|
+
def _target_stype(self) -> Stype:
|
|
133
|
+
if self.task_type in {
|
|
134
|
+
TaskType.BINARY_CLASSIFICATION,
|
|
135
|
+
TaskType.MULTICLASS_CLASSIFICATION,
|
|
136
|
+
}:
|
|
137
|
+
return Stype.categorical
|
|
138
|
+
if self.task_type in {TaskType.REGRESSION}:
|
|
139
|
+
return Stype.numerical
|
|
140
|
+
if self.task_type.is_link_pred:
|
|
141
|
+
return Stype.multicategorical
|
|
142
|
+
raise ValueError
|
|
143
|
+
|
|
144
|
+
@property
|
|
145
|
+
def target_column(self) -> Column:
|
|
146
|
+
return Column(
|
|
147
|
+
name=self._target_column,
|
|
148
|
+
expr=None,
|
|
149
|
+
dtype=self._dtype_dict[self._target_column],
|
|
150
|
+
stype=self._target_stype,
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
@target_column.setter
|
|
154
|
+
def target_column(self, name: str) -> None:
|
|
155
|
+
if name in {self._entity_column, self._time_column}:
|
|
156
|
+
raise ValueError # TODO
|
|
157
|
+
if name not in self._context_df:
|
|
158
|
+
raise ValueError # TODO
|
|
159
|
+
if not self._target_stype.supports_dtype(self._dtype_dict[name]):
|
|
160
|
+
raise ValueError # TODO
|
|
161
|
+
|
|
162
|
+
self._target_column = name
|
|
163
|
+
|
|
164
|
+
# Time column #############################################################
|
|
165
|
+
|
|
166
|
+
def has_time_column(self) -> bool:
|
|
167
|
+
r"""Returns ``True`` if this task has a time column; ``False``
|
|
168
|
+
otherwise.
|
|
169
|
+
"""
|
|
170
|
+
return self._time_column is not None
|
|
171
|
+
|
|
172
|
+
@property
|
|
173
|
+
def time_column(self) -> Column | None:
|
|
174
|
+
r"""The time column of this task.
|
|
175
|
+
|
|
176
|
+
The getter returns the time column of this task, or ``None`` if no
|
|
177
|
+
such time column is present.
|
|
178
|
+
|
|
179
|
+
The setter sets a column as a time column for this task, and raises a
|
|
180
|
+
:class:`ValueError` if the time column has a non-timestamp compatible
|
|
181
|
+
data type or if the column name does not match a column in the data
|
|
182
|
+
frame.
|
|
183
|
+
"""
|
|
184
|
+
if self._time_column is None:
|
|
185
|
+
return None
|
|
186
|
+
return Column(
|
|
187
|
+
name=self._time_column,
|
|
188
|
+
expr=None,
|
|
189
|
+
dtype=self._dtype_dict[self._time_column],
|
|
190
|
+
stype=Stype.timestamp,
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
@time_column.setter
|
|
194
|
+
def time_column(self, name: str | None) -> None:
|
|
195
|
+
if name is None:
|
|
196
|
+
self._time_column = None
|
|
197
|
+
return
|
|
198
|
+
|
|
199
|
+
if name in {self._entity_column, self._target_column}:
|
|
200
|
+
raise ValueError # TODO
|
|
201
|
+
if name not in self._context_df:
|
|
202
|
+
raise ValueError # TODO
|
|
203
|
+
if name not in self._pred_df:
|
|
204
|
+
raise ValueError # TODO
|
|
205
|
+
if not contains_timestamp(
|
|
206
|
+
ser=self._context_df[name],
|
|
207
|
+
column_name=name,
|
|
208
|
+
dtype=self._dtype_dict[name],
|
|
209
|
+
):
|
|
210
|
+
raise ValueError # TODO
|
|
211
|
+
|
|
212
|
+
self._time_column = name
|
|
213
|
+
|
|
214
|
+
# Metadata ################################################################
|
|
215
|
+
|
|
216
|
+
@property
|
|
217
|
+
def metadata(self) -> pd.DataFrame:
|
|
218
|
+
raise NotImplementedError
|
|
219
|
+
|
|
220
|
+
def print_metadata(self) -> None:
|
|
221
|
+
raise NotImplementedError
|
|
222
|
+
|
|
223
|
+
# Python builtins #########################################################
|
|
224
|
+
|
|
225
|
+
def __hash__(self) -> int:
|
|
226
|
+
return hash((
|
|
227
|
+
self.task_type,
|
|
228
|
+
self.entity_table_names,
|
|
229
|
+
self._entity_column,
|
|
230
|
+
self._target_column,
|
|
231
|
+
self._time_column,
|
|
232
|
+
))
|
|
233
|
+
|
|
234
|
+
def __repr__(self) -> str:
|
|
235
|
+
if self.task_type.is_link_pred:
|
|
236
|
+
entity_table_repr = f'entity_table_names={self.entity_table_names}'
|
|
237
|
+
else:
|
|
238
|
+
entity_table_repr = f'entity_table_name={self.entity_table_name}'
|
|
239
|
+
return (f'{self.__class__.__name__}(\n'
|
|
240
|
+
f' task_type={self.task_type},\n'
|
|
241
|
+
f' num_context_examples={len(self._context_df)},\n'
|
|
242
|
+
f' num_prediction_examples={len(self._pred_df)},\n'
|
|
243
|
+
f' {entity_table_repr},\n'
|
|
244
|
+
f' entity_column={self._entity_column},\n'
|
|
245
|
+
f' target_column={self._target_column},\n'
|
|
246
|
+
f' time_column={self._time_column},\n'
|
|
247
|
+
f')')
|
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Literal, Mapping, Optional, Union, overload
|
|
3
|
+
|
|
4
|
+
from kumoapi.distilled_model_plan import DistilledModelPlan
|
|
5
|
+
from kumoapi.jobs import DistillationJobRequest, DistillationJobResource
|
|
6
|
+
|
|
7
|
+
from kumoai import global_state
|
|
8
|
+
from kumoai.client.jobs import TrainingJobID
|
|
9
|
+
from kumoai.graph import Graph
|
|
10
|
+
from kumoai.pquery.training_table import TrainingTable, TrainingTableJob
|
|
11
|
+
from kumoai.trainer.job import TrainingJob, TrainingJobResult
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class DistillationTrainer:
|
|
17
|
+
r"""A trainer supports creating a Kumo machine learning model
|
|
18
|
+
for use in an online serving endpoint. The distllation process involes
|
|
19
|
+
training a shallow model on a :class:`~kumoai.pquery.PredictiveQuery` using
|
|
20
|
+
the embeddings generated by a base model :args:`base_training_job_id`.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
model_plan: The distilled model plan to use for the distillation process.
|
|
24
|
+
base_training_job_id: The ID of the base training job to use for the distillation process.
|
|
25
|
+
""" # noqa: E501
|
|
26
|
+
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
model_plan: DistilledModelPlan,
|
|
30
|
+
base_training_job_id: TrainingJobID,
|
|
31
|
+
) -> None:
|
|
32
|
+
self.model_plan: DistilledModelPlan = model_plan
|
|
33
|
+
self.base_training_job_id: TrainingJobID = base_training_job_id
|
|
34
|
+
|
|
35
|
+
# Cached from backend:
|
|
36
|
+
self._training_job_id: Optional[TrainingJobID] = None
|
|
37
|
+
|
|
38
|
+
# Metadata ################################################################
|
|
39
|
+
|
|
40
|
+
@property
|
|
41
|
+
def is_trained(self) -> bool:
|
|
42
|
+
r"""Returns ``True`` if this trainer instance has successfully been
|
|
43
|
+
trained (and is therefore ready for prediction); ``False`` otherwise.
|
|
44
|
+
"""
|
|
45
|
+
raise NotImplementedError(
|
|
46
|
+
"Checking if a distilled trainer is trained is not "
|
|
47
|
+
"implemented yet.")
|
|
48
|
+
|
|
49
|
+
@overload
|
|
50
|
+
def fit(
|
|
51
|
+
self,
|
|
52
|
+
graph: Graph,
|
|
53
|
+
train_table: Union[TrainingTable, TrainingTableJob],
|
|
54
|
+
) -> TrainingJobResult:
|
|
55
|
+
pass
|
|
56
|
+
|
|
57
|
+
@overload
|
|
58
|
+
def fit(
|
|
59
|
+
self,
|
|
60
|
+
graph: Graph,
|
|
61
|
+
train_table: Union[TrainingTable, TrainingTableJob],
|
|
62
|
+
*,
|
|
63
|
+
non_blocking: Literal[False],
|
|
64
|
+
) -> TrainingJobResult:
|
|
65
|
+
pass
|
|
66
|
+
|
|
67
|
+
@overload
|
|
68
|
+
def fit(
|
|
69
|
+
self,
|
|
70
|
+
graph: Graph,
|
|
71
|
+
train_table: Union[TrainingTable, TrainingTableJob],
|
|
72
|
+
*,
|
|
73
|
+
non_blocking: Literal[True],
|
|
74
|
+
) -> TrainingJob:
|
|
75
|
+
pass
|
|
76
|
+
|
|
77
|
+
@overload
|
|
78
|
+
def fit(
|
|
79
|
+
self,
|
|
80
|
+
graph: Graph,
|
|
81
|
+
train_table: Union[TrainingTable, TrainingTableJob],
|
|
82
|
+
*,
|
|
83
|
+
non_blocking: bool,
|
|
84
|
+
) -> Union[TrainingJob, TrainingJobResult]:
|
|
85
|
+
pass
|
|
86
|
+
|
|
87
|
+
def fit(
|
|
88
|
+
self,
|
|
89
|
+
graph: Graph,
|
|
90
|
+
train_table: Union[TrainingTable, TrainingTableJob],
|
|
91
|
+
*,
|
|
92
|
+
non_blocking: bool = False,
|
|
93
|
+
custom_tags: Mapping[str, str] = {},
|
|
94
|
+
) -> Union[TrainingJob, TrainingJobResult]:
|
|
95
|
+
r"""Fits a model to the specified graph and training table, with the
|
|
96
|
+
strategy defined by :class:`DistilledTrainer`'s :obj:`model_plan`.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
graph: The :class:`~kumoai.graph.Graph` object that represents the
|
|
100
|
+
tables and relationships that Kumo will learn from.
|
|
101
|
+
train_table: The :class:`~kumoai.pquery.TrainingTable`, or
|
|
102
|
+
in-progress :class:`~kumoai.pquery.TrainingTableJob`, that
|
|
103
|
+
represents the training data produced by a
|
|
104
|
+
:class:`~kumoai.pquery.PredictiveQuery` on :obj:`graph`.
|
|
105
|
+
non_blocking: Whether this operation should return immediately
|
|
106
|
+
after launching the training job, or await completion of the
|
|
107
|
+
training job.
|
|
108
|
+
custom_tags: Additional, customer defined k-v tags to be associated
|
|
109
|
+
with the job to be launched. Job tags are useful for grouping
|
|
110
|
+
and searching jobs.
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
Union[TrainingJobResult, TrainingJob]:
|
|
114
|
+
If ``non_blocking=False``, returns a training job object. If
|
|
115
|
+
``non_blocking=True``, returns a training job future object.
|
|
116
|
+
"""
|
|
117
|
+
# TODO(manan, siyang): remove soon:
|
|
118
|
+
job_id = train_table.job_id
|
|
119
|
+
assert job_id is not None
|
|
120
|
+
|
|
121
|
+
train_table_job_api = global_state.client.generate_train_table_job_api
|
|
122
|
+
pq_id = train_table_job_api.get(job_id).config.pquery_id
|
|
123
|
+
assert pq_id is not None
|
|
124
|
+
|
|
125
|
+
custom_table = None
|
|
126
|
+
if isinstance(train_table, TrainingTable):
|
|
127
|
+
custom_table = train_table._custom_train_table
|
|
128
|
+
|
|
129
|
+
# NOTE the backend implementation currently handles sequentialization
|
|
130
|
+
# between a training table future and a training job; that is, if the
|
|
131
|
+
# training table future is still executing, the backend will wait on
|
|
132
|
+
# the job ID completion before executing a training job. This preserves
|
|
133
|
+
# semantics for both futures, ensures that Kumo works as expected if
|
|
134
|
+
# used only via REST API, and allows us to avoid chaining calllbacks
|
|
135
|
+
# in an ugly way here:
|
|
136
|
+
api = global_state.client.distillation_job_api
|
|
137
|
+
self._training_job_id = api.create(
|
|
138
|
+
DistillationJobRequest(
|
|
139
|
+
dict(custom_tags),
|
|
140
|
+
pquery_id=pq_id,
|
|
141
|
+
base_training_job_id=self.base_training_job_id,
|
|
142
|
+
distilled_model_plan=self.model_plan,
|
|
143
|
+
graph_snapshot_id=graph.snapshot(non_blocking=non_blocking),
|
|
144
|
+
train_table_job_id=job_id,
|
|
145
|
+
custom_train_table=custom_table,
|
|
146
|
+
))
|
|
147
|
+
|
|
148
|
+
out = TrainingJob(job_id=self._training_job_id)
|
|
149
|
+
if non_blocking:
|
|
150
|
+
return out
|
|
151
|
+
return out.attach()
|
|
152
|
+
|
|
153
|
+
@classmethod
|
|
154
|
+
def _load_from_job(
|
|
155
|
+
cls,
|
|
156
|
+
job: DistillationJobResource,
|
|
157
|
+
) -> 'DistillationTrainer':
|
|
158
|
+
trainer = cls(job.config.distilled_model_plan,
|
|
159
|
+
job.config.base_training_job_id)
|
|
160
|
+
trainer._training_job_id = job.job_id
|
|
161
|
+
return trainer
|
|
162
|
+
|
|
163
|
+
@classmethod
|
|
164
|
+
def load(cls, job_id: TrainingJobID) -> 'DistillationTrainer':
|
|
165
|
+
r"""Creates a :class:`~kumoai.trainer.Trainer` instance from a training
|
|
166
|
+
job ID.
|
|
167
|
+
"""
|
|
168
|
+
raise NotImplementedError(
|
|
169
|
+
"Loading a distilled trainer from a job ID is not implemented yet."
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
@classmethod
|
|
173
|
+
def load_from_tags(cls, tags: Mapping[str, str]) -> 'DistillationTrainer':
|
|
174
|
+
raise NotImplementedError(
|
|
175
|
+
"Loading a distilled trainer from tags is not implemented yet.")
|
kumoai/utils/display.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
|
|
3
|
+
import pandas as pd
|
|
4
|
+
|
|
5
|
+
from kumoai import in_notebook, in_snowflake_notebook
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def message(msg: str) -> None:
|
|
9
|
+
msg = msg.replace("`", "'") if not in_notebook() else msg
|
|
10
|
+
|
|
11
|
+
if in_snowflake_notebook():
|
|
12
|
+
import streamlit as st
|
|
13
|
+
st.markdown(msg)
|
|
14
|
+
elif in_notebook():
|
|
15
|
+
from IPython.display import Markdown, display
|
|
16
|
+
display(Markdown(msg))
|
|
17
|
+
else:
|
|
18
|
+
print(msg)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def title(msg: str) -> None:
|
|
22
|
+
message(f"### {msg}" if in_notebook() else f"{msg}:")
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def italic(msg: str) -> None:
|
|
26
|
+
message(f"*{msg}*" if in_notebook() else msg)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def unordered_list(items: Sequence[str]) -> None:
|
|
30
|
+
if in_notebook():
|
|
31
|
+
msg = '\n'.join([f"- {item}" for item in items])
|
|
32
|
+
else:
|
|
33
|
+
msg = '\n'.join([f"• {item.replace('`', '')}" for item in items])
|
|
34
|
+
message(msg)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def dataframe(df: pd.DataFrame) -> None:
|
|
38
|
+
if in_snowflake_notebook():
|
|
39
|
+
import streamlit as st
|
|
40
|
+
st.dataframe(df, hide_index=True)
|
|
41
|
+
elif in_notebook():
|
|
42
|
+
from IPython.display import display
|
|
43
|
+
try:
|
|
44
|
+
if hasattr(df.style, 'hide'):
|
|
45
|
+
display(df.style.hide(axis='index')) # pandas=2
|
|
46
|
+
else:
|
|
47
|
+
display(df.style.hide_index()) # pandas<1.3
|
|
48
|
+
except ImportError:
|
|
49
|
+
print(df.to_string(index=False)) # missing jinja2
|
|
50
|
+
else:
|
|
51
|
+
print(df.to_string(index=False))
|
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
kumoai/_logging.py,sha256=U2_5ROdyk92P4xO4H2WJV8EC7dr6YxmmnM-b7QX9M7I,886
|
|
2
2
|
kumoai/mixin.py,sha256=MP413xzuCqWhxAPUHmloLA3j4ZyF1tEtfi516b_hOXQ,812
|
|
3
|
-
kumoai/_version.py,sha256=
|
|
3
|
+
kumoai/_version.py,sha256=ZgEOgzhMS-JHEcx_StKbib3QyH7C44Vgs2zNQ7IM43A,39
|
|
4
4
|
kumoai/kumolib.cpython-312-darwin.so,sha256=xQvdWHx9xmQ11y3F3ywxJv6A0sDk6D3-2fQbxSdM1z4,232576
|
|
5
|
-
kumoai/__init__.py,sha256=
|
|
5
|
+
kumoai/__init__.py,sha256=x6Emn6VesHQz0wR7ZnbddPRYO9A5-0JTHDkzJ3Ocq6w,10907
|
|
6
6
|
kumoai/formatting.py,sha256=jA_rLDCGKZI8WWCha-vtuLenVKTZvli99Tqpurz1H84,953
|
|
7
7
|
kumoai/futures.py,sha256=oJFIfdCM_3nWIqQteBKYMY4fPhoYlYWE_JA2o6tx-ng,3737
|
|
8
8
|
kumoai/jobs.py,sha256=NrdLEFNo7oeCYSy-kj2nAvCFrz9BZ_xrhkqHFHk5ksY,2496
|
|
@@ -11,41 +11,43 @@ kumoai/databricks.py,sha256=e6E4lOFvZHXFwh4CO1kXU1zzDU3AapLQYMxjiHPC-HQ,476
|
|
|
11
11
|
kumoai/spcs.py,sha256=N31d7rLa-bgYh8e2J4YzX1ScxGLqiVXrqJnCl1y4Mts,4139
|
|
12
12
|
kumoai/_singleton.py,sha256=UTwrbDkoZSGB8ZelorvprPDDv9uZkUi1q_SrmsyngpQ,836
|
|
13
13
|
kumoai/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
14
|
-
kumoai/experimental/rfm/
|
|
15
|
-
kumoai/experimental/rfm/
|
|
14
|
+
kumoai/experimental/rfm/relbench.py,sha256=cVsxxV3TIL3PLEoYb-8tAVW3GSef6NQAd3rxdHJL63I,2276
|
|
15
|
+
kumoai/experimental/rfm/graph.py,sha256=H9lIQLDkL5zJMwEHh7PgruvMUxWsjpynXUT7gnmTTUM,46351
|
|
16
|
+
kumoai/experimental/rfm/__init__.py,sha256=bW2XyYtkbdiu_iICYFF2Fu1Fx5fyGbqne6m_6c1P-fY,7016
|
|
16
17
|
kumoai/experimental/rfm/sagemaker.py,sha256=6fyXO1Jd_scq-DH7kcv6JcV8QPyTbh4ceqwQDPADlZ0,4963
|
|
17
|
-
kumoai/experimental/rfm/rfm.py,sha256=
|
|
18
|
+
kumoai/experimental/rfm/rfm.py,sha256=M5Q_vasSVrAX-5ucFPrdWoTZdpE4VYxKjKZz7BccITc,49672
|
|
18
19
|
kumoai/experimental/rfm/authenticate.py,sha256=G2RkRWznMVQUzvhvbKhn0bMCY7VmoNYxluz3THRqSdE,18851
|
|
20
|
+
kumoai/experimental/rfm/task_table.py,sha256=SPwkEdKRTwHHFcLqbvC1cDkyLXN2-3DpY5ujAyHRE-Q,8377
|
|
19
21
|
kumoai/experimental/rfm/backend/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
20
22
|
kumoai/experimental/rfm/backend/sqlite/__init__.py,sha256=jl-DBbhsqQ-dUXyWhyQTM1AU2qNAtXCmi1mokdhtBTg,902
|
|
21
|
-
kumoai/experimental/rfm/backend/sqlite/table.py,sha256=
|
|
22
|
-
kumoai/experimental/rfm/backend/sqlite/sampler.py,sha256=
|
|
23
|
+
kumoai/experimental/rfm/backend/sqlite/table.py,sha256=WqYtd_rwlawItRMXZUfv14qdyU6huQmODuFjDo483dI,6683
|
|
24
|
+
kumoai/experimental/rfm/backend/sqlite/sampler.py,sha256=_D9C5mj3oL4J2qZFap3emvTy2jxzth3dEWZPfr4dmEE,16201
|
|
23
25
|
kumoai/experimental/rfm/backend/local/__init__.py,sha256=2s9sSA-E-8pfkkzCH4XPuaSxSznEURMfMgwEIfYYPsg,1014
|
|
24
|
-
kumoai/experimental/rfm/backend/local/table.py,sha256=
|
|
25
|
-
kumoai/experimental/rfm/backend/local/graph_store.py,sha256=
|
|
26
|
-
kumoai/experimental/rfm/backend/local/sampler.py,sha256=
|
|
26
|
+
kumoai/experimental/rfm/backend/local/table.py,sha256=GKeYGcu52ztCU8EBMqp5UVj85E145Ug41xiCPiTCXq4,3489
|
|
27
|
+
kumoai/experimental/rfm/backend/local/graph_store.py,sha256=RHhkI13KpdPxqb4vXkwEwuFiX5DkrEsfZsOLywNnrvU,11294
|
|
28
|
+
kumoai/experimental/rfm/backend/local/sampler.py,sha256=UKxTjsYs00sYuV_LAlDuZOvQq0BZzPCzZK1Fki2Fd70,10726
|
|
27
29
|
kumoai/experimental/rfm/backend/snow/__init__.py,sha256=BYfsiuJ4Ee30GjG9EuUtitMHXnRfvVKi85zNlIwldV4,993
|
|
28
|
-
kumoai/experimental/rfm/backend/snow/table.py,sha256
|
|
29
|
-
kumoai/experimental/rfm/backend/snow/sampler.py,sha256=
|
|
30
|
+
kumoai/experimental/rfm/backend/snow/table.py,sha256=9N7TOcXX8hhAjCawnhuvQCArBFTCdng3gBakunUxg90,8892
|
|
31
|
+
kumoai/experimental/rfm/backend/snow/sampler.py,sha256=zvPsgVnDfvskcnPWsIcqxw-Fn9DsCLfdoLE-m3bjeww,11483
|
|
30
32
|
kumoai/experimental/rfm/pquery/__init__.py,sha256=X0O3EIq5SMfBEE-ii5Cq6iDhR3s3XMXB52Cx5htoePw,152
|
|
31
33
|
kumoai/experimental/rfm/pquery/pandas_executor.py,sha256=MwSvFRwLq-z19LEdF0G0AT7Gj9tCqu-XLEA7mNbqXwc,18454
|
|
32
34
|
kumoai/experimental/rfm/pquery/executor.py,sha256=gs5AVNaA50ci8zXOBD3qt5szdTReSwTs4BGuEyx4BEE,2728
|
|
33
|
-
kumoai/experimental/rfm/infer/multicategorical.py,sha256=
|
|
35
|
+
kumoai/experimental/rfm/infer/multicategorical.py,sha256=lNO_8aJw1whO6QVEMB3PRWMNlEEiX44g3v4tP88TSQY,1119
|
|
34
36
|
kumoai/experimental/rfm/infer/categorical.py,sha256=VwNaKwKbRYkTxEJ1R6gziffC8dGsEThcDEfbi-KqW5c,853
|
|
35
37
|
kumoai/experimental/rfm/infer/time_col.py,sha256=oNenUK6P7ql8uwShodtQ73uG1x3fbFWT78jRcF9DLTI,1789
|
|
36
38
|
kumoai/experimental/rfm/infer/pkey.py,sha256=IaJI5GHK8ds_a3AOr3YYVgUlSmYYEgr4Nu92s2RyBV4,4412
|
|
37
39
|
kumoai/experimental/rfm/infer/id.py,sha256=ZIO0DWIoiEoS_8MVc5lkqBfkTWWQ0yGCgjkwLdaYa_Q,908
|
|
38
|
-
kumoai/experimental/rfm/infer/dtype.py,sha256=
|
|
39
|
-
kumoai/experimental/rfm/infer/__init__.py,sha256=
|
|
40
|
+
kumoai/experimental/rfm/infer/dtype.py,sha256=FyAqvtrOWQC9hGrhQ7sC4BAI6c9k6ew-fo8ClS1sewM,2782
|
|
41
|
+
kumoai/experimental/rfm/infer/__init__.py,sha256=8GDxQKd0pxZULdk7mpwl3CsOpL4v2HPuPEsbi2t_vzc,519
|
|
40
42
|
kumoai/experimental/rfm/infer/timestamp.py,sha256=vM9--7eStzaGG13Y-oLYlpNJyhL6f9dp17HDXwtl_DM,1094
|
|
41
|
-
kumoai/experimental/rfm/
|
|
42
|
-
kumoai/experimental/rfm/base/
|
|
43
|
-
kumoai/experimental/rfm/base/
|
|
44
|
-
kumoai/experimental/rfm/base/table.py,sha256=
|
|
45
|
-
kumoai/experimental/rfm/base/
|
|
46
|
-
kumoai/experimental/rfm/base/
|
|
47
|
-
kumoai/experimental/rfm/base/source.py,sha256=
|
|
48
|
-
kumoai/experimental/rfm/base/column.py,sha256=
|
|
43
|
+
kumoai/experimental/rfm/infer/stype.py,sha256=fu4zsOB-C7jNeMnq6dsK4bOZSewe7PtZe_AkohSRLoM,894
|
|
44
|
+
kumoai/experimental/rfm/base/sql_sampler.py,sha256=qurkEVlMhDZw3d9SM2uGud6TMv_Wx_iqWoCgEKd_g9o,5094
|
|
45
|
+
kumoai/experimental/rfm/base/__init__.py,sha256=rjmMux5lG8srw1bjQGcFQFv6zET9e5riP81nPkw28Jg,724
|
|
46
|
+
kumoai/experimental/rfm/base/table.py,sha256=6qZeTMfnQejrn6TwqQeJGzJG7C0dSjJ7-NMLX38dvns,26563
|
|
47
|
+
kumoai/experimental/rfm/base/sampler.py,sha256=tXYnVEyKC5NjSIpe8pNYp0V3Qbg-KbUE_QB0Emy2YiQ,30882
|
|
48
|
+
kumoai/experimental/rfm/base/expression.py,sha256=Y7NtLTnKlx6euG_N3fLTcrFKheB6P5KS_jhCfoXV9DE,1252
|
|
49
|
+
kumoai/experimental/rfm/base/source.py,sha256=bwu3GU2TvIXR2fwKAmJ1-5BDoNXMnI1SU3Fgdk8lWnc,301
|
|
50
|
+
kumoai/experimental/rfm/base/column.py,sha256=GXzLC-VpShr6PecUzaj1MJKc_PHzfW5Jn9bOYPA8fFA,4965
|
|
49
51
|
kumoai/encoder/__init__.py,sha256=VPGs4miBC_WfwWeOXeHhFomOUocERFavhKf5fqITcds,182
|
|
50
52
|
kumoai/graph/graph.py,sha256=iyp4klPIMn2ttuEqMJvsrxKb_tmz_DTnvziIhCegduM,38291
|
|
51
53
|
kumoai/graph/__init__.py,sha256=n8X4X8luox4hPBHTRC9R-3JzvYYMoR8n7lF1H4w4Hzc,228
|
|
@@ -56,6 +58,7 @@ kumoai/artifact_export/job.py,sha256=GEisSwvcjK_35RgOfsLXGgxMTXIWm765B_BW_Kgs-V0
|
|
|
56
58
|
kumoai/artifact_export/__init__.py,sha256=BsfDrc3mCHpO9-BqvqKm8qrXDIwfdaoH5UIoG4eQkc4,238
|
|
57
59
|
kumoai/utils/datasets.py,sha256=ptKIUoBONVD55pTVNdRCkQT3NWdN_r9UAUu4xewPa3U,2928
|
|
58
60
|
kumoai/utils/__init__.py,sha256=6S-UtwjeLpnCYRCCIEWhkitPYGaqOGXC1ChE13DzXiU,256
|
|
61
|
+
kumoai/utils/display.py,sha256=eXlw4B72y6zEruWYOfwvfqxfMBTL9AsPtWfw3BjaWqQ,1397
|
|
59
62
|
kumoai/utils/progress_logger.py,sha256=3aYOoVSbQv5i9m2T8IqMydofKf6iNB1jxsl1uGjHZz8,9265
|
|
60
63
|
kumoai/utils/sql.py,sha256=f6lR6rBEW7Dtk0NdM26dOZXUHDizEHb1WPlBCJrwoq0,118
|
|
61
64
|
kumoai/utils/forecasting.py,sha256=-nDS6ucKNfQhTQOfebjefj0wwWH3-KYNslIomxwwMBM,7415
|
|
@@ -85,7 +88,7 @@ kumoai/connector/bigquery_connector.py,sha256=IkyRqvF8Cg96kApUuuz86eYnl-BqBmDX1f
|
|
|
85
88
|
kumoai/connector/source_table.py,sha256=QLT8bEYaxeMwy-b168url0VfnkTrs5K6VKLbxTI4hEY,17539
|
|
86
89
|
kumoai/connector/__init__.py,sha256=9g6oNJ0qHWFlL5enTSoK4_SSH_5hP74xUDZx-9SggC4,842
|
|
87
90
|
kumoai/connector/file_upload_connector.py,sha256=swp03HgChOvmNPJetuujBSAqADe7NRmS_T0F3o9it4w,7008
|
|
88
|
-
kumoai/connector/utils.py,sha256=
|
|
91
|
+
kumoai/connector/utils.py,sha256=sD3_Dmf42FobMfVayzMVkDHIfXzPN-htD3RHd6Kw8hQ,65055
|
|
89
92
|
kumoai/connector/s3_connector.py,sha256=3kbv-h7DwD8O260Q0h1GPm5wwQpLt-Tb3d_CBSaie44,10155
|
|
90
93
|
kumoai/connector/base.py,sha256=cujXSZF3zAfuxNuEw54DSL1T7XCuR4t0shSMDuPUagQ,5291
|
|
91
94
|
kumoai/pquery/__init__.py,sha256=uTXr7t1eXcVfM-ETaM_1ImfEqhrmaj8BjiIvy1YZTL8,533
|
|
@@ -93,12 +96,12 @@ kumoai/pquery/predictive_query.py,sha256=UXn1s8ztubYZMNGl4ijaeidMiGlFveb1TGw9qI5
|
|
|
93
96
|
kumoai/pquery/prediction_table.py,sha256=QPDH22X1UB0NIufY7qGuV2XW7brG3Pv--FbjNezzM2g,10776
|
|
94
97
|
kumoai/pquery/training_table.py,sha256=elmPDZx11kPiC_dkOhJcBUGtHKgL32GCBvZ9k6U0pMg,15809
|
|
95
98
|
kumoai/client/pquery.py,sha256=IQ8As-OOJOkuMoMosphOsA5hxQYLCbzOQJO7RezK8uY,7091
|
|
96
|
-
kumoai/client/client.py,sha256=
|
|
99
|
+
kumoai/client/client.py,sha256=npTLooBtmZ9xOo7AbEiYQTh9wFktsGSEpSEfdB7vdB4,8715
|
|
97
100
|
kumoai/client/graph.py,sha256=zvLEDExLT_RVbUMHqVl0m6tO6s2gXmYSoWmPF6YMlnA,3831
|
|
98
101
|
kumoai/client/online.py,sha256=pkBBh_DEC3GAnPcNw6bopNRlGe7EUbIFe7_seQqZRaw,2720
|
|
99
102
|
kumoai/client/source_table.py,sha256=VCsCcM7KYcnjGP7HLTb-AOSEGEVsJTWjk8bMg1JdgPU,2101
|
|
100
103
|
kumoai/client/__init__.py,sha256=MkyOuMaHQ2c8GPxjBDQSVFhfRE2d2_6CXQ6rxj4ps4w,64
|
|
101
|
-
kumoai/client/jobs.py,sha256=
|
|
104
|
+
kumoai/client/jobs.py,sha256=z3By5MWvWdJ_wYFyJA34pD4NueOXvXEqrAANWEpp4Pk,18066
|
|
102
105
|
kumoai/client/utils.py,sha256=lz1NubwMDHCwzQRowRXm7mjAoYRd5UjRQIwXdtWAl90,3849
|
|
103
106
|
kumoai/client/connector.py,sha256=x3i2aBTJTEMZvYRcWkY-UfWVOANZjqAso4GBbcshFjw,3920
|
|
104
107
|
kumoai/client/table.py,sha256=cQG-RPm-e91idEgse1IPJDvBmzddIDGDkuyrR1rq4wU,3235
|
|
@@ -110,9 +113,10 @@ kumoai/trainer/job.py,sha256=Wk69nzFhbvuA3nEvtCstI04z5CxkgvQ6tHnGchE0Lkg,44938
|
|
|
110
113
|
kumoai/trainer/baseline_trainer.py,sha256=LlfViNOmswNv4c6zJJLsyv0pC2mM2WKMGYx06ogtEVc,4024
|
|
111
114
|
kumoai/trainer/__init__.py,sha256=zUdFl-f-sBWmm2x8R-rdVzPBeU2FaMzUY5mkcgoTa1k,939
|
|
112
115
|
kumoai/trainer/online_serving.py,sha256=9cddb5paeZaCgbUeceQdAOxysCtV5XP-KcsgFz_XR5w,9566
|
|
116
|
+
kumoai/trainer/distilled_trainer.py,sha256=2pPs5clakNxkLfaak7uqPJOrpTWe1RVVM7ztDSqQZvU,6484
|
|
113
117
|
kumoai/trainer/trainer.py,sha256=hBXO7gwpo3t59zKFTeIkK65B8QRmWCwO33sbDuEAPlY,20133
|
|
114
|
-
kumoai-2.14.0.
|
|
115
|
-
kumoai-2.14.0.
|
|
116
|
-
kumoai-2.14.0.
|
|
117
|
-
kumoai-2.14.0.
|
|
118
|
-
kumoai-2.14.0.
|
|
118
|
+
kumoai-2.14.0.dev202601041732.dist-info/RECORD,,
|
|
119
|
+
kumoai-2.14.0.dev202601041732.dist-info/WHEEL,sha256=V1loQ6TpxABu1APUg0MoTRBOzSKT5xVc3skizX-ovCU,136
|
|
120
|
+
kumoai-2.14.0.dev202601041732.dist-info/top_level.txt,sha256=YjU6UcmomoDx30vEXLsOU784ED7VztQOsFApk1SFwvs,7
|
|
121
|
+
kumoai-2.14.0.dev202601041732.dist-info/METADATA,sha256=dmLN7vMtkp6iM92XWX0BtKZCF8yU3RmoAKhJ-gwc4ME,2557
|
|
122
|
+
kumoai-2.14.0.dev202601041732.dist-info/licenses/LICENSE,sha256=TbWlyqRmhq9PEzCaTI0H0nWLQCCOywQM8wYH8MbjfLo,1102
|
|
@@ -1,16 +0,0 @@
|
|
|
1
|
-
from dataclasses import dataclass
|
|
2
|
-
from typing import Any, TypeAlias
|
|
3
|
-
|
|
4
|
-
from kumoapi.typing import Dtype
|
|
5
|
-
|
|
6
|
-
from kumoai.mixin import CastMixin
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
@dataclass(frozen=True)
|
|
10
|
-
class ColumnExpressionSpec(CastMixin):
|
|
11
|
-
name: str
|
|
12
|
-
expr: str
|
|
13
|
-
dtype: Dtype | None = None
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
ColumnExpressionType: TypeAlias = ColumnExpressionSpec | dict[str, Any]
|