kumoai 2.13.0.dev202512091732__cp311-cp311-macosx_11_0_arm64.whl → 2.14.0.dev202601051732__cp311-cp311-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +23 -26
- kumoai/_version.py +1 -1
- kumoai/client/client.py +6 -0
- kumoai/client/jobs.py +24 -0
- kumoai/client/pquery.py +6 -2
- kumoai/connector/utils.py +21 -7
- kumoai/experimental/rfm/__init__.py +51 -24
- kumoai/experimental/rfm/authenticate.py +3 -4
- kumoai/experimental/rfm/backend/local/graph_store.py +52 -104
- kumoai/experimental/rfm/backend/local/sampler.py +125 -55
- kumoai/experimental/rfm/backend/local/table.py +35 -31
- kumoai/experimental/rfm/backend/snow/__init__.py +2 -0
- kumoai/experimental/rfm/backend/snow/sampler.py +297 -0
- kumoai/experimental/rfm/backend/snow/table.py +174 -49
- kumoai/experimental/rfm/backend/sqlite/__init__.py +4 -2
- kumoai/experimental/rfm/backend/sqlite/sampler.py +398 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +131 -48
- kumoai/experimental/rfm/base/__init__.py +21 -5
- kumoai/experimental/rfm/base/column.py +96 -10
- kumoai/experimental/rfm/base/expression.py +44 -0
- kumoai/experimental/rfm/base/sampler.py +422 -35
- kumoai/experimental/rfm/base/source.py +2 -1
- kumoai/experimental/rfm/base/sql_sampler.py +144 -0
- kumoai/experimental/rfm/base/table.py +386 -195
- kumoai/experimental/rfm/graph.py +350 -178
- kumoai/experimental/rfm/infer/__init__.py +6 -4
- kumoai/experimental/rfm/infer/dtype.py +7 -4
- kumoai/experimental/rfm/infer/multicategorical.py +1 -1
- kumoai/experimental/rfm/infer/pkey.py +4 -2
- kumoai/experimental/rfm/infer/stype.py +35 -0
- kumoai/experimental/rfm/infer/time_col.py +1 -2
- kumoai/experimental/rfm/pquery/executor.py +27 -27
- kumoai/experimental/rfm/pquery/pandas_executor.py +29 -31
- kumoai/experimental/rfm/relbench.py +76 -0
- kumoai/experimental/rfm/rfm.py +630 -408
- kumoai/experimental/rfm/sagemaker.py +4 -4
- kumoai/experimental/rfm/task_table.py +290 -0
- kumoai/pquery/predictive_query.py +10 -6
- kumoai/testing/snow.py +50 -0
- kumoai/trainer/distilled_trainer.py +175 -0
- kumoai/utils/__init__.py +3 -2
- kumoai/utils/display.py +51 -0
- kumoai/utils/progress_logger.py +190 -12
- kumoai/utils/sql.py +3 -0
- {kumoai-2.13.0.dev202512091732.dist-info → kumoai-2.14.0.dev202601051732.dist-info}/METADATA +3 -2
- {kumoai-2.13.0.dev202512091732.dist-info → kumoai-2.14.0.dev202601051732.dist-info}/RECORD +49 -40
- kumoai/experimental/rfm/local_graph_sampler.py +0 -223
- kumoai/experimental/rfm/local_pquery_driver.py +0 -689
- {kumoai-2.13.0.dev202512091732.dist-info → kumoai-2.14.0.dev202601051732.dist-info}/WHEEL +0 -0
- {kumoai-2.13.0.dev202512091732.dist-info → kumoai-2.14.0.dev202601051732.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.13.0.dev202512091732.dist-info → kumoai-2.14.0.dev202601051732.dist-info}/top_level.txt +0 -0
kumoai/utils/progress_logger.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
|
+
import re
|
|
1
2
|
import sys
|
|
2
3
|
import time
|
|
3
|
-
from typing import Any
|
|
4
|
+
from typing import Any
|
|
4
5
|
|
|
5
6
|
from rich.console import Console, ConsoleOptions, RenderResult
|
|
6
7
|
from rich.live import Live
|
|
@@ -20,12 +21,23 @@ from typing_extensions import Self
|
|
|
20
21
|
|
|
21
22
|
|
|
22
23
|
class ProgressLogger:
|
|
23
|
-
def __init__(self, msg: str) -> None:
|
|
24
|
+
def __init__(self, msg: str, verbose: bool = True) -> None:
|
|
24
25
|
self.msg = msg
|
|
25
|
-
self.
|
|
26
|
+
self.verbose = verbose
|
|
27
|
+
self.depth = 0
|
|
28
|
+
|
|
29
|
+
self.logs: list[str] = []
|
|
30
|
+
|
|
31
|
+
self.start_time: float | None = None
|
|
32
|
+
self.end_time: float | None = None
|
|
33
|
+
|
|
34
|
+
@classmethod
|
|
35
|
+
def default(cls, msg: str, verbose: bool = True) -> 'ProgressLogger':
|
|
36
|
+
from kumoai import in_snowflake_notebook
|
|
26
37
|
|
|
27
|
-
|
|
28
|
-
|
|
38
|
+
if in_snowflake_notebook():
|
|
39
|
+
return StreamlitProgressLogger(msg, verbose)
|
|
40
|
+
return RichProgressLogger(msg, verbose)
|
|
29
41
|
|
|
30
42
|
@property
|
|
31
43
|
def duration(self) -> float:
|
|
@@ -37,11 +49,19 @@ class ProgressLogger:
|
|
|
37
49
|
def log(self, msg: str) -> None:
|
|
38
50
|
self.logs.append(msg)
|
|
39
51
|
|
|
52
|
+
def init_progress(self, total: int, description: str) -> None:
|
|
53
|
+
pass
|
|
54
|
+
|
|
55
|
+
def step(self) -> None:
|
|
56
|
+
pass
|
|
57
|
+
|
|
40
58
|
def __enter__(self) -> Self:
|
|
59
|
+
self.depth += 1
|
|
41
60
|
self.start_time = time.perf_counter()
|
|
42
61
|
return self
|
|
43
62
|
|
|
44
63
|
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
|
64
|
+
self.depth -= 1
|
|
45
65
|
self.end_time = time.perf_counter()
|
|
46
66
|
|
|
47
67
|
def __repr__(self) -> str:
|
|
@@ -66,22 +86,21 @@ class ColoredTimeRemainingColumn(TimeRemainingColumn):
|
|
|
66
86
|
return Text(str(super().render(task)), style=self.style)
|
|
67
87
|
|
|
68
88
|
|
|
69
|
-
class
|
|
89
|
+
class RichProgressLogger(ProgressLogger):
|
|
70
90
|
def __init__(
|
|
71
91
|
self,
|
|
72
92
|
msg: str,
|
|
73
93
|
verbose: bool = True,
|
|
74
94
|
refresh_per_second: int = 10,
|
|
75
95
|
) -> None:
|
|
76
|
-
super().__init__(msg=msg)
|
|
96
|
+
super().__init__(msg=msg, verbose=verbose)
|
|
77
97
|
|
|
78
|
-
self.verbose = verbose
|
|
79
98
|
self.refresh_per_second = refresh_per_second
|
|
80
99
|
|
|
81
|
-
self._progress:
|
|
82
|
-
self._task:
|
|
100
|
+
self._progress: Progress | None = None
|
|
101
|
+
self._task: int | None = None
|
|
83
102
|
|
|
84
|
-
self._live:
|
|
103
|
+
self._live: Live | None = None
|
|
85
104
|
self._exception: bool = False
|
|
86
105
|
|
|
87
106
|
def init_progress(self, total: int, description: str) -> None:
|
|
@@ -107,6 +126,9 @@ class InteractiveProgressLogger(ProgressLogger):
|
|
|
107
126
|
|
|
108
127
|
super().__enter__()
|
|
109
128
|
|
|
129
|
+
if self.depth > 1:
|
|
130
|
+
return self
|
|
131
|
+
|
|
110
132
|
if not in_notebook(): # Render progress bar in TUI.
|
|
111
133
|
sys.stdout.write("\x1b]9;4;3\x07")
|
|
112
134
|
sys.stdout.flush()
|
|
@@ -126,6 +148,9 @@ class InteractiveProgressLogger(ProgressLogger):
|
|
|
126
148
|
|
|
127
149
|
super().__exit__(exc_type, exc_val, exc_tb)
|
|
128
150
|
|
|
151
|
+
if self.depth > 1:
|
|
152
|
+
return
|
|
153
|
+
|
|
129
154
|
if exc_type is not None:
|
|
130
155
|
self._exception = True
|
|
131
156
|
|
|
@@ -151,7 +176,7 @@ class InteractiveProgressLogger(ProgressLogger):
|
|
|
151
176
|
|
|
152
177
|
table = Table.grid(padding=(0, 1))
|
|
153
178
|
|
|
154
|
-
icon:
|
|
179
|
+
icon: Text | Padding
|
|
155
180
|
if self._exception:
|
|
156
181
|
style = 'red'
|
|
157
182
|
icon = Text('❌', style=style)
|
|
@@ -175,3 +200,156 @@ class InteractiveProgressLogger(ProgressLogger):
|
|
|
175
200
|
|
|
176
201
|
if self.verbose and self._progress is not None:
|
|
177
202
|
yield self._progress.get_renderable()
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
class StreamlitProgressLogger(ProgressLogger):
|
|
206
|
+
def __init__(
|
|
207
|
+
self,
|
|
208
|
+
msg: str,
|
|
209
|
+
verbose: bool = True,
|
|
210
|
+
) -> None:
|
|
211
|
+
super().__init__(msg=msg, verbose=verbose)
|
|
212
|
+
|
|
213
|
+
self._status: Any = None
|
|
214
|
+
|
|
215
|
+
self._total = 0
|
|
216
|
+
self._current = 0
|
|
217
|
+
self._description: str = ''
|
|
218
|
+
self._progress: Any = None
|
|
219
|
+
|
|
220
|
+
def __enter__(self) -> Self:
|
|
221
|
+
super().__enter__()
|
|
222
|
+
|
|
223
|
+
import streamlit as st
|
|
224
|
+
|
|
225
|
+
if self.depth > 1:
|
|
226
|
+
return self
|
|
227
|
+
|
|
228
|
+
# Adjust layout for prettier output:
|
|
229
|
+
st.markdown(STREAMLIT_CSS, unsafe_allow_html=True)
|
|
230
|
+
|
|
231
|
+
if self.verbose:
|
|
232
|
+
self._status = st.status(
|
|
233
|
+
f':blue[{self._sanitize_text(self.msg)}]',
|
|
234
|
+
expanded=True,
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
return self
|
|
238
|
+
|
|
239
|
+
def log(self, msg: str) -> None:
|
|
240
|
+
super().log(msg)
|
|
241
|
+
if self.verbose and self._status is not None:
|
|
242
|
+
self._status.write(self._sanitize_text(msg))
|
|
243
|
+
|
|
244
|
+
def init_progress(self, total: int, description: str) -> None:
|
|
245
|
+
if self.verbose and self._status is not None:
|
|
246
|
+
self._total = total
|
|
247
|
+
self._current = 0
|
|
248
|
+
self._description = self._sanitize_text(description)
|
|
249
|
+
percent = min(self._current / self._total, 1.0)
|
|
250
|
+
self._progress = self._status.progress(
|
|
251
|
+
value=percent,
|
|
252
|
+
text=f'{self._description} [{self._current}/{self._total}]',
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
def step(self) -> None:
|
|
256
|
+
self._current += 1
|
|
257
|
+
|
|
258
|
+
if self.verbose and self._progress is not None:
|
|
259
|
+
percent = min(self._current / self._total, 1.0)
|
|
260
|
+
self._progress.progress(
|
|
261
|
+
value=percent,
|
|
262
|
+
text=f'{self._description} [{self._current}/{self._total}]',
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
|
266
|
+
super().__exit__(exc_type, exc_val, exc_tb)
|
|
267
|
+
|
|
268
|
+
if not self.verbose or self._status is None or self.depth > 1:
|
|
269
|
+
return
|
|
270
|
+
|
|
271
|
+
label = f'{self._sanitize_text(self.msg)} ({self.duration:.2f}s)'
|
|
272
|
+
|
|
273
|
+
if exc_type is not None:
|
|
274
|
+
self._status.update(
|
|
275
|
+
label=f':red[{label}]',
|
|
276
|
+
state='error',
|
|
277
|
+
expanded=True,
|
|
278
|
+
)
|
|
279
|
+
else:
|
|
280
|
+
self._status.update(
|
|
281
|
+
label=f':green[{label}]',
|
|
282
|
+
state='complete',
|
|
283
|
+
expanded=True,
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
@staticmethod
|
|
287
|
+
def _sanitize_text(msg: str) -> str:
|
|
288
|
+
return re.sub(r'\[/?bold\]', '**', msg)
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
STREAMLIT_CSS = """
|
|
292
|
+
<style>
|
|
293
|
+
/* Fix horizontal scrollbar */
|
|
294
|
+
.stExpander summary {
|
|
295
|
+
width: auto;
|
|
296
|
+
}
|
|
297
|
+
|
|
298
|
+
/* Fix paddings/margins */
|
|
299
|
+
.stExpander summary {
|
|
300
|
+
padding: 0.75rem 1rem 0.5rem;
|
|
301
|
+
}
|
|
302
|
+
.stExpander p {
|
|
303
|
+
margin: 0px 0px 0.2rem;
|
|
304
|
+
}
|
|
305
|
+
.stExpander [data-testid="stExpanderDetails"] {
|
|
306
|
+
padding-bottom: 1.45rem;
|
|
307
|
+
}
|
|
308
|
+
.stExpander .stProgress div:first-child {
|
|
309
|
+
padding-bottom: 4px;
|
|
310
|
+
}
|
|
311
|
+
|
|
312
|
+
/* Fix expand icon position */
|
|
313
|
+
.stExpander summary svg {
|
|
314
|
+
height: 1.5rem;
|
|
315
|
+
}
|
|
316
|
+
|
|
317
|
+
/* Fix summary icons */
|
|
318
|
+
.stExpander summary [data-testid="stExpanderIconCheck"] {
|
|
319
|
+
font-size: 1.8rem;
|
|
320
|
+
margin-top: -3px;
|
|
321
|
+
color: rgb(21, 130, 55);
|
|
322
|
+
}
|
|
323
|
+
.stExpander summary [data-testid="stExpanderIconError"] {
|
|
324
|
+
font-size: 1.8rem;
|
|
325
|
+
margin-top: -3px;
|
|
326
|
+
color: rgb(255, 43, 43);
|
|
327
|
+
}
|
|
328
|
+
.stExpander summary span:first-child span:first-child {
|
|
329
|
+
width: 1.6rem;
|
|
330
|
+
}
|
|
331
|
+
|
|
332
|
+
/* Add border between title and content */
|
|
333
|
+
.stExpander [data-testid="stExpanderDetails"] {
|
|
334
|
+
border-top: 1px solid rgba(30, 37, 47, 0.2);
|
|
335
|
+
padding-top: 0.5rem;
|
|
336
|
+
}
|
|
337
|
+
|
|
338
|
+
/* Fix title font size */
|
|
339
|
+
.stExpander summary p {
|
|
340
|
+
font-size: 1rem;
|
|
341
|
+
}
|
|
342
|
+
|
|
343
|
+
/* Gray out content */
|
|
344
|
+
.stExpander [data-testid="stExpanderDetails"] {
|
|
345
|
+
color: rgba(30, 37, 47, 0.5);
|
|
346
|
+
}
|
|
347
|
+
|
|
348
|
+
/* Fix progress bar font size */
|
|
349
|
+
.stExpander .stProgress p {
|
|
350
|
+
line-height: 1.6;
|
|
351
|
+
font-size: 1rem;
|
|
352
|
+
color: rgba(30, 37, 47, 0.5);
|
|
353
|
+
}
|
|
354
|
+
</style>
|
|
355
|
+
"""
|
kumoai/utils/sql.py
ADDED
{kumoai-2.13.0.dev202512091732.dist-info → kumoai-2.14.0.dev202601051732.dist-info}/METADATA
RENAMED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: kumoai
|
|
3
|
-
Version: 2.
|
|
3
|
+
Version: 2.14.0.dev202601051732
|
|
4
4
|
Summary: AI on the Modern Data Stack
|
|
5
5
|
Author-email: "Kumo.AI" <hello@kumo.ai>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -23,7 +23,7 @@ Requires-Dist: requests>=2.28.2
|
|
|
23
23
|
Requires-Dist: urllib3
|
|
24
24
|
Requires-Dist: plotly
|
|
25
25
|
Requires-Dist: typing_extensions>=4.5.0
|
|
26
|
-
Requires-Dist: kumo-api==0.
|
|
26
|
+
Requires-Dist: kumo-api==0.49.0
|
|
27
27
|
Requires-Dist: tqdm>=4.66.0
|
|
28
28
|
Requires-Dist: aiohttp>=3.10.0
|
|
29
29
|
Requires-Dist: pydantic>=1.10.21
|
|
@@ -41,6 +41,7 @@ Requires-Dist: requests-mock; extra == "test"
|
|
|
41
41
|
Provides-Extra: sqlite
|
|
42
42
|
Requires-Dist: adbc_driver_sqlite; extra == "sqlite"
|
|
43
43
|
Provides-Extra: snowflake
|
|
44
|
+
Requires-Dist: numpy<2.0; extra == "snowflake"
|
|
44
45
|
Requires-Dist: snowflake-connector-python; extra == "snowflake"
|
|
45
46
|
Requires-Dist: pyyaml; extra == "snowflake"
|
|
46
47
|
Provides-Extra: sagemaker
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
kumoai/_logging.py,sha256=U2_5ROdyk92P4xO4H2WJV8EC7dr6YxmmnM-b7QX9M7I,886
|
|
2
2
|
kumoai/mixin.py,sha256=MP413xzuCqWhxAPUHmloLA3j4ZyF1tEtfi516b_hOXQ,812
|
|
3
|
-
kumoai/_version.py,sha256=
|
|
4
|
-
kumoai/__init__.py,sha256=
|
|
3
|
+
kumoai/_version.py,sha256=DlzSXtmzrqDNbHnCm1VKEOGMvKKcB2xEkXcSe0tIMyI,39
|
|
4
|
+
kumoai/__init__.py,sha256=x6Emn6VesHQz0wR7ZnbddPRYO9A5-0JTHDkzJ3Ocq6w,10907
|
|
5
5
|
kumoai/formatting.py,sha256=jA_rLDCGKZI8WWCha-vtuLenVKTZvli99Tqpurz1H84,953
|
|
6
6
|
kumoai/futures.py,sha256=oJFIfdCM_3nWIqQteBKYMY4fPhoYlYWE_JA2o6tx-ng,3737
|
|
7
7
|
kumoai/kumolib.cpython-311-darwin.so,sha256=AmB_Fysmud1y7Gm5CuBQ5lWDuSzpxVDV_iTA2cjH1s8,232544
|
|
@@ -11,38 +11,43 @@ kumoai/databricks.py,sha256=e6E4lOFvZHXFwh4CO1kXU1zzDU3AapLQYMxjiHPC-HQ,476
|
|
|
11
11
|
kumoai/spcs.py,sha256=N31d7rLa-bgYh8e2J4YzX1ScxGLqiVXrqJnCl1y4Mts,4139
|
|
12
12
|
kumoai/_singleton.py,sha256=UTwrbDkoZSGB8ZelorvprPDDv9uZkUi1q_SrmsyngpQ,836
|
|
13
13
|
kumoai/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
14
|
-
kumoai/experimental/rfm/
|
|
15
|
-
kumoai/experimental/rfm/
|
|
16
|
-
kumoai/experimental/rfm/
|
|
17
|
-
kumoai/experimental/rfm/
|
|
18
|
-
kumoai/experimental/rfm/
|
|
19
|
-
kumoai/experimental/rfm/
|
|
20
|
-
kumoai/experimental/rfm/
|
|
14
|
+
kumoai/experimental/rfm/relbench.py,sha256=cVsxxV3TIL3PLEoYb-8tAVW3GSef6NQAd3rxdHJL63I,2276
|
|
15
|
+
kumoai/experimental/rfm/graph.py,sha256=H9lIQLDkL5zJMwEHh7PgruvMUxWsjpynXUT7gnmTTUM,46351
|
|
16
|
+
kumoai/experimental/rfm/__init__.py,sha256=bW2XyYtkbdiu_iICYFF2Fu1Fx5fyGbqne6m_6c1P-fY,7016
|
|
17
|
+
kumoai/experimental/rfm/sagemaker.py,sha256=6fyXO1Jd_scq-DH7kcv6JcV8QPyTbh4ceqwQDPADlZ0,4963
|
|
18
|
+
kumoai/experimental/rfm/rfm.py,sha256=6XCx_OeJI0X5LhRKypc1r6dHKieSYFYvo-8OnG3M9UE,57545
|
|
19
|
+
kumoai/experimental/rfm/authenticate.py,sha256=G2RkRWznMVQUzvhvbKhn0bMCY7VmoNYxluz3THRqSdE,18851
|
|
20
|
+
kumoai/experimental/rfm/task_table.py,sha256=rzea9WTVx4zs6Y2QZdWG15C5GG9T2IQsxYPlsR1UFSs,9771
|
|
21
21
|
kumoai/experimental/rfm/backend/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
22
|
-
kumoai/experimental/rfm/backend/sqlite/__init__.py,sha256=
|
|
23
|
-
kumoai/experimental/rfm/backend/sqlite/table.py,sha256=
|
|
22
|
+
kumoai/experimental/rfm/backend/sqlite/__init__.py,sha256=jl-DBbhsqQ-dUXyWhyQTM1AU2qNAtXCmi1mokdhtBTg,902
|
|
23
|
+
kumoai/experimental/rfm/backend/sqlite/table.py,sha256=WqYtd_rwlawItRMXZUfv14qdyU6huQmODuFjDo483dI,6683
|
|
24
|
+
kumoai/experimental/rfm/backend/sqlite/sampler.py,sha256=_D9C5mj3oL4J2qZFap3emvTy2jxzth3dEWZPfr4dmEE,16201
|
|
24
25
|
kumoai/experimental/rfm/backend/local/__init__.py,sha256=2s9sSA-E-8pfkkzCH4XPuaSxSznEURMfMgwEIfYYPsg,1014
|
|
25
|
-
kumoai/experimental/rfm/backend/local/table.py,sha256=
|
|
26
|
-
kumoai/experimental/rfm/backend/local/graph_store.py,sha256=
|
|
27
|
-
kumoai/experimental/rfm/backend/local/sampler.py,sha256=
|
|
28
|
-
kumoai/experimental/rfm/backend/snow/__init__.py,sha256=
|
|
29
|
-
kumoai/experimental/rfm/backend/snow/table.py,sha256=
|
|
26
|
+
kumoai/experimental/rfm/backend/local/table.py,sha256=GKeYGcu52ztCU8EBMqp5UVj85E145Ug41xiCPiTCXq4,3489
|
|
27
|
+
kumoai/experimental/rfm/backend/local/graph_store.py,sha256=RHhkI13KpdPxqb4vXkwEwuFiX5DkrEsfZsOLywNnrvU,11294
|
|
28
|
+
kumoai/experimental/rfm/backend/local/sampler.py,sha256=UKxTjsYs00sYuV_LAlDuZOvQq0BZzPCzZK1Fki2Fd70,10726
|
|
29
|
+
kumoai/experimental/rfm/backend/snow/__init__.py,sha256=BYfsiuJ4Ee30GjG9EuUtitMHXnRfvVKi85zNlIwldV4,993
|
|
30
|
+
kumoai/experimental/rfm/backend/snow/table.py,sha256=9N7TOcXX8hhAjCawnhuvQCArBFTCdng3gBakunUxg90,8892
|
|
31
|
+
kumoai/experimental/rfm/backend/snow/sampler.py,sha256=zvPsgVnDfvskcnPWsIcqxw-Fn9DsCLfdoLE-m3bjeww,11483
|
|
30
32
|
kumoai/experimental/rfm/pquery/__init__.py,sha256=X0O3EIq5SMfBEE-ii5Cq6iDhR3s3XMXB52Cx5htoePw,152
|
|
31
|
-
kumoai/experimental/rfm/pquery/pandas_executor.py,sha256=
|
|
32
|
-
kumoai/experimental/rfm/pquery/executor.py,sha256=
|
|
33
|
-
kumoai/experimental/rfm/infer/multicategorical.py,sha256=
|
|
33
|
+
kumoai/experimental/rfm/pquery/pandas_executor.py,sha256=MwSvFRwLq-z19LEdF0G0AT7Gj9tCqu-XLEA7mNbqXwc,18454
|
|
34
|
+
kumoai/experimental/rfm/pquery/executor.py,sha256=gs5AVNaA50ci8zXOBD3qt5szdTReSwTs4BGuEyx4BEE,2728
|
|
35
|
+
kumoai/experimental/rfm/infer/multicategorical.py,sha256=lNO_8aJw1whO6QVEMB3PRWMNlEEiX44g3v4tP88TSQY,1119
|
|
34
36
|
kumoai/experimental/rfm/infer/categorical.py,sha256=VwNaKwKbRYkTxEJ1R6gziffC8dGsEThcDEfbi-KqW5c,853
|
|
35
|
-
kumoai/experimental/rfm/infer/time_col.py,sha256=
|
|
36
|
-
kumoai/experimental/rfm/infer/pkey.py,sha256=
|
|
37
|
+
kumoai/experimental/rfm/infer/time_col.py,sha256=oNenUK6P7ql8uwShodtQ73uG1x3fbFWT78jRcF9DLTI,1789
|
|
38
|
+
kumoai/experimental/rfm/infer/pkey.py,sha256=IaJI5GHK8ds_a3AOr3YYVgUlSmYYEgr4Nu92s2RyBV4,4412
|
|
37
39
|
kumoai/experimental/rfm/infer/id.py,sha256=ZIO0DWIoiEoS_8MVc5lkqBfkTWWQ0yGCgjkwLdaYa_Q,908
|
|
38
|
-
kumoai/experimental/rfm/infer/dtype.py,sha256=
|
|
39
|
-
kumoai/experimental/rfm/infer/__init__.py,sha256=
|
|
40
|
+
kumoai/experimental/rfm/infer/dtype.py,sha256=FyAqvtrOWQC9hGrhQ7sC4BAI6c9k6ew-fo8ClS1sewM,2782
|
|
41
|
+
kumoai/experimental/rfm/infer/__init__.py,sha256=8GDxQKd0pxZULdk7mpwl3CsOpL4v2HPuPEsbi2t_vzc,519
|
|
40
42
|
kumoai/experimental/rfm/infer/timestamp.py,sha256=vM9--7eStzaGG13Y-oLYlpNJyhL6f9dp17HDXwtl_DM,1094
|
|
41
|
-
kumoai/experimental/rfm/
|
|
42
|
-
kumoai/experimental/rfm/base/
|
|
43
|
-
kumoai/experimental/rfm/base/
|
|
44
|
-
kumoai/experimental/rfm/base/
|
|
45
|
-
kumoai/experimental/rfm/base/
|
|
43
|
+
kumoai/experimental/rfm/infer/stype.py,sha256=fu4zsOB-C7jNeMnq6dsK4bOZSewe7PtZe_AkohSRLoM,894
|
|
44
|
+
kumoai/experimental/rfm/base/sql_sampler.py,sha256=1M0B2qSUT2JmiR87xdivrLXk75jn9sy_Y3DUYqsjeK4,5151
|
|
45
|
+
kumoai/experimental/rfm/base/__init__.py,sha256=rjmMux5lG8srw1bjQGcFQFv6zET9e5riP81nPkw28Jg,724
|
|
46
|
+
kumoai/experimental/rfm/base/table.py,sha256=6qZeTMfnQejrn6TwqQeJGzJG7C0dSjJ7-NMLX38dvns,26563
|
|
47
|
+
kumoai/experimental/rfm/base/sampler.py,sha256=tXYnVEyKC5NjSIpe8pNYp0V3Qbg-KbUE_QB0Emy2YiQ,30882
|
|
48
|
+
kumoai/experimental/rfm/base/expression.py,sha256=Y7NtLTnKlx6euG_N3fLTcrFKheB6P5KS_jhCfoXV9DE,1252
|
|
49
|
+
kumoai/experimental/rfm/base/source.py,sha256=bwu3GU2TvIXR2fwKAmJ1-5BDoNXMnI1SU3Fgdk8lWnc,301
|
|
50
|
+
kumoai/experimental/rfm/base/column.py,sha256=GXzLC-VpShr6PecUzaj1MJKc_PHzfW5Jn9bOYPA8fFA,4965
|
|
46
51
|
kumoai/encoder/__init__.py,sha256=VPGs4miBC_WfwWeOXeHhFomOUocERFavhKf5fqITcds,182
|
|
47
52
|
kumoai/graph/graph.py,sha256=iyp4klPIMn2ttuEqMJvsrxKb_tmz_DTnvziIhCegduM,38291
|
|
48
53
|
kumoai/graph/__init__.py,sha256=n8X4X8luox4hPBHTRC9R-3JzvYYMoR8n7lF1H4w4Hzc,228
|
|
@@ -52,8 +57,10 @@ kumoai/artifact_export/config.py,sha256=jOPDduduxv0uuB-7xVlDiZglfpmFF5lzQhhH1SMk
|
|
|
52
57
|
kumoai/artifact_export/job.py,sha256=GEisSwvcjK_35RgOfsLXGgxMTXIWm765B_BW_Kgs-V0,3275
|
|
53
58
|
kumoai/artifact_export/__init__.py,sha256=BsfDrc3mCHpO9-BqvqKm8qrXDIwfdaoH5UIoG4eQkc4,238
|
|
54
59
|
kumoai/utils/datasets.py,sha256=ptKIUoBONVD55pTVNdRCkQT3NWdN_r9UAUu4xewPa3U,2928
|
|
55
|
-
kumoai/utils/__init__.py,sha256=
|
|
56
|
-
kumoai/utils/
|
|
60
|
+
kumoai/utils/__init__.py,sha256=6S-UtwjeLpnCYRCCIEWhkitPYGaqOGXC1ChE13DzXiU,256
|
|
61
|
+
kumoai/utils/display.py,sha256=eXlw4B72y6zEruWYOfwvfqxfMBTL9AsPtWfw3BjaWqQ,1397
|
|
62
|
+
kumoai/utils/progress_logger.py,sha256=rRcfWnfV6uHuvb7cD0mIIfUz3JvnSae0U4SesncODU8,9505
|
|
63
|
+
kumoai/utils/sql.py,sha256=f6lR6rBEW7Dtk0NdM26dOZXUHDizEHb1WPlBCJrwoq0,118
|
|
57
64
|
kumoai/utils/forecasting.py,sha256=-nDS6ucKNfQhTQOfebjefj0wwWH3-KYNslIomxwwMBM,7415
|
|
58
65
|
kumoai/codegen/generate.py,sha256=SvfWWa71xSAOjH9645yQvgoEM-o4BYjupM_EpUxqB_E,7331
|
|
59
66
|
kumoai/codegen/naming.py,sha256=_XVQGxHfuub4bhvyuBKjltD5Lm_oPpibvP_LZteCGk0,3021
|
|
@@ -71,6 +78,7 @@ kumoai/codegen/handlers/__init__.py,sha256=k8TB_Kn-1BycBBi51kqFS2fZHCpCPgR9-3J9g
|
|
|
71
78
|
kumoai/codegen/handlers/utils.py,sha256=58b2GCgaTBUp2aId7BLMXMV0ENrusbNbfw7mlyXAXPE,1447
|
|
72
79
|
kumoai/codegen/handlers/connector.py,sha256=afGf_GreyQ9y6qF3QTgSiM416qtUcP298SatNqUFhvQ,3828
|
|
73
80
|
kumoai/codegen/handlers/table.py,sha256=POHpA-GFYFGTSuerGmtigYablk-Wq1L3EBvsOI-iFMQ,3956
|
|
81
|
+
kumoai/testing/snow.py,sha256=ubx3yJP0UHxsNiar1-jNdv8ZfszKc8Js3_Gg70uf008,1487
|
|
74
82
|
kumoai/testing/__init__.py,sha256=goHIIo3JE7uHV7njo4_aTd89mVVR74BEAZ2uyBaOR0w,170
|
|
75
83
|
kumoai/testing/decorators.py,sha256=83tMifuPTpUqX7zHxMttkj1TDdB62EBtAP-Fjj72Zdo,1607
|
|
76
84
|
kumoai/connector/glue_connector.py,sha256=HivT0QYQ8-XeB4QLgWvghiqXuq7jyBK9G2R1py_NnE4,4697
|
|
@@ -80,20 +88,20 @@ kumoai/connector/bigquery_connector.py,sha256=IkyRqvF8Cg96kApUuuz86eYnl-BqBmDX1f
|
|
|
80
88
|
kumoai/connector/source_table.py,sha256=QLT8bEYaxeMwy-b168url0VfnkTrs5K6VKLbxTI4hEY,17539
|
|
81
89
|
kumoai/connector/__init__.py,sha256=9g6oNJ0qHWFlL5enTSoK4_SSH_5hP74xUDZx-9SggC4,842
|
|
82
90
|
kumoai/connector/file_upload_connector.py,sha256=swp03HgChOvmNPJetuujBSAqADe7NRmS_T0F3o9it4w,7008
|
|
83
|
-
kumoai/connector/utils.py,sha256=
|
|
91
|
+
kumoai/connector/utils.py,sha256=sD3_Dmf42FobMfVayzMVkDHIfXzPN-htD3RHd6Kw8hQ,65055
|
|
84
92
|
kumoai/connector/s3_connector.py,sha256=3kbv-h7DwD8O260Q0h1GPm5wwQpLt-Tb3d_CBSaie44,10155
|
|
85
93
|
kumoai/connector/base.py,sha256=cujXSZF3zAfuxNuEw54DSL1T7XCuR4t0shSMDuPUagQ,5291
|
|
86
94
|
kumoai/pquery/__init__.py,sha256=uTXr7t1eXcVfM-ETaM_1ImfEqhrmaj8BjiIvy1YZTL8,533
|
|
87
|
-
kumoai/pquery/predictive_query.py,sha256=
|
|
95
|
+
kumoai/pquery/predictive_query.py,sha256=UXn1s8ztubYZMNGl4ijaeidMiGlFveb1TGw9qI5-TAo,24901
|
|
88
96
|
kumoai/pquery/prediction_table.py,sha256=QPDH22X1UB0NIufY7qGuV2XW7brG3Pv--FbjNezzM2g,10776
|
|
89
97
|
kumoai/pquery/training_table.py,sha256=elmPDZx11kPiC_dkOhJcBUGtHKgL32GCBvZ9k6U0pMg,15809
|
|
90
|
-
kumoai/client/pquery.py,sha256=
|
|
91
|
-
kumoai/client/client.py,sha256=
|
|
98
|
+
kumoai/client/pquery.py,sha256=IQ8As-OOJOkuMoMosphOsA5hxQYLCbzOQJO7RezK8uY,7091
|
|
99
|
+
kumoai/client/client.py,sha256=npTLooBtmZ9xOo7AbEiYQTh9wFktsGSEpSEfdB7vdB4,8715
|
|
92
100
|
kumoai/client/graph.py,sha256=zvLEDExLT_RVbUMHqVl0m6tO6s2gXmYSoWmPF6YMlnA,3831
|
|
93
101
|
kumoai/client/online.py,sha256=pkBBh_DEC3GAnPcNw6bopNRlGe7EUbIFe7_seQqZRaw,2720
|
|
94
102
|
kumoai/client/source_table.py,sha256=VCsCcM7KYcnjGP7HLTb-AOSEGEVsJTWjk8bMg1JdgPU,2101
|
|
95
103
|
kumoai/client/__init__.py,sha256=MkyOuMaHQ2c8GPxjBDQSVFhfRE2d2_6CXQ6rxj4ps4w,64
|
|
96
|
-
kumoai/client/jobs.py,sha256=
|
|
104
|
+
kumoai/client/jobs.py,sha256=z3By5MWvWdJ_wYFyJA34pD4NueOXvXEqrAANWEpp4Pk,18066
|
|
97
105
|
kumoai/client/utils.py,sha256=lz1NubwMDHCwzQRowRXm7mjAoYRd5UjRQIwXdtWAl90,3849
|
|
98
106
|
kumoai/client/connector.py,sha256=x3i2aBTJTEMZvYRcWkY-UfWVOANZjqAso4GBbcshFjw,3920
|
|
99
107
|
kumoai/client/table.py,sha256=cQG-RPm-e91idEgse1IPJDvBmzddIDGDkuyrR1rq4wU,3235
|
|
@@ -105,9 +113,10 @@ kumoai/trainer/job.py,sha256=Wk69nzFhbvuA3nEvtCstI04z5CxkgvQ6tHnGchE0Lkg,44938
|
|
|
105
113
|
kumoai/trainer/baseline_trainer.py,sha256=LlfViNOmswNv4c6zJJLsyv0pC2mM2WKMGYx06ogtEVc,4024
|
|
106
114
|
kumoai/trainer/__init__.py,sha256=zUdFl-f-sBWmm2x8R-rdVzPBeU2FaMzUY5mkcgoTa1k,939
|
|
107
115
|
kumoai/trainer/online_serving.py,sha256=9cddb5paeZaCgbUeceQdAOxysCtV5XP-KcsgFz_XR5w,9566
|
|
116
|
+
kumoai/trainer/distilled_trainer.py,sha256=2pPs5clakNxkLfaak7uqPJOrpTWe1RVVM7ztDSqQZvU,6484
|
|
108
117
|
kumoai/trainer/trainer.py,sha256=hBXO7gwpo3t59zKFTeIkK65B8QRmWCwO33sbDuEAPlY,20133
|
|
109
|
-
kumoai-2.
|
|
110
|
-
kumoai-2.
|
|
111
|
-
kumoai-2.
|
|
112
|
-
kumoai-2.
|
|
113
|
-
kumoai-2.
|
|
118
|
+
kumoai-2.14.0.dev202601051732.dist-info/RECORD,,
|
|
119
|
+
kumoai-2.14.0.dev202601051732.dist-info/WHEEL,sha256=sunMa2yiYbrNLGeMVDqEA0ayyJbHlex7SCn1TZrEq60,136
|
|
120
|
+
kumoai-2.14.0.dev202601051732.dist-info/top_level.txt,sha256=YjU6UcmomoDx30vEXLsOU784ED7VztQOsFApk1SFwvs,7
|
|
121
|
+
kumoai-2.14.0.dev202601051732.dist-info/METADATA,sha256=JPohnaTwjtH8K7Bx7Rl14fcTQc1JF9fB2sWmmhJZgQw,2557
|
|
122
|
+
kumoai-2.14.0.dev202601051732.dist-info/licenses/LICENSE,sha256=TbWlyqRmhq9PEzCaTI0H0nWLQCCOywQM8wYH8MbjfLo,1102
|
|
@@ -1,223 +0,0 @@
|
|
|
1
|
-
import re
|
|
2
|
-
from typing import Dict, List, Optional, Tuple
|
|
3
|
-
|
|
4
|
-
import numpy as np
|
|
5
|
-
import pandas as pd
|
|
6
|
-
from kumoapi.rfm.context import EdgeLayout, Link, Subgraph, Table
|
|
7
|
-
from kumoapi.typing import Stype
|
|
8
|
-
|
|
9
|
-
import kumoai.kumolib as kumolib
|
|
10
|
-
from kumoai.experimental.rfm.backend.local import LocalGraphStore
|
|
11
|
-
|
|
12
|
-
PUNCTUATION = re.compile(r"[\'\"\.,\(\)\!\?\;\:]")
|
|
13
|
-
MULTISPACE = re.compile(r"\s+")
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
def normalize_text(
|
|
17
|
-
ser: pd.Series,
|
|
18
|
-
max_words: Optional[int] = 50,
|
|
19
|
-
) -> pd.Series:
|
|
20
|
-
r"""Normalizes text into a list of lower-case words.
|
|
21
|
-
|
|
22
|
-
Args:
|
|
23
|
-
ser: The :class:`pandas.Series` to normalize.
|
|
24
|
-
max_words: The maximum number of words to return.
|
|
25
|
-
This will auto-shrink any large text column to avoid blowing up
|
|
26
|
-
context size.
|
|
27
|
-
"""
|
|
28
|
-
if len(ser) == 0 or pd.api.types.is_list_like(ser.iloc[0]):
|
|
29
|
-
return ser
|
|
30
|
-
|
|
31
|
-
def normalize_fn(line: str) -> list[str]:
|
|
32
|
-
line = PUNCTUATION.sub(" ", line)
|
|
33
|
-
line = re.sub(r"<br\s*/?>", " ", line) # Handle <br /> or <br>
|
|
34
|
-
line = MULTISPACE.sub(" ", line)
|
|
35
|
-
words = line.split()
|
|
36
|
-
if max_words is not None:
|
|
37
|
-
words = words[:max_words]
|
|
38
|
-
return words
|
|
39
|
-
|
|
40
|
-
ser = ser.fillna('').astype(str)
|
|
41
|
-
|
|
42
|
-
if max_words is not None:
|
|
43
|
-
# We estimate the number of words as 5 characters + 1 space in an
|
|
44
|
-
# English text on average. We need this pre-filter here, as word
|
|
45
|
-
# splitting on a giant text can be very expensive:
|
|
46
|
-
ser = ser.str[:6 * max_words]
|
|
47
|
-
|
|
48
|
-
ser = ser.str.lower()
|
|
49
|
-
ser = ser.map(normalize_fn)
|
|
50
|
-
|
|
51
|
-
return ser
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
class LocalGraphSampler:
|
|
55
|
-
def __init__(self, graph_store: LocalGraphStore) -> None:
|
|
56
|
-
self._graph_store = graph_store
|
|
57
|
-
self._sampler = kumolib.NeighborSampler(
|
|
58
|
-
self._graph_store.node_types,
|
|
59
|
-
self._graph_store.edge_types,
|
|
60
|
-
{
|
|
61
|
-
'__'.join(edge_type): colptr
|
|
62
|
-
for edge_type, colptr in self._graph_store.colptr_dict.items()
|
|
63
|
-
},
|
|
64
|
-
{
|
|
65
|
-
'__'.join(edge_type): row
|
|
66
|
-
for edge_type, row in self._graph_store.row_dict.items()
|
|
67
|
-
},
|
|
68
|
-
self._graph_store.time_dict,
|
|
69
|
-
)
|
|
70
|
-
|
|
71
|
-
def __call__(
|
|
72
|
-
self,
|
|
73
|
-
entity_table_names: Tuple[str, ...],
|
|
74
|
-
node: np.ndarray,
|
|
75
|
-
time: np.ndarray,
|
|
76
|
-
num_neighbors: List[int],
|
|
77
|
-
exclude_cols_dict: Dict[str, List[str]],
|
|
78
|
-
) -> Subgraph:
|
|
79
|
-
|
|
80
|
-
(
|
|
81
|
-
row_dict,
|
|
82
|
-
col_dict,
|
|
83
|
-
node_dict,
|
|
84
|
-
batch_dict,
|
|
85
|
-
num_sampled_nodes_dict,
|
|
86
|
-
num_sampled_edges_dict,
|
|
87
|
-
) = self._sampler.sample(
|
|
88
|
-
{
|
|
89
|
-
'__'.join(edge_type): num_neighbors
|
|
90
|
-
for edge_type in self._graph_store.edge_types
|
|
91
|
-
},
|
|
92
|
-
{}, # time interval based sampling
|
|
93
|
-
entity_table_names[0],
|
|
94
|
-
node,
|
|
95
|
-
time // 1000**3, # nanoseconds to seconds
|
|
96
|
-
)
|
|
97
|
-
|
|
98
|
-
table_dict: Dict[str, Table] = {}
|
|
99
|
-
for table_name, node in node_dict.items():
|
|
100
|
-
batch = batch_dict[table_name]
|
|
101
|
-
|
|
102
|
-
if len(node) == 0:
|
|
103
|
-
continue
|
|
104
|
-
|
|
105
|
-
df = self._graph_store.df_dict[table_name]
|
|
106
|
-
|
|
107
|
-
num_sampled_nodes = num_sampled_nodes_dict[table_name].tolist()
|
|
108
|
-
stype_dict = { # Exclude target columns:
|
|
109
|
-
column_name: stype
|
|
110
|
-
for column_name, stype in
|
|
111
|
-
self._graph_store.stype_dict[table_name].items()
|
|
112
|
-
if column_name not in exclude_cols_dict.get(table_name, [])
|
|
113
|
-
}
|
|
114
|
-
primary_key: Optional[str] = None
|
|
115
|
-
if table_name in entity_table_names:
|
|
116
|
-
primary_key = self._graph_store.pkey_name_dict.get(table_name)
|
|
117
|
-
|
|
118
|
-
columns: List[str] = []
|
|
119
|
-
if table_name in entity_table_names:
|
|
120
|
-
columns += [self._graph_store.pkey_name_dict[table_name]]
|
|
121
|
-
columns += list(stype_dict.keys())
|
|
122
|
-
|
|
123
|
-
if len(columns) == 0:
|
|
124
|
-
table_dict[table_name] = Table(
|
|
125
|
-
df=pd.DataFrame(index=range(len(node))),
|
|
126
|
-
row=None,
|
|
127
|
-
batch=batch,
|
|
128
|
-
num_sampled_nodes=num_sampled_nodes,
|
|
129
|
-
stype_dict=stype_dict,
|
|
130
|
-
primary_key=primary_key,
|
|
131
|
-
)
|
|
132
|
-
continue
|
|
133
|
-
|
|
134
|
-
row: Optional[np.ndarray] = None
|
|
135
|
-
if table_name in self._graph_store.end_time_column_dict:
|
|
136
|
-
# Set end time to NaT for all values greater than anchor time:
|
|
137
|
-
df = df.iloc[node].reset_index(drop=True)
|
|
138
|
-
col_name = self._graph_store.end_time_column_dict[table_name]
|
|
139
|
-
ser = df[col_name]
|
|
140
|
-
value = ser.astype('datetime64[ns]').astype(int).to_numpy()
|
|
141
|
-
mask = value > time[batch]
|
|
142
|
-
df.loc[mask, col_name] = pd.NaT
|
|
143
|
-
else:
|
|
144
|
-
# Only store unique rows in `df` above a certain threshold:
|
|
145
|
-
unique_node, inverse = np.unique(node, return_inverse=True)
|
|
146
|
-
if len(node) > 1.05 * len(unique_node):
|
|
147
|
-
df = df.iloc[unique_node].reset_index(drop=True)
|
|
148
|
-
row = inverse
|
|
149
|
-
else:
|
|
150
|
-
df = df.iloc[node].reset_index(drop=True)
|
|
151
|
-
|
|
152
|
-
# Filter data frame to minimal set of columns:
|
|
153
|
-
df = df[columns]
|
|
154
|
-
|
|
155
|
-
# Normalize text (if not already pre-processed):
|
|
156
|
-
for column_name, stype in stype_dict.items():
|
|
157
|
-
if stype == Stype.text:
|
|
158
|
-
df[column_name] = normalize_text(df[column_name])
|
|
159
|
-
|
|
160
|
-
table_dict[table_name] = Table(
|
|
161
|
-
df=df,
|
|
162
|
-
row=row,
|
|
163
|
-
batch=batch,
|
|
164
|
-
num_sampled_nodes=num_sampled_nodes,
|
|
165
|
-
stype_dict=stype_dict,
|
|
166
|
-
primary_key=primary_key,
|
|
167
|
-
)
|
|
168
|
-
|
|
169
|
-
link_dict: Dict[Tuple[str, str, str], Link] = {}
|
|
170
|
-
for edge_type in self._graph_store.edge_types:
|
|
171
|
-
edge_type_str = '__'.join(edge_type)
|
|
172
|
-
|
|
173
|
-
row = row_dict[edge_type_str]
|
|
174
|
-
col = col_dict[edge_type_str]
|
|
175
|
-
|
|
176
|
-
if len(row) == 0:
|
|
177
|
-
continue
|
|
178
|
-
|
|
179
|
-
# Do not store reverse edge type if it is a replica:
|
|
180
|
-
rev_edge_type = Subgraph.rev_edge_type(edge_type)
|
|
181
|
-
rev_edge_type_str = '__'.join(rev_edge_type)
|
|
182
|
-
if (rev_edge_type in link_dict
|
|
183
|
-
and np.array_equal(row, col_dict[rev_edge_type_str])
|
|
184
|
-
and np.array_equal(col, row_dict[rev_edge_type_str])):
|
|
185
|
-
link = Link(
|
|
186
|
-
layout=EdgeLayout.REV,
|
|
187
|
-
row=None,
|
|
188
|
-
col=None,
|
|
189
|
-
num_sampled_edges=(
|
|
190
|
-
num_sampled_edges_dict[edge_type_str].tolist()),
|
|
191
|
-
)
|
|
192
|
-
link_dict[edge_type] = link
|
|
193
|
-
continue
|
|
194
|
-
|
|
195
|
-
layout = EdgeLayout.COO
|
|
196
|
-
if np.array_equal(row, np.arange(len(row))):
|
|
197
|
-
row = None
|
|
198
|
-
if np.array_equal(col, np.arange(len(col))):
|
|
199
|
-
col = None
|
|
200
|
-
|
|
201
|
-
# Store in compressed representation if more efficient:
|
|
202
|
-
num_cols = table_dict[edge_type[2]].num_rows
|
|
203
|
-
if col is not None and len(col) > num_cols + 1:
|
|
204
|
-
layout = EdgeLayout.CSC
|
|
205
|
-
colcount = np.bincount(col, minlength=num_cols)
|
|
206
|
-
col = np.empty(num_cols + 1, dtype=col.dtype)
|
|
207
|
-
col[0] = 0
|
|
208
|
-
np.cumsum(colcount, out=col[1:])
|
|
209
|
-
|
|
210
|
-
link = Link(
|
|
211
|
-
layout=layout,
|
|
212
|
-
row=row,
|
|
213
|
-
col=col,
|
|
214
|
-
num_sampled_edges=(
|
|
215
|
-
num_sampled_edges_dict[edge_type_str].tolist()),
|
|
216
|
-
)
|
|
217
|
-
link_dict[edge_type] = link
|
|
218
|
-
|
|
219
|
-
return Subgraph(
|
|
220
|
-
anchor_time=time,
|
|
221
|
-
table_dict=table_dict,
|
|
222
|
-
link_dict=link_dict,
|
|
223
|
-
)
|