kumoai 2.13.0.dev202511271731__cp312-cp312-win_amd64.whl → 2.14.0.dev202512111731__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +12 -0
- kumoai/_version.py +1 -1
- kumoai/connector/utils.py +23 -2
- kumoai/experimental/rfm/__init__.py +20 -45
- kumoai/experimental/rfm/backend/__init__.py +0 -0
- kumoai/experimental/rfm/backend/local/__init__.py +42 -0
- kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +37 -90
- kumoai/experimental/rfm/backend/local/sampler.py +313 -0
- kumoai/experimental/rfm/backend/local/table.py +109 -0
- kumoai/experimental/rfm/backend/snow/__init__.py +35 -0
- kumoai/experimental/rfm/backend/snow/table.py +117 -0
- kumoai/experimental/rfm/backend/sqlite/__init__.py +30 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +101 -0
- kumoai/experimental/rfm/base/__init__.py +13 -0
- kumoai/experimental/rfm/base/column.py +66 -0
- kumoai/experimental/rfm/base/sampler.py +763 -0
- kumoai/experimental/rfm/base/source.py +18 -0
- kumoai/experimental/rfm/{local_table.py → base/table.py} +139 -139
- kumoai/experimental/rfm/{local_graph.py → graph.py} +334 -79
- kumoai/experimental/rfm/infer/__init__.py +6 -0
- kumoai/experimental/rfm/infer/dtype.py +79 -0
- kumoai/experimental/rfm/infer/pkey.py +126 -0
- kumoai/experimental/rfm/infer/time_col.py +62 -0
- kumoai/experimental/rfm/pquery/pandas_executor.py +1 -1
- kumoai/experimental/rfm/rfm.py +204 -166
- kumoai/experimental/rfm/sagemaker.py +11 -3
- kumoai/kumolib.cp312-win_amd64.pyd +0 -0
- kumoai/pquery/predictive_query.py +10 -6
- kumoai/testing/decorators.py +1 -1
- {kumoai-2.13.0.dev202511271731.dist-info → kumoai-2.14.0.dev202512111731.dist-info}/METADATA +9 -8
- {kumoai-2.13.0.dev202511271731.dist-info → kumoai-2.14.0.dev202512111731.dist-info}/RECORD +34 -22
- kumoai/experimental/rfm/local_graph_sampler.py +0 -182
- kumoai/experimental/rfm/local_pquery_driver.py +0 -689
- kumoai/experimental/rfm/utils.py +0 -344
- {kumoai-2.13.0.dev202511271731.dist-info → kumoai-2.14.0.dev202512111731.dist-info}/WHEEL +0 -0
- {kumoai-2.13.0.dev202511271731.dist-info → kumoai-2.14.0.dev202512111731.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.13.0.dev202511271731.dist-info → kumoai-2.14.0.dev202512111731.dist-info}/top_level.txt +0 -0
|
@@ -370,9 +370,11 @@ class PredictiveQuery:
|
|
|
370
370
|
train_table_job_api = global_state.client.generate_train_table_job_api
|
|
371
371
|
job_id: GenerateTrainTableJobID = train_table_job_api.create(
|
|
372
372
|
GenerateTrainTableRequest(
|
|
373
|
-
dict(custom_tags),
|
|
374
|
-
|
|
375
|
-
|
|
373
|
+
dict(custom_tags),
|
|
374
|
+
pq_id,
|
|
375
|
+
plan,
|
|
376
|
+
None,
|
|
377
|
+
))
|
|
376
378
|
|
|
377
379
|
self._train_table = TrainingTableJob(job_id=job_id)
|
|
378
380
|
if non_blocking:
|
|
@@ -451,9 +453,11 @@ class PredictiveQuery:
|
|
|
451
453
|
bp_table_api = global_state.client.generate_prediction_table_job_api
|
|
452
454
|
job_id: GeneratePredictionTableJobID = bp_table_api.create(
|
|
453
455
|
GeneratePredictionTableRequest(
|
|
454
|
-
dict(custom_tags),
|
|
455
|
-
|
|
456
|
-
|
|
456
|
+
dict(custom_tags),
|
|
457
|
+
pq_id,
|
|
458
|
+
plan,
|
|
459
|
+
None,
|
|
460
|
+
))
|
|
457
461
|
|
|
458
462
|
self._prediction_table = PredictionTableJob(job_id=job_id)
|
|
459
463
|
if non_blocking:
|
kumoai/testing/decorators.py
CHANGED
|
@@ -25,7 +25,7 @@ def onlyFullTest(func: Callable) -> Callable:
|
|
|
25
25
|
def has_package(package: str) -> bool:
|
|
26
26
|
r"""Returns ``True`` in case ``package`` is installed."""
|
|
27
27
|
req = Requirement(package)
|
|
28
|
-
if importlib.util.find_spec(req.name) is None:
|
|
28
|
+
if importlib.util.find_spec(req.name) is None: # type: ignore
|
|
29
29
|
return False
|
|
30
30
|
|
|
31
31
|
try:
|
{kumoai-2.13.0.dev202511271731.dist-info → kumoai-2.14.0.dev202512111731.dist-info}/METADATA
RENAMED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: kumoai
|
|
3
|
-
Version: 2.
|
|
3
|
+
Version: 2.14.0.dev202512111731
|
|
4
4
|
Summary: AI on the Modern Data Stack
|
|
5
5
|
Author-email: "Kumo.AI" <hello@kumo.ai>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -23,13 +23,11 @@ Requires-Dist: requests>=2.28.2
|
|
|
23
23
|
Requires-Dist: urllib3
|
|
24
24
|
Requires-Dist: plotly
|
|
25
25
|
Requires-Dist: typing_extensions>=4.5.0
|
|
26
|
-
Requires-Dist: kumo-api==0.
|
|
26
|
+
Requires-Dist: kumo-api==0.49.0
|
|
27
27
|
Requires-Dist: tqdm>=4.66.0
|
|
28
28
|
Requires-Dist: aiohttp>=3.10.0
|
|
29
29
|
Requires-Dist: pydantic>=1.10.21
|
|
30
30
|
Requires-Dist: rich>=9.0.0
|
|
31
|
-
Requires-Dist: mypy-boto3-sagemaker-runtime
|
|
32
|
-
Requires-Dist: boto3
|
|
33
31
|
Provides-Extra: doc
|
|
34
32
|
Requires-Dist: sphinx; extra == "doc"
|
|
35
33
|
Requires-Dist: sphinx-book-theme; extra == "doc"
|
|
@@ -40,13 +38,16 @@ Provides-Extra: test
|
|
|
40
38
|
Requires-Dist: pytest; extra == "test"
|
|
41
39
|
Requires-Dist: pytest-mock; extra == "test"
|
|
42
40
|
Requires-Dist: requests-mock; extra == "test"
|
|
43
|
-
Provides-Extra:
|
|
44
|
-
Requires-Dist:
|
|
45
|
-
|
|
46
|
-
Requires-Dist:
|
|
41
|
+
Provides-Extra: sqlite
|
|
42
|
+
Requires-Dist: adbc_driver_sqlite; extra == "sqlite"
|
|
43
|
+
Provides-Extra: snowflake
|
|
44
|
+
Requires-Dist: snowflake-connector-python; extra == "snowflake"
|
|
45
|
+
Requires-Dist: pyyaml; extra == "snowflake"
|
|
47
46
|
Provides-Extra: sagemaker
|
|
48
47
|
Requires-Dist: boto3<2.0,>=1.30.0; extra == "sagemaker"
|
|
49
48
|
Requires-Dist: mypy-boto3-sagemaker-runtime<2.0,>=1.34.0; extra == "sagemaker"
|
|
49
|
+
Provides-Extra: test-sagemaker
|
|
50
|
+
Requires-Dist: sagemaker<3.0; extra == "test-sagemaker"
|
|
50
51
|
Dynamic: license-file
|
|
51
52
|
Dynamic: requires-dist
|
|
52
53
|
|
|
@@ -1,13 +1,13 @@
|
|
|
1
|
-
kumoai/__init__.py,sha256=
|
|
1
|
+
kumoai/__init__.py,sha256=aDhb7KGetDnOz54u1Fd45zfM2N8oAha6XT2CvJqOvgc,11146
|
|
2
2
|
kumoai/_logging.py,sha256=qL4JbMQwKXri2f-SEJoFB8TY5ALG12S-nobGTNWxW-A,915
|
|
3
3
|
kumoai/_singleton.py,sha256=i2BHWKpccNh5SJGDyU0IXsnYzJAYr8Xb0wz4c6LRbpo,861
|
|
4
|
-
kumoai/_version.py,sha256=
|
|
4
|
+
kumoai/_version.py,sha256=1P4-kFh0Np-63qTmspyV3KEawGdIIMPl6ZARL87NM3s,39
|
|
5
5
|
kumoai/databricks.py,sha256=ahwJz6DWLXMkndT0XwEDBxF-hoqhidFR8wBUQ4TLZ68,490
|
|
6
6
|
kumoai/exceptions.py,sha256=7TMs0SC8xrU009_Pgd4QXtSF9lxJq8MtRbeX9pcQUy4,859
|
|
7
7
|
kumoai/formatting.py,sha256=o3uCnLwXPhe1KI5WV9sBgRrcU7ed4rgu_pf89GL9Nc0,983
|
|
8
8
|
kumoai/futures.py,sha256=J8rtZMEYFzdn5xF_x-LAiKJz3KGL6PT02f6rq_2bOJk,3836
|
|
9
9
|
kumoai/jobs.py,sha256=dCi7BAdfm2tCnonYlGU4WJokJWbh3RzFfaOX2EYCIHU,2576
|
|
10
|
-
kumoai/kumolib.cp312-win_amd64.pyd,sha256=
|
|
10
|
+
kumoai/kumolib.cp312-win_amd64.pyd,sha256=tgu5bJkhFo3Yc524MI6sVvovE0yUuQ-dEnGPbKaFBIE,198144
|
|
11
11
|
kumoai/mixin.py,sha256=IaiB8SAI0VqOoMVzzIaUlqMt53-QPUK6OB0HikG-V9E,840
|
|
12
12
|
kumoai/spcs.py,sha256=KWfENrwSLruprlD-QPh63uU0N6npiNrwkeKfBk3EUyQ,4260
|
|
13
13
|
kumoai/artifact_export/__init__.py,sha256=UXAQI5q92ChBzWAk8o3J6pElzYHudAzFZssQXd4o7i8,247
|
|
@@ -50,37 +50,49 @@ kumoai/connector/glue_connector.py,sha256=kqT2q53Da7PeeaZrvLVzFXC186E7glh5eGitKL
|
|
|
50
50
|
kumoai/connector/s3_connector.py,sha256=AUzENbQ20bYXh3XOXEOsWRKlaGGkm3YrW9JfBLm-LqY,10433
|
|
51
51
|
kumoai/connector/snowflake_connector.py,sha256=tQzIWxC4oDGqxFt0212w5eoIPT4QBP2nuF9SdKRNwNI,9274
|
|
52
52
|
kumoai/connector/source_table.py,sha256=fnqwIKY6qYo4G0EsRzchb6FgZ-dQyU6aRaD9UAxsml0,18010
|
|
53
|
-
kumoai/connector/utils.py,sha256=
|
|
53
|
+
kumoai/connector/utils.py,sha256=5K9BMdWiIP3hhdkUc6Xt1e0xv5YyziXtZ4PnBqq0Ehw,66490
|
|
54
54
|
kumoai/encoder/__init__.py,sha256=8FeP6mUyCeXxr1b8kUIi5dxe5vEXQRft9tPoaV1CBqg,186
|
|
55
55
|
kumoai/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
56
|
-
kumoai/experimental/rfm/__init__.py,sha256=
|
|
56
|
+
kumoai/experimental/rfm/__init__.py,sha256=EFZz6IvvskmeO85Vig6p1m_6jdimS_BkeREOndHuRsc,6247
|
|
57
57
|
kumoai/experimental/rfm/authenticate.py,sha256=G89_4TMeUpr5fG_0VTzMF5sdNhaciitA1oc2loTlTmo,19321
|
|
58
|
-
kumoai/experimental/rfm/
|
|
59
|
-
kumoai/experimental/rfm/
|
|
60
|
-
kumoai/experimental/rfm/
|
|
61
|
-
kumoai/experimental/rfm/
|
|
62
|
-
kumoai/experimental/rfm/
|
|
63
|
-
kumoai/experimental/rfm/
|
|
64
|
-
kumoai/experimental/rfm/
|
|
65
|
-
kumoai/experimental/rfm/
|
|
66
|
-
kumoai/experimental/rfm/
|
|
58
|
+
kumoai/experimental/rfm/graph.py,sha256=SL3-WinoLnkZC6VVjebYGLuQJJyEVFJdCm6h3FNE0e4,40816
|
|
59
|
+
kumoai/experimental/rfm/rfm.py,sha256=jauaV5MuuSPsWDI2VwzMxxnsXpQ6Jp3nhgPTmxUMeh8,50304
|
|
60
|
+
kumoai/experimental/rfm/sagemaker.py,sha256=sEJSyfEFBA3-7wKinBEzSooKHEn0BgPjrgRnPhYo79g,5120
|
|
61
|
+
kumoai/experimental/rfm/backend/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
62
|
+
kumoai/experimental/rfm/backend/local/__init__.py,sha256=8JbLaai0yhtldFcDkddphIJKMiKc0XnodvYBWkrGPXI,1056
|
|
63
|
+
kumoai/experimental/rfm/backend/local/graph_store.py,sha256=m1eHot4hRSl6OY1trPRDEH0U89hVoRqMfWhA9f-SyQU,12234
|
|
64
|
+
kumoai/experimental/rfm/backend/local/sampler.py,sha256=GRMaBtfrSTJtX_9_uwi4MXV0V8SHYAsJJ68d5qrvJJg,11051
|
|
65
|
+
kumoai/experimental/rfm/backend/local/table.py,sha256=1PqNOROzlnK3SaZHNcU2hyzeifs0N4wssQAS3-Z0Myc,3674
|
|
66
|
+
kumoai/experimental/rfm/backend/snow/__init__.py,sha256=viMeR9VWpB1kjRdSWCTNFMdM7a8Mj_Dtck1twJW8dV8,962
|
|
67
|
+
kumoai/experimental/rfm/backend/snow/table.py,sha256=Rf4hUPOUtsjpaIc9vBKWPZ3yz20OOg6DZqCGeih4KC8,4372
|
|
68
|
+
kumoai/experimental/rfm/backend/sqlite/__init__.py,sha256=xw5NNLrWSvUvRkD49X_9hZYjas5EuP1XDANPy0EEjOg,874
|
|
69
|
+
kumoai/experimental/rfm/backend/sqlite/table.py,sha256=mBiZC21gQwfR4demFrP37GmawMHfIm-G82mLQeBqIZo,3901
|
|
70
|
+
kumoai/experimental/rfm/base/__init__.py,sha256=FX1DonqdQLD-q0SBqhNIelr6HUN2LNiutviCmdikDok,282
|
|
71
|
+
kumoai/experimental/rfm/base/column.py,sha256=OE-PRQ8HO4uTq0e3_3eHJFfhp5nzw79zd-43g3iMh4g,2385
|
|
72
|
+
kumoai/experimental/rfm/base/sampler.py,sha256=277THFQGvQK_OkBBKHQ7166oGpnL_zVPFble7Ehlu3c,31805
|
|
73
|
+
kumoai/experimental/rfm/base/source.py,sha256=H5yN9xAwK3i_69EdqOV_x58muPGKQiI8ev5BhHQDZEo,290
|
|
74
|
+
kumoai/experimental/rfm/base/table.py,sha256=6GlWUz3tAu2g1QqTA5idWGmfo2KEJoNApDFZRn8e0pg,20388
|
|
75
|
+
kumoai/experimental/rfm/infer/__init__.py,sha256=qKg8or-SpgTApD6ePw1PJ4aUZPrOLTHLRCmBIJ92hrk,486
|
|
67
76
|
kumoai/experimental/rfm/infer/categorical.py,sha256=bqmfrE5ZCBTcb35lA4SyAkCu3MgttAn29VBJYMBNhVg,893
|
|
77
|
+
kumoai/experimental/rfm/infer/dtype.py,sha256=Hf_drluYNuN59lTSe-8GuXalg20Pv93kCktB6Hb9f74,2686
|
|
68
78
|
kumoai/experimental/rfm/infer/id.py,sha256=xaJBETLZa8ttzZCsDwFSwfyCi3VYsLc_kDWT_t_6Ih4,954
|
|
69
79
|
kumoai/experimental/rfm/infer/multicategorical.py,sha256=D-1KwYRkOSkBrOJr4Xa3eTCoAF9O9hPGa7Vg67V5_HU,1150
|
|
80
|
+
kumoai/experimental/rfm/infer/pkey.py,sha256=Hvztcircd4iGdsnFU9Xi1kq_A5ONMnkAdnrpQT5svSs,4519
|
|
81
|
+
kumoai/experimental/rfm/infer/time_col.py,sha256=G98Cgz1m9G9VA-ApnCmGYnJxEFwp1jfaPf3nCMOz_N0,1882
|
|
70
82
|
kumoai/experimental/rfm/infer/timestamp.py,sha256=L2VxjtYTSyUBYAo4M-L08xSQlPpqnHMAVF5_vxjh3Y0,1135
|
|
71
83
|
kumoai/experimental/rfm/pquery/__init__.py,sha256=RkTn0I74uXOUuOiBpa6S-_QEYctMutkUnBEfF9ztQzI,159
|
|
72
84
|
kumoai/experimental/rfm/pquery/executor.py,sha256=S8wwXbAkH-YSnmEVYB8d6wyJF4JJ003mH_0zFTvOp_I,2843
|
|
73
|
-
kumoai/experimental/rfm/pquery/pandas_executor.py,sha256=
|
|
85
|
+
kumoai/experimental/rfm/pquery/pandas_executor.py,sha256=h5F2hrnyqAPqaqH4RwTdgkeedfCcDTpph9sF0lwIrx4,19024
|
|
74
86
|
kumoai/graph/__init__.py,sha256=QGk3OMwRzQJSGESdcc7hcQH6UDmNVJYTdqnRren4c7Q,240
|
|
75
87
|
kumoai/graph/column.py,sha256=cQhioibTbIKIBZ-bf8-Bt4F4Iblhidps-CYWrkxRPnE,4295
|
|
76
88
|
kumoai/graph/graph.py,sha256=Pq-dxi4MwoDtrrwm3xeyUB9Hl7ryNfHq4rMHuvyNB3c,39239
|
|
77
89
|
kumoai/graph/table.py,sha256=BB-4ezyd7hrrj6QZwRBa80ySH0trwYb4fmhRn3xoK-k,34726
|
|
78
90
|
kumoai/pquery/__init__.py,sha256=FF6QUTG_xrz2ic1I8NcIa8O993Ae98eZ9gkvQ4rapgo,558
|
|
79
91
|
kumoai/pquery/prediction_table.py,sha256=hWG4L_ze4PLgUoxCXNKk8_nkYxVXELQs8_X8KGOE9yk,11063
|
|
80
|
-
kumoai/pquery/predictive_query.py,sha256=
|
|
92
|
+
kumoai/pquery/predictive_query.py,sha256=I5Ntc7YO1qEGxKrLuhAzZO3SySr8Wnjhde8eDbbB7zk,25542
|
|
81
93
|
kumoai/pquery/training_table.py,sha256=L1QjaVlY4SAPD8OUmTaH6YjZzBbPOnS9mnAT69znWv0,16233
|
|
82
94
|
kumoai/testing/__init__.py,sha256=XBQ_Sa3WnOYlpXZ3gUn8w6nVfZt-nfPhytfIBeiPt4w,178
|
|
83
|
-
kumoai/testing/decorators.py,sha256=
|
|
95
|
+
kumoai/testing/decorators.py,sha256=p79ZCQqPY_MHWy0_l7-xQ6wUIqFTn4AbrGWTHLvpbQY,1664
|
|
84
96
|
kumoai/trainer/__init__.py,sha256=uCFXy9bw_byn_wYd3M-BTZCHTVvv4XXr8qRlh-QOvag,981
|
|
85
97
|
kumoai/trainer/baseline_trainer.py,sha256=oXweh8j1sar6KhQfr3A7gmQxcDq7SG0Bx3jIenbtyC4,4117
|
|
86
98
|
kumoai/trainer/config.py,sha256=7_Jv1w1mqaokCQwQdJkqCSgVpmh8GqE3fL1Ky_vvttI,100
|
|
@@ -92,8 +104,8 @@ kumoai/utils/__init__.py,sha256=wAKgmwtMIGuiauW9D_GGKH95K-24Kgwmld27mm4nsro,278
|
|
|
92
104
|
kumoai/utils/datasets.py,sha256=UyAII-oAn7x3ombuvpbSQ41aVF9SYKBjQthTD-vcT2A,3011
|
|
93
105
|
kumoai/utils/forecasting.py,sha256=ZgKeUCbWLOot0giAkoigwU5du8LkrwAicFOi5hVn6wg,7624
|
|
94
106
|
kumoai/utils/progress_logger.py,sha256=MZsWgHd4UZQKCXiJZgQeW-Emi_BmzlCKPLPXOL_HqBo,5239
|
|
95
|
-
kumoai-2.
|
|
96
|
-
kumoai-2.
|
|
97
|
-
kumoai-2.
|
|
98
|
-
kumoai-2.
|
|
99
|
-
kumoai-2.
|
|
107
|
+
kumoai-2.14.0.dev202512111731.dist-info/licenses/LICENSE,sha256=ZUilBDp--4vbhsEr6f_Upw9rnIx09zQ3K9fXQ0rfd6w,1111
|
|
108
|
+
kumoai-2.14.0.dev202512111731.dist-info/METADATA,sha256=BTYN5IeO5MXS1lF5aA-O_p7x32QpVK0kuMMp6Wy2S-A,2580
|
|
109
|
+
kumoai-2.14.0.dev202512111731.dist-info/WHEEL,sha256=8UP9x9puWI0P1V_d7K2oMTBqfeLNm21CTzZ_Ptr0NXU,101
|
|
110
|
+
kumoai-2.14.0.dev202512111731.dist-info/top_level.txt,sha256=YjU6UcmomoDx30vEXLsOU784ED7VztQOsFApk1SFwvs,7
|
|
111
|
+
kumoai-2.14.0.dev202512111731.dist-info/RECORD,,
|
|
@@ -1,182 +0,0 @@
|
|
|
1
|
-
from typing import Dict, List, Optional, Tuple
|
|
2
|
-
|
|
3
|
-
import numpy as np
|
|
4
|
-
import pandas as pd
|
|
5
|
-
from kumoapi.rfm.context import EdgeLayout, Link, Subgraph, Table
|
|
6
|
-
from kumoapi.typing import Stype
|
|
7
|
-
|
|
8
|
-
import kumoai.kumolib as kumolib
|
|
9
|
-
from kumoai.experimental.rfm.local_graph_store import LocalGraphStore
|
|
10
|
-
from kumoai.experimental.rfm.utils import normalize_text
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
class LocalGraphSampler:
|
|
14
|
-
def __init__(self, graph_store: LocalGraphStore) -> None:
|
|
15
|
-
self._graph_store = graph_store
|
|
16
|
-
self._sampler = kumolib.NeighborSampler(
|
|
17
|
-
self._graph_store.node_types,
|
|
18
|
-
self._graph_store.edge_types,
|
|
19
|
-
{
|
|
20
|
-
'__'.join(edge_type): colptr
|
|
21
|
-
for edge_type, colptr in self._graph_store.colptr_dict.items()
|
|
22
|
-
},
|
|
23
|
-
{
|
|
24
|
-
'__'.join(edge_type): row
|
|
25
|
-
for edge_type, row in self._graph_store.row_dict.items()
|
|
26
|
-
},
|
|
27
|
-
self._graph_store.time_dict,
|
|
28
|
-
)
|
|
29
|
-
|
|
30
|
-
def __call__(
|
|
31
|
-
self,
|
|
32
|
-
entity_table_names: Tuple[str, ...],
|
|
33
|
-
node: np.ndarray,
|
|
34
|
-
time: np.ndarray,
|
|
35
|
-
num_neighbors: List[int],
|
|
36
|
-
exclude_cols_dict: Dict[str, List[str]],
|
|
37
|
-
) -> Subgraph:
|
|
38
|
-
|
|
39
|
-
(
|
|
40
|
-
row_dict,
|
|
41
|
-
col_dict,
|
|
42
|
-
node_dict,
|
|
43
|
-
batch_dict,
|
|
44
|
-
num_sampled_nodes_dict,
|
|
45
|
-
num_sampled_edges_dict,
|
|
46
|
-
) = self._sampler.sample(
|
|
47
|
-
{
|
|
48
|
-
'__'.join(edge_type): num_neighbors
|
|
49
|
-
for edge_type in self._graph_store.edge_types
|
|
50
|
-
},
|
|
51
|
-
{}, # time interval based sampling
|
|
52
|
-
entity_table_names[0],
|
|
53
|
-
node,
|
|
54
|
-
time // 1000**3, # nanoseconds to seconds
|
|
55
|
-
)
|
|
56
|
-
|
|
57
|
-
table_dict: Dict[str, Table] = {}
|
|
58
|
-
for table_name, node in node_dict.items():
|
|
59
|
-
batch = batch_dict[table_name]
|
|
60
|
-
|
|
61
|
-
if len(node) == 0:
|
|
62
|
-
continue
|
|
63
|
-
|
|
64
|
-
df = self._graph_store.df_dict[table_name]
|
|
65
|
-
|
|
66
|
-
num_sampled_nodes = num_sampled_nodes_dict[table_name].tolist()
|
|
67
|
-
stype_dict = { # Exclude target columns:
|
|
68
|
-
column_name: stype
|
|
69
|
-
for column_name, stype in
|
|
70
|
-
self._graph_store.stype_dict[table_name].items()
|
|
71
|
-
if column_name not in exclude_cols_dict.get(table_name, [])
|
|
72
|
-
}
|
|
73
|
-
primary_key: Optional[str] = None
|
|
74
|
-
if table_name in entity_table_names:
|
|
75
|
-
primary_key = self._graph_store.pkey_name_dict.get(table_name)
|
|
76
|
-
|
|
77
|
-
columns: List[str] = []
|
|
78
|
-
if table_name in entity_table_names:
|
|
79
|
-
columns += [self._graph_store.pkey_name_dict[table_name]]
|
|
80
|
-
columns += list(stype_dict.keys())
|
|
81
|
-
|
|
82
|
-
if len(columns) == 0:
|
|
83
|
-
table_dict[table_name] = Table(
|
|
84
|
-
df=pd.DataFrame(index=range(len(node))),
|
|
85
|
-
row=None,
|
|
86
|
-
batch=batch,
|
|
87
|
-
num_sampled_nodes=num_sampled_nodes,
|
|
88
|
-
stype_dict=stype_dict,
|
|
89
|
-
primary_key=primary_key,
|
|
90
|
-
)
|
|
91
|
-
continue
|
|
92
|
-
|
|
93
|
-
row: Optional[np.ndarray] = None
|
|
94
|
-
if table_name in self._graph_store.end_time_column_dict:
|
|
95
|
-
# Set end time to NaT for all values greater than anchor time:
|
|
96
|
-
df = df.iloc[node].reset_index(drop=True)
|
|
97
|
-
col_name = self._graph_store.end_time_column_dict[table_name]
|
|
98
|
-
ser = df[col_name]
|
|
99
|
-
value = ser.astype('datetime64[ns]').astype(int).to_numpy()
|
|
100
|
-
mask = value > time[batch]
|
|
101
|
-
df.loc[mask, col_name] = pd.NaT
|
|
102
|
-
else:
|
|
103
|
-
# Only store unique rows in `df` above a certain threshold:
|
|
104
|
-
unique_node, inverse = np.unique(node, return_inverse=True)
|
|
105
|
-
if len(node) > 1.05 * len(unique_node):
|
|
106
|
-
df = df.iloc[unique_node].reset_index(drop=True)
|
|
107
|
-
row = inverse
|
|
108
|
-
else:
|
|
109
|
-
df = df.iloc[node].reset_index(drop=True)
|
|
110
|
-
|
|
111
|
-
# Filter data frame to minimal set of columns:
|
|
112
|
-
df = df[columns]
|
|
113
|
-
|
|
114
|
-
# Normalize text (if not already pre-processed):
|
|
115
|
-
for column_name, stype in stype_dict.items():
|
|
116
|
-
if stype == Stype.text:
|
|
117
|
-
df[column_name] = normalize_text(df[column_name])
|
|
118
|
-
|
|
119
|
-
table_dict[table_name] = Table(
|
|
120
|
-
df=df,
|
|
121
|
-
row=row,
|
|
122
|
-
batch=batch,
|
|
123
|
-
num_sampled_nodes=num_sampled_nodes,
|
|
124
|
-
stype_dict=stype_dict,
|
|
125
|
-
primary_key=primary_key,
|
|
126
|
-
)
|
|
127
|
-
|
|
128
|
-
link_dict: Dict[Tuple[str, str, str], Link] = {}
|
|
129
|
-
for edge_type in self._graph_store.edge_types:
|
|
130
|
-
edge_type_str = '__'.join(edge_type)
|
|
131
|
-
|
|
132
|
-
row = row_dict[edge_type_str]
|
|
133
|
-
col = col_dict[edge_type_str]
|
|
134
|
-
|
|
135
|
-
if len(row) == 0:
|
|
136
|
-
continue
|
|
137
|
-
|
|
138
|
-
# Do not store reverse edge type if it is a replica:
|
|
139
|
-
rev_edge_type = Subgraph.rev_edge_type(edge_type)
|
|
140
|
-
rev_edge_type_str = '__'.join(rev_edge_type)
|
|
141
|
-
if (rev_edge_type in link_dict
|
|
142
|
-
and np.array_equal(row, col_dict[rev_edge_type_str])
|
|
143
|
-
and np.array_equal(col, row_dict[rev_edge_type_str])):
|
|
144
|
-
link = Link(
|
|
145
|
-
layout=EdgeLayout.REV,
|
|
146
|
-
row=None,
|
|
147
|
-
col=None,
|
|
148
|
-
num_sampled_edges=(
|
|
149
|
-
num_sampled_edges_dict[edge_type_str].tolist()),
|
|
150
|
-
)
|
|
151
|
-
link_dict[edge_type] = link
|
|
152
|
-
continue
|
|
153
|
-
|
|
154
|
-
layout = EdgeLayout.COO
|
|
155
|
-
if np.array_equal(row, np.arange(len(row))):
|
|
156
|
-
row = None
|
|
157
|
-
if np.array_equal(col, np.arange(len(col))):
|
|
158
|
-
col = None
|
|
159
|
-
|
|
160
|
-
# Store in compressed representation if more efficient:
|
|
161
|
-
num_cols = table_dict[edge_type[2]].num_rows
|
|
162
|
-
if col is not None and len(col) > num_cols + 1:
|
|
163
|
-
layout = EdgeLayout.CSC
|
|
164
|
-
colcount = np.bincount(col, minlength=num_cols)
|
|
165
|
-
col = np.empty(num_cols + 1, dtype=col.dtype)
|
|
166
|
-
col[0] = 0
|
|
167
|
-
np.cumsum(colcount, out=col[1:])
|
|
168
|
-
|
|
169
|
-
link = Link(
|
|
170
|
-
layout=layout,
|
|
171
|
-
row=row,
|
|
172
|
-
col=col,
|
|
173
|
-
num_sampled_edges=(
|
|
174
|
-
num_sampled_edges_dict[edge_type_str].tolist()),
|
|
175
|
-
)
|
|
176
|
-
link_dict[edge_type] = link
|
|
177
|
-
|
|
178
|
-
return Subgraph(
|
|
179
|
-
anchor_time=time,
|
|
180
|
-
table_dict=table_dict,
|
|
181
|
-
link_dict=link_dict,
|
|
182
|
-
)
|