kumoai 2.12.0.dev202511111731__cp311-cp311-macosx_11_0_arm64.whl → 2.13.0.dev202512091732__cp311-cp311-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +18 -9
- kumoai/_version.py +1 -1
- kumoai/client/client.py +9 -13
- kumoai/connector/utils.py +23 -2
- kumoai/experimental/rfm/__init__.py +162 -46
- kumoai/experimental/rfm/backend/__init__.py +0 -0
- kumoai/experimental/rfm/backend/local/__init__.py +42 -0
- kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +20 -30
- kumoai/experimental/rfm/backend/local/sampler.py +242 -0
- kumoai/experimental/rfm/backend/local/table.py +109 -0
- kumoai/experimental/rfm/backend/snow/__init__.py +35 -0
- kumoai/experimental/rfm/backend/snow/table.py +117 -0
- kumoai/experimental/rfm/backend/sqlite/__init__.py +30 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +101 -0
- kumoai/experimental/rfm/base/__init__.py +14 -0
- kumoai/experimental/rfm/base/column.py +66 -0
- kumoai/experimental/rfm/base/sampler.py +374 -0
- kumoai/experimental/rfm/base/source.py +18 -0
- kumoai/experimental/rfm/{local_table.py → base/table.py} +139 -139
- kumoai/experimental/rfm/{local_graph.py → graph.py} +334 -79
- kumoai/experimental/rfm/infer/__init__.py +6 -0
- kumoai/experimental/rfm/infer/dtype.py +79 -0
- kumoai/experimental/rfm/infer/pkey.py +126 -0
- kumoai/experimental/rfm/infer/time_col.py +62 -0
- kumoai/experimental/rfm/local_graph_sampler.py +43 -4
- kumoai/experimental/rfm/local_pquery_driver.py +1 -1
- kumoai/experimental/rfm/pquery/pandas_executor.py +1 -1
- kumoai/experimental/rfm/rfm.py +28 -27
- kumoai/experimental/rfm/sagemaker.py +138 -0
- kumoai/spcs.py +1 -3
- kumoai/testing/decorators.py +1 -1
- {kumoai-2.12.0.dev202511111731.dist-info → kumoai-2.13.0.dev202512091732.dist-info}/METADATA +12 -2
- {kumoai-2.12.0.dev202511111731.dist-info → kumoai-2.13.0.dev202512091732.dist-info}/RECORD +36 -21
- kumoai/experimental/rfm/utils.py +0 -344
- {kumoai-2.12.0.dev202511111731.dist-info → kumoai-2.13.0.dev202512091732.dist-info}/WHEEL +0 -0
- {kumoai-2.12.0.dev202511111731.dist-info → kumoai-2.13.0.dev202512091732.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.12.0.dev202511111731.dist-info → kumoai-2.13.0.dev202512091732.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
from kumoapi.typing import Dtype, Stype
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@dataclass(init=False, repr=False, eq=False)
|
|
8
|
+
class Column:
|
|
9
|
+
stype: Stype
|
|
10
|
+
|
|
11
|
+
def __init__(
|
|
12
|
+
self,
|
|
13
|
+
name: str,
|
|
14
|
+
dtype: Dtype,
|
|
15
|
+
stype: Stype,
|
|
16
|
+
is_primary_key: bool = False,
|
|
17
|
+
is_time_column: bool = False,
|
|
18
|
+
is_end_time_column: bool = False,
|
|
19
|
+
) -> None:
|
|
20
|
+
self._name = name
|
|
21
|
+
self._dtype = Dtype(dtype)
|
|
22
|
+
self._is_primary_key = is_primary_key
|
|
23
|
+
self._is_time_column = is_time_column
|
|
24
|
+
self._is_end_time_column = is_end_time_column
|
|
25
|
+
self.stype = Stype(stype)
|
|
26
|
+
|
|
27
|
+
@property
|
|
28
|
+
def name(self) -> str:
|
|
29
|
+
return self._name
|
|
30
|
+
|
|
31
|
+
@property
|
|
32
|
+
def dtype(self) -> Dtype:
|
|
33
|
+
return self._dtype
|
|
34
|
+
|
|
35
|
+
def __setattr__(self, key: str, val: Any) -> None:
|
|
36
|
+
if key == 'stype':
|
|
37
|
+
if isinstance(val, str):
|
|
38
|
+
val = Stype(val)
|
|
39
|
+
assert isinstance(val, Stype)
|
|
40
|
+
if not val.supports_dtype(self.dtype):
|
|
41
|
+
raise ValueError(f"Column '{self.name}' received an "
|
|
42
|
+
f"incompatible semantic type (got "
|
|
43
|
+
f"dtype='{self.dtype}' and stype='{val}')")
|
|
44
|
+
if self._is_primary_key and val != Stype.ID:
|
|
45
|
+
raise ValueError(f"Primary key '{self.name}' must have 'ID' "
|
|
46
|
+
f"semantic type (got '{val}')")
|
|
47
|
+
if self._is_time_column and val != Stype.timestamp:
|
|
48
|
+
raise ValueError(f"Time column '{self.name}' must have "
|
|
49
|
+
f"'timestamp' semantic type (got '{val}')")
|
|
50
|
+
if self._is_end_time_column and val != Stype.timestamp:
|
|
51
|
+
raise ValueError(f"End time column '{self.name}' must have "
|
|
52
|
+
f"'timestamp' semantic type (got '{val}')")
|
|
53
|
+
|
|
54
|
+
super().__setattr__(key, val)
|
|
55
|
+
|
|
56
|
+
def __hash__(self) -> int:
|
|
57
|
+
return hash((self.name, self.stype, self.dtype))
|
|
58
|
+
|
|
59
|
+
def __eq__(self, other: Any) -> bool:
|
|
60
|
+
if not isinstance(other, Column):
|
|
61
|
+
return False
|
|
62
|
+
return hash(self) == hash(other)
|
|
63
|
+
|
|
64
|
+
def __repr__(self) -> str:
|
|
65
|
+
return (f'{self.__class__.__name__}(name={self.name}, '
|
|
66
|
+
f'stype={self.stype}, dtype={self.dtype})')
|
|
@@ -0,0 +1,374 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import re
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from collections import defaultdict
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import TYPE_CHECKING, Literal
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import pandas as pd
|
|
10
|
+
from kumoapi.pquery import ValidatedPredictiveQuery
|
|
11
|
+
from kumoapi.pquery.AST import Aggregation, ASTNode
|
|
12
|
+
from kumoapi.rfm.context import EdgeLayout, Link, Subgraph, Table
|
|
13
|
+
from kumoapi.typing import Stype
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
from kumoai.experimental.rfm import Graph
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass
|
|
20
|
+
class SamplerOutput:
|
|
21
|
+
df_dict: dict[str, pd.DataFrame]
|
|
22
|
+
inverse_dict: dict[str, np.ndarray]
|
|
23
|
+
batch_dict: dict[str, np.ndarray]
|
|
24
|
+
num_sampled_nodes_dict: dict[str, list[int]]
|
|
25
|
+
row_dict: dict[tuple[str, str, str], np.ndarray]
|
|
26
|
+
col_dict: dict[tuple[str, str, str], np.ndarray]
|
|
27
|
+
num_sampled_edges_dict: dict[tuple[str, str, str], list[int]]
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass
|
|
31
|
+
class TargetOutput:
|
|
32
|
+
entity_pkey: pd.Series
|
|
33
|
+
anchor_time: pd.Series
|
|
34
|
+
target: pd.Series
|
|
35
|
+
num_trials: int
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class Sampler(ABC):
|
|
39
|
+
def __init__(self, graph: 'Graph') -> None:
|
|
40
|
+
self._edge_types: list[tuple[str, str, str]] = []
|
|
41
|
+
for edge in graph.edges:
|
|
42
|
+
edge_type = (edge.src_table, edge.fkey, edge.dst_table)
|
|
43
|
+
self._edge_types.append(edge_type)
|
|
44
|
+
self._edge_types.append(Subgraph.rev_edge_type(edge_type))
|
|
45
|
+
|
|
46
|
+
self._primary_key_dict: dict[str, str] = {
|
|
47
|
+
table.name: table._primary_key
|
|
48
|
+
for table in graph.tables.values()
|
|
49
|
+
if table._primary_key is not None
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
self._time_column_dict: dict[str, str] = {
|
|
53
|
+
table.name: table._time_column
|
|
54
|
+
for table in graph.tables.values()
|
|
55
|
+
if table._time_column is not None
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
self._end_time_column_dict: dict[str, str] = {
|
|
59
|
+
table.name: table._end_time_column
|
|
60
|
+
for table in graph.tables.values()
|
|
61
|
+
if table._end_time_column is not None
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
foreign_keys = {(edge.src_table, edge.fkey) for edge in graph.edges}
|
|
65
|
+
self._table_stype_dict: dict[str, dict[str, Stype]] = {}
|
|
66
|
+
for table in graph.tables.values():
|
|
67
|
+
self._table_stype_dict[table.name] = {}
|
|
68
|
+
for column in table.columns:
|
|
69
|
+
if column == table.primary_key:
|
|
70
|
+
continue
|
|
71
|
+
if (table.name, column.name) in foreign_keys:
|
|
72
|
+
continue
|
|
73
|
+
self._table_stype_dict[table.name][column.name] = column.stype
|
|
74
|
+
|
|
75
|
+
@property
|
|
76
|
+
def edge_types(self) -> list[tuple[str, str, str]]:
|
|
77
|
+
return self._edge_types
|
|
78
|
+
|
|
79
|
+
@property
|
|
80
|
+
def primary_key_dict(self) -> dict[str, str]:
|
|
81
|
+
return self._primary_key_dict
|
|
82
|
+
|
|
83
|
+
@property
|
|
84
|
+
def time_column_dict(self) -> dict[str, str]:
|
|
85
|
+
return self._time_column_dict
|
|
86
|
+
|
|
87
|
+
@property
|
|
88
|
+
def end_time_column_dict(self) -> dict[str, str]:
|
|
89
|
+
return self._end_time_column_dict
|
|
90
|
+
|
|
91
|
+
@property
|
|
92
|
+
def table_stype_dict(self) -> dict[str, dict[str, Stype]]:
|
|
93
|
+
return self._table_stype_dict
|
|
94
|
+
|
|
95
|
+
def sample_subgraph(
|
|
96
|
+
self,
|
|
97
|
+
entity_table_names: tuple[str, ...],
|
|
98
|
+
entity_pkey: pd.Series,
|
|
99
|
+
anchor_time: pd.Series,
|
|
100
|
+
num_neighbors: list[int],
|
|
101
|
+
exclude_cols_dict: dict[str, list[str]] | None = None,
|
|
102
|
+
) -> Subgraph:
|
|
103
|
+
|
|
104
|
+
# Exclude all columns that leak target information:
|
|
105
|
+
table_stype_dict: dict[str, dict[str, Stype]] = self._table_stype_dict
|
|
106
|
+
if exclude_cols_dict is not None:
|
|
107
|
+
table_stype_dict = copy.deepcopy(table_stype_dict)
|
|
108
|
+
for table_name, exclude_cols in exclude_cols_dict.items():
|
|
109
|
+
for column_name in exclude_cols:
|
|
110
|
+
del table_stype_dict[table_name][column_name]
|
|
111
|
+
|
|
112
|
+
# Collect all columns being used as features:
|
|
113
|
+
columns_dict: dict[str, set[str]] = {
|
|
114
|
+
table_name: set(stype_dict.keys())
|
|
115
|
+
for table_name, stype_dict in table_stype_dict.items()
|
|
116
|
+
}
|
|
117
|
+
# Make sure to store primary key information for entity tables:
|
|
118
|
+
for table_name in entity_table_names:
|
|
119
|
+
columns_dict[table_name].add(self.primary_key_dict[table_name])
|
|
120
|
+
|
|
121
|
+
if anchor_time.dtype != 'datetime64[ns]':
|
|
122
|
+
anchor_time = anchor_time.astype('datetime64[ns]')
|
|
123
|
+
|
|
124
|
+
out = self._sample_subgraph(
|
|
125
|
+
entity_table_name=entity_table_names[0],
|
|
126
|
+
entity_pkey=entity_pkey,
|
|
127
|
+
anchor_time=anchor_time,
|
|
128
|
+
columns_dict=columns_dict,
|
|
129
|
+
num_neighbors=num_neighbors,
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
subgraph = Subgraph(
|
|
133
|
+
anchor_time=anchor_time.astype(int).to_numpy(),
|
|
134
|
+
table_dict={},
|
|
135
|
+
link_dict={},
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
for table_name, batch in out.batch_dict.items():
|
|
139
|
+
if len(batch) == 0:
|
|
140
|
+
continue
|
|
141
|
+
|
|
142
|
+
primary_key: str | None = None
|
|
143
|
+
if table_name in entity_table_names:
|
|
144
|
+
primary_key = self.primary_key_dict[table_name]
|
|
145
|
+
|
|
146
|
+
df = out.df_dict[table_name].reset_index(drop=True)
|
|
147
|
+
if end_time_column := self.end_time_column_dict.get(table_name):
|
|
148
|
+
# Set end time to NaT for all values greater than anchor time:
|
|
149
|
+
assert table_name not in out.inverse_dict
|
|
150
|
+
ser = df[end_time_column]
|
|
151
|
+
if ser.dtype != 'datetime64[ns]':
|
|
152
|
+
ser = ser.astype('datetime64[ns]')
|
|
153
|
+
mask = ser > anchor_time.iloc[batch]
|
|
154
|
+
ser.iloc[mask] = pd.NaT
|
|
155
|
+
df[end_time_column] = ser
|
|
156
|
+
|
|
157
|
+
stype_dict = table_stype_dict[table_name]
|
|
158
|
+
for column_name, stype in stype_dict.items():
|
|
159
|
+
if stype == Stype.text:
|
|
160
|
+
df[column_name] = _normalize_text(df[column_name])
|
|
161
|
+
|
|
162
|
+
subgraph.table_dict[table_name] = Table(
|
|
163
|
+
df=df,
|
|
164
|
+
row=out.inverse_dict.get(table_name),
|
|
165
|
+
batch=batch,
|
|
166
|
+
num_sampled_nodes=out.num_sampled_nodes_dict[table_name],
|
|
167
|
+
stype_dict=stype_dict,
|
|
168
|
+
primary_key=primary_key,
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
for edge_type in out.row_dict.keys():
|
|
172
|
+
row: np.ndarray | None = out.row_dict[edge_type]
|
|
173
|
+
col: np.ndarray | None = out.col_dict[edge_type]
|
|
174
|
+
|
|
175
|
+
if row is None or col is None or len(row) == 0:
|
|
176
|
+
continue
|
|
177
|
+
|
|
178
|
+
# Do not store reverse edge type if it is an exact replica:
|
|
179
|
+
rev_edge_type = Subgraph.rev_edge_type(edge_type)
|
|
180
|
+
if (rev_edge_type in subgraph.link_dict
|
|
181
|
+
and np.array_equal(row, out.col_dict[rev_edge_type])
|
|
182
|
+
and np.array_equal(col, out.row_dict[rev_edge_type])):
|
|
183
|
+
subgraph.link_dict[edge_type] = Link(
|
|
184
|
+
layout=EdgeLayout.REV,
|
|
185
|
+
row=None,
|
|
186
|
+
col=None,
|
|
187
|
+
num_sampled_edges=out.num_sampled_edges_dict[edge_type],
|
|
188
|
+
)
|
|
189
|
+
continue
|
|
190
|
+
|
|
191
|
+
# Do not store non-informative edges:
|
|
192
|
+
layout = EdgeLayout.COO
|
|
193
|
+
if np.array_equal(row, np.arange(len(row))):
|
|
194
|
+
row = None
|
|
195
|
+
if np.array_equal(col, np.arange(len(col))):
|
|
196
|
+
col = None
|
|
197
|
+
|
|
198
|
+
# Store in compressed representation if more efficient:
|
|
199
|
+
num_cols = subgraph.table_dict[edge_type[2]].num_rows
|
|
200
|
+
if col is not None and len(col) > num_cols + 1:
|
|
201
|
+
layout = EdgeLayout.CSC
|
|
202
|
+
colcount = np.bincount(col, minlength=num_cols)
|
|
203
|
+
col = np.empty(num_cols + 1, dtype=col.dtype)
|
|
204
|
+
col[0] = 0
|
|
205
|
+
np.cumsum(colcount, out=col[1:])
|
|
206
|
+
|
|
207
|
+
subgraph.link_dict[edge_type] = Link(
|
|
208
|
+
layout=layout,
|
|
209
|
+
row=row,
|
|
210
|
+
col=col,
|
|
211
|
+
num_sampled_edges=out.num_sampled_edges_dict[edge_type],
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
return subgraph
|
|
215
|
+
|
|
216
|
+
def sample_target(
|
|
217
|
+
self,
|
|
218
|
+
query: ValidatedPredictiveQuery,
|
|
219
|
+
num_examples: int,
|
|
220
|
+
anchor_time: pd.Timestamp | Literal['entity'],
|
|
221
|
+
random_seed: int | None = None,
|
|
222
|
+
) -> TargetOutput:
|
|
223
|
+
|
|
224
|
+
columns_dict: dict[str, set[str]] = defaultdict(set)
|
|
225
|
+
for fqn in query.all_query_columns + [query.entity_column]:
|
|
226
|
+
table_name, column_name = fqn.split('.')
|
|
227
|
+
columns_dict[table_name].add(column_name)
|
|
228
|
+
|
|
229
|
+
if column_name := self.time_column_dict.get(query.entity_table):
|
|
230
|
+
columns_dict[table_name].add(column_name)
|
|
231
|
+
if column_name := self.end_time_column_dict.get(query.entity_table):
|
|
232
|
+
columns_dict[table_name].add(column_name)
|
|
233
|
+
|
|
234
|
+
time_offset_dict: dict[
|
|
235
|
+
tuple[str, str, str],
|
|
236
|
+
tuple[pd.DateOffset | None, pd.DateOffset],
|
|
237
|
+
] = {}
|
|
238
|
+
|
|
239
|
+
def _add_time_offset(node: ASTNode, num_forecasts: int = 1) -> None:
|
|
240
|
+
if isinstance(node, Aggregation):
|
|
241
|
+
table_name = node._get_target_column_name().split('.')[0]
|
|
242
|
+
columns_dict[table_name].add(self.time_column_dict[table_name])
|
|
243
|
+
|
|
244
|
+
edge_types = [
|
|
245
|
+
edge_type for edge_type in self.edge_types
|
|
246
|
+
if edge_type[0] == table_name
|
|
247
|
+
and edge_type[2] == query.entity_table
|
|
248
|
+
]
|
|
249
|
+
if len(edge_types) != 1:
|
|
250
|
+
raise ValueError(f"Could not find a unique foreign key "
|
|
251
|
+
f"from table '{table_name}' to "
|
|
252
|
+
f"'{query.entity_table}'")
|
|
253
|
+
if edge_types[0] not in time_offset_dict:
|
|
254
|
+
start = node.aggr_time_range.start_date_offset
|
|
255
|
+
end = node.aggr_time_range.end_date_offset * num_forecasts
|
|
256
|
+
else:
|
|
257
|
+
start, end = time_offset_dict[edge_types[0]]
|
|
258
|
+
start = min_date_offset(
|
|
259
|
+
start,
|
|
260
|
+
node.aggr_time_range.start_date_offset,
|
|
261
|
+
)
|
|
262
|
+
end = max_date_offset(
|
|
263
|
+
end,
|
|
264
|
+
node.aggr_time_range.end_date_offset * num_forecasts,
|
|
265
|
+
)
|
|
266
|
+
time_offset_dict[edge_types[0]] = (start, end)
|
|
267
|
+
|
|
268
|
+
for child in node.children:
|
|
269
|
+
_add_time_offset(child, num_forecasts)
|
|
270
|
+
|
|
271
|
+
_add_time_offset(query.target_ast, query.num_forecasts)
|
|
272
|
+
_add_time_offset(query.entity_ast)
|
|
273
|
+
if query.whatif_ast is not None:
|
|
274
|
+
_add_time_offset(query.whatif_ast)
|
|
275
|
+
|
|
276
|
+
return self._sample_target(
|
|
277
|
+
query=query,
|
|
278
|
+
num_examples=num_examples,
|
|
279
|
+
anchor_time=anchor_time,
|
|
280
|
+
columns_dict=columns_dict,
|
|
281
|
+
time_offset_dict=time_offset_dict,
|
|
282
|
+
random_seed=random_seed,
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
# Abstract Methods ########################################################
|
|
286
|
+
|
|
287
|
+
@abstractmethod
|
|
288
|
+
def _sample_subgraph(
|
|
289
|
+
self,
|
|
290
|
+
entity_table_name: str,
|
|
291
|
+
entity_pkey: pd.Series,
|
|
292
|
+
anchor_time: pd.Series,
|
|
293
|
+
columns_dict: dict[str, set[str]],
|
|
294
|
+
num_neighbors: list[int],
|
|
295
|
+
) -> SamplerOutput:
|
|
296
|
+
pass
|
|
297
|
+
|
|
298
|
+
@abstractmethod
|
|
299
|
+
def _sample_target(
|
|
300
|
+
self,
|
|
301
|
+
query: ValidatedPredictiveQuery,
|
|
302
|
+
num_examples: int,
|
|
303
|
+
anchor_time: pd.Timestamp | Literal['entity'],
|
|
304
|
+
columns_dict: dict[str, set[str]],
|
|
305
|
+
time_offset_dict: dict[
|
|
306
|
+
tuple[str, str, str],
|
|
307
|
+
tuple[pd.DateOffset | None, pd.DateOffset],
|
|
308
|
+
],
|
|
309
|
+
random_seed: int | None = None,
|
|
310
|
+
) -> TargetOutput:
|
|
311
|
+
pass
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
# Helper Functions ############################################################
|
|
315
|
+
|
|
316
|
+
PUNCTUATION = re.compile(r"[\'\"\.,\(\)\!\?\;\:]")
|
|
317
|
+
MULTISPACE = re.compile(r"\s+")
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
def _normalize_text(
|
|
321
|
+
ser: pd.Series,
|
|
322
|
+
max_words: int | None = 50,
|
|
323
|
+
) -> pd.Series:
|
|
324
|
+
r"""Normalizes text into a list of lower-case words.
|
|
325
|
+
|
|
326
|
+
Args:
|
|
327
|
+
ser: The :class:`pandas.Series` to normalize.
|
|
328
|
+
max_words: The maximum number of words to return.
|
|
329
|
+
This will auto-shrink any large text column to avoid blowing up
|
|
330
|
+
context size.
|
|
331
|
+
"""
|
|
332
|
+
if len(ser) == 0 or pd.api.types.is_list_like(ser.iloc[0]):
|
|
333
|
+
return ser
|
|
334
|
+
|
|
335
|
+
def normalize_fn(line: str) -> list[str]:
|
|
336
|
+
line = PUNCTUATION.sub(" ", line)
|
|
337
|
+
line = re.sub(r"<br\s*/?>", " ", line) # Handle <br /> or <br>
|
|
338
|
+
line = MULTISPACE.sub(" ", line)
|
|
339
|
+
words = line.split()
|
|
340
|
+
if max_words is not None:
|
|
341
|
+
words = words[:max_words]
|
|
342
|
+
return words
|
|
343
|
+
|
|
344
|
+
ser = ser.fillna('').astype(str)
|
|
345
|
+
|
|
346
|
+
if max_words is not None:
|
|
347
|
+
# We estimate the number of words as 5 characters + 1 space in an
|
|
348
|
+
# English text on average. We need this pre-filter here, as word
|
|
349
|
+
# splitting on a giant text can be very expensive:
|
|
350
|
+
ser = ser.str[:6 * max_words]
|
|
351
|
+
|
|
352
|
+
ser = ser.str.lower()
|
|
353
|
+
ser = ser.map(normalize_fn)
|
|
354
|
+
|
|
355
|
+
return ser
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
def min_date_offset(*args: pd.DateOffset | None) -> pd.DateOffset | None:
|
|
359
|
+
if any(arg is None for arg in args):
|
|
360
|
+
return None
|
|
361
|
+
|
|
362
|
+
anchor = pd.Timestamp('2000-01-01')
|
|
363
|
+
timestamps = [anchor + arg for arg in args]
|
|
364
|
+
assert len(timestamps) > 0
|
|
365
|
+
argmin = min(range(len(timestamps)), key=lambda i: timestamps[i])
|
|
366
|
+
return args[argmin]
|
|
367
|
+
|
|
368
|
+
|
|
369
|
+
def max_date_offset(*args: pd.DateOffset) -> pd.DateOffset:
|
|
370
|
+
anchor = pd.Timestamp('2000-01-01')
|
|
371
|
+
timestamps = [anchor + arg for arg in args]
|
|
372
|
+
assert len(timestamps) > 0
|
|
373
|
+
argmax = max(range(len(timestamps)), key=lambda i: timestamps[i])
|
|
374
|
+
return args[argmax]
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
from kumoapi.typing import Dtype
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@dataclass
|
|
7
|
+
class SourceColumn:
|
|
8
|
+
name: str
|
|
9
|
+
dtype: Dtype
|
|
10
|
+
is_primary_key: bool
|
|
11
|
+
is_unique_key: bool
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class SourceForeignKey:
|
|
16
|
+
name: str
|
|
17
|
+
dst_table: str
|
|
18
|
+
primary_key: str
|