kumoai 2.13.0.dev202511271731__cp312-cp312-win_amd64.whl → 2.14.0.dev202512111731__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. kumoai/__init__.py +12 -0
  2. kumoai/_version.py +1 -1
  3. kumoai/connector/utils.py +23 -2
  4. kumoai/experimental/rfm/__init__.py +20 -45
  5. kumoai/experimental/rfm/backend/__init__.py +0 -0
  6. kumoai/experimental/rfm/backend/local/__init__.py +42 -0
  7. kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +37 -90
  8. kumoai/experimental/rfm/backend/local/sampler.py +313 -0
  9. kumoai/experimental/rfm/backend/local/table.py +109 -0
  10. kumoai/experimental/rfm/backend/snow/__init__.py +35 -0
  11. kumoai/experimental/rfm/backend/snow/table.py +117 -0
  12. kumoai/experimental/rfm/backend/sqlite/__init__.py +30 -0
  13. kumoai/experimental/rfm/backend/sqlite/table.py +101 -0
  14. kumoai/experimental/rfm/base/__init__.py +13 -0
  15. kumoai/experimental/rfm/base/column.py +66 -0
  16. kumoai/experimental/rfm/base/sampler.py +763 -0
  17. kumoai/experimental/rfm/base/source.py +18 -0
  18. kumoai/experimental/rfm/{local_table.py → base/table.py} +139 -139
  19. kumoai/experimental/rfm/{local_graph.py → graph.py} +334 -79
  20. kumoai/experimental/rfm/infer/__init__.py +6 -0
  21. kumoai/experimental/rfm/infer/dtype.py +79 -0
  22. kumoai/experimental/rfm/infer/pkey.py +126 -0
  23. kumoai/experimental/rfm/infer/time_col.py +62 -0
  24. kumoai/experimental/rfm/pquery/pandas_executor.py +1 -1
  25. kumoai/experimental/rfm/rfm.py +204 -166
  26. kumoai/experimental/rfm/sagemaker.py +11 -3
  27. kumoai/kumolib.cp312-win_amd64.pyd +0 -0
  28. kumoai/pquery/predictive_query.py +10 -6
  29. kumoai/testing/decorators.py +1 -1
  30. {kumoai-2.13.0.dev202511271731.dist-info → kumoai-2.14.0.dev202512111731.dist-info}/METADATA +9 -8
  31. {kumoai-2.13.0.dev202511271731.dist-info → kumoai-2.14.0.dev202512111731.dist-info}/RECORD +34 -22
  32. kumoai/experimental/rfm/local_graph_sampler.py +0 -182
  33. kumoai/experimental/rfm/local_pquery_driver.py +0 -689
  34. kumoai/experimental/rfm/utils.py +0 -344
  35. {kumoai-2.13.0.dev202511271731.dist-info → kumoai-2.14.0.dev202512111731.dist-info}/WHEEL +0 -0
  36. {kumoai-2.13.0.dev202511271731.dist-info → kumoai-2.14.0.dev202512111731.dist-info}/licenses/LICENSE +0 -0
  37. {kumoai-2.13.0.dev202511271731.dist-info → kumoai-2.14.0.dev202512111731.dist-info}/top_level.txt +0 -0
@@ -370,9 +370,11 @@ class PredictiveQuery:
370
370
  train_table_job_api = global_state.client.generate_train_table_job_api
371
371
  job_id: GenerateTrainTableJobID = train_table_job_api.create(
372
372
  GenerateTrainTableRequest(
373
- dict(custom_tags), pq_id, plan,
374
- graph_snapshot_id=self.graph.snapshot(
375
- non_blocking=non_blocking)))
373
+ dict(custom_tags),
374
+ pq_id,
375
+ plan,
376
+ None,
377
+ ))
376
378
 
377
379
  self._train_table = TrainingTableJob(job_id=job_id)
378
380
  if non_blocking:
@@ -451,9 +453,11 @@ class PredictiveQuery:
451
453
  bp_table_api = global_state.client.generate_prediction_table_job_api
452
454
  job_id: GeneratePredictionTableJobID = bp_table_api.create(
453
455
  GeneratePredictionTableRequest(
454
- dict(custom_tags), pq_id, plan,
455
- graph_snapshot_id=self.graph.snapshot(
456
- non_blocking=non_blocking)))
456
+ dict(custom_tags),
457
+ pq_id,
458
+ plan,
459
+ None,
460
+ ))
457
461
 
458
462
  self._prediction_table = PredictionTableJob(job_id=job_id)
459
463
  if non_blocking:
@@ -25,7 +25,7 @@ def onlyFullTest(func: Callable) -> Callable:
25
25
  def has_package(package: str) -> bool:
26
26
  r"""Returns ``True`` in case ``package`` is installed."""
27
27
  req = Requirement(package)
28
- if importlib.util.find_spec(req.name) is None:
28
+ if importlib.util.find_spec(req.name) is None: # type: ignore
29
29
  return False
30
30
 
31
31
  try:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: kumoai
3
- Version: 2.13.0.dev202511271731
3
+ Version: 2.14.0.dev202512111731
4
4
  Summary: AI on the Modern Data Stack
5
5
  Author-email: "Kumo.AI" <hello@kumo.ai>
6
6
  License-Expression: MIT
@@ -23,13 +23,11 @@ Requires-Dist: requests>=2.28.2
23
23
  Requires-Dist: urllib3
24
24
  Requires-Dist: plotly
25
25
  Requires-Dist: typing_extensions>=4.5.0
26
- Requires-Dist: kumo-api==0.46.0
26
+ Requires-Dist: kumo-api==0.49.0
27
27
  Requires-Dist: tqdm>=4.66.0
28
28
  Requires-Dist: aiohttp>=3.10.0
29
29
  Requires-Dist: pydantic>=1.10.21
30
30
  Requires-Dist: rich>=9.0.0
31
- Requires-Dist: mypy-boto3-sagemaker-runtime
32
- Requires-Dist: boto3
33
31
  Provides-Extra: doc
34
32
  Requires-Dist: sphinx; extra == "doc"
35
33
  Requires-Dist: sphinx-book-theme; extra == "doc"
@@ -40,13 +38,16 @@ Provides-Extra: test
40
38
  Requires-Dist: pytest; extra == "test"
41
39
  Requires-Dist: pytest-mock; extra == "test"
42
40
  Requires-Dist: requests-mock; extra == "test"
43
- Provides-Extra: test-sagemaker
44
- Requires-Dist: sagemaker; extra == "test-sagemaker"
45
- Requires-Dist: pandas==2.1.4; extra == "test-sagemaker"
46
- Requires-Dist: pyarrow==12.0.1; extra == "test-sagemaker"
41
+ Provides-Extra: sqlite
42
+ Requires-Dist: adbc_driver_sqlite; extra == "sqlite"
43
+ Provides-Extra: snowflake
44
+ Requires-Dist: snowflake-connector-python; extra == "snowflake"
45
+ Requires-Dist: pyyaml; extra == "snowflake"
47
46
  Provides-Extra: sagemaker
48
47
  Requires-Dist: boto3<2.0,>=1.30.0; extra == "sagemaker"
49
48
  Requires-Dist: mypy-boto3-sagemaker-runtime<2.0,>=1.34.0; extra == "sagemaker"
49
+ Provides-Extra: test-sagemaker
50
+ Requires-Dist: sagemaker<3.0; extra == "test-sagemaker"
50
51
  Dynamic: license-file
51
52
  Dynamic: requires-dist
52
53
 
@@ -1,13 +1,13 @@
1
- kumoai/__init__.py,sha256=qu-qohU2cQlManX1aZIlzA3ivKl52m-cSQBPSW8urUU,10837
1
+ kumoai/__init__.py,sha256=aDhb7KGetDnOz54u1Fd45zfM2N8oAha6XT2CvJqOvgc,11146
2
2
  kumoai/_logging.py,sha256=qL4JbMQwKXri2f-SEJoFB8TY5ALG12S-nobGTNWxW-A,915
3
3
  kumoai/_singleton.py,sha256=i2BHWKpccNh5SJGDyU0IXsnYzJAYr8Xb0wz4c6LRbpo,861
4
- kumoai/_version.py,sha256=qbrPpAMKZjsA3NIbSYLDMXMkJni4kHPs7GmD50DZGyQ,39
4
+ kumoai/_version.py,sha256=1P4-kFh0Np-63qTmspyV3KEawGdIIMPl6ZARL87NM3s,39
5
5
  kumoai/databricks.py,sha256=ahwJz6DWLXMkndT0XwEDBxF-hoqhidFR8wBUQ4TLZ68,490
6
6
  kumoai/exceptions.py,sha256=7TMs0SC8xrU009_Pgd4QXtSF9lxJq8MtRbeX9pcQUy4,859
7
7
  kumoai/formatting.py,sha256=o3uCnLwXPhe1KI5WV9sBgRrcU7ed4rgu_pf89GL9Nc0,983
8
8
  kumoai/futures.py,sha256=J8rtZMEYFzdn5xF_x-LAiKJz3KGL6PT02f6rq_2bOJk,3836
9
9
  kumoai/jobs.py,sha256=dCi7BAdfm2tCnonYlGU4WJokJWbh3RzFfaOX2EYCIHU,2576
10
- kumoai/kumolib.cp312-win_amd64.pyd,sha256=cv_LNcIrboLDXPVIO2MiZcZSnpD9bqKqlzh9Q6UPeR4,198144
10
+ kumoai/kumolib.cp312-win_amd64.pyd,sha256=tgu5bJkhFo3Yc524MI6sVvovE0yUuQ-dEnGPbKaFBIE,198144
11
11
  kumoai/mixin.py,sha256=IaiB8SAI0VqOoMVzzIaUlqMt53-QPUK6OB0HikG-V9E,840
12
12
  kumoai/spcs.py,sha256=KWfENrwSLruprlD-QPh63uU0N6npiNrwkeKfBk3EUyQ,4260
13
13
  kumoai/artifact_export/__init__.py,sha256=UXAQI5q92ChBzWAk8o3J6pElzYHudAzFZssQXd4o7i8,247
@@ -50,37 +50,49 @@ kumoai/connector/glue_connector.py,sha256=kqT2q53Da7PeeaZrvLVzFXC186E7glh5eGitKL
50
50
  kumoai/connector/s3_connector.py,sha256=AUzENbQ20bYXh3XOXEOsWRKlaGGkm3YrW9JfBLm-LqY,10433
51
51
  kumoai/connector/snowflake_connector.py,sha256=tQzIWxC4oDGqxFt0212w5eoIPT4QBP2nuF9SdKRNwNI,9274
52
52
  kumoai/connector/source_table.py,sha256=fnqwIKY6qYo4G0EsRzchb6FgZ-dQyU6aRaD9UAxsml0,18010
53
- kumoai/connector/utils.py,sha256=SlkjPJS_wqfwFzIaQOHZtENQnbOz5sgLbvvvPDXE1ww,65786
53
+ kumoai/connector/utils.py,sha256=5K9BMdWiIP3hhdkUc6Xt1e0xv5YyziXtZ4PnBqq0Ehw,66490
54
54
  kumoai/encoder/__init__.py,sha256=8FeP6mUyCeXxr1b8kUIi5dxe5vEXQRft9tPoaV1CBqg,186
55
55
  kumoai/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
56
- kumoai/experimental/rfm/__init__.py,sha256=gpjpeN8PT3ZESi6kUaeyZqYnoJnysRVXDaY9hrycJA4,7020
56
+ kumoai/experimental/rfm/__init__.py,sha256=EFZz6IvvskmeO85Vig6p1m_6jdimS_BkeREOndHuRsc,6247
57
57
  kumoai/experimental/rfm/authenticate.py,sha256=G89_4TMeUpr5fG_0VTzMF5sdNhaciitA1oc2loTlTmo,19321
58
- kumoai/experimental/rfm/local_graph.py,sha256=nZ9hDfyWg1dHFLoTEKoLt0ZJPvf9MUA1MNyfTRzJThg,30886
59
- kumoai/experimental/rfm/local_graph_sampler.py,sha256=3JNpktW__nwxVKZxP4cQBgsIin7J_LNXYS7YlV36xbU,6854
60
- kumoai/experimental/rfm/local_graph_store.py,sha256=eUuIMFcdIRqN1kRxnqOdJpKEt-S_oyupAyHr7YuQoSU,14206
61
- kumoai/experimental/rfm/local_pquery_driver.py,sha256=Yd_yHIrvuDj16IC1pvsqiQvZS41vvOOCRMiuDGtN6Fk,26851
62
- kumoai/experimental/rfm/local_table.py,sha256=5H08657TIyH7n_QnpFKr2g4BtVqdXTymmrfhSGaDmkU,20150
63
- kumoai/experimental/rfm/rfm.py,sha256=MarISSPKuv6nIaGG69zFAwIagF6EA37xcSRClZrQMFc,49470
64
- kumoai/experimental/rfm/sagemaker.py,sha256=eebpZtASqiIGF2FpY53bbWLj6p-u5hkK4RLgBNAvEzg,4953
65
- kumoai/experimental/rfm/utils.py,sha256=dLx2wdyTWg7vZI_7R-I0z_lA-2aV5M8h9n3bnnLyylI,11467
66
- kumoai/experimental/rfm/infer/__init__.py,sha256=fPsdDr4D3hgC8snW0j3pAVpCyR-xrauuogMnTOMrfok,304
58
+ kumoai/experimental/rfm/graph.py,sha256=SL3-WinoLnkZC6VVjebYGLuQJJyEVFJdCm6h3FNE0e4,40816
59
+ kumoai/experimental/rfm/rfm.py,sha256=jauaV5MuuSPsWDI2VwzMxxnsXpQ6Jp3nhgPTmxUMeh8,50304
60
+ kumoai/experimental/rfm/sagemaker.py,sha256=sEJSyfEFBA3-7wKinBEzSooKHEn0BgPjrgRnPhYo79g,5120
61
+ kumoai/experimental/rfm/backend/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
62
+ kumoai/experimental/rfm/backend/local/__init__.py,sha256=8JbLaai0yhtldFcDkddphIJKMiKc0XnodvYBWkrGPXI,1056
63
+ kumoai/experimental/rfm/backend/local/graph_store.py,sha256=m1eHot4hRSl6OY1trPRDEH0U89hVoRqMfWhA9f-SyQU,12234
64
+ kumoai/experimental/rfm/backend/local/sampler.py,sha256=GRMaBtfrSTJtX_9_uwi4MXV0V8SHYAsJJ68d5qrvJJg,11051
65
+ kumoai/experimental/rfm/backend/local/table.py,sha256=1PqNOROzlnK3SaZHNcU2hyzeifs0N4wssQAS3-Z0Myc,3674
66
+ kumoai/experimental/rfm/backend/snow/__init__.py,sha256=viMeR9VWpB1kjRdSWCTNFMdM7a8Mj_Dtck1twJW8dV8,962
67
+ kumoai/experimental/rfm/backend/snow/table.py,sha256=Rf4hUPOUtsjpaIc9vBKWPZ3yz20OOg6DZqCGeih4KC8,4372
68
+ kumoai/experimental/rfm/backend/sqlite/__init__.py,sha256=xw5NNLrWSvUvRkD49X_9hZYjas5EuP1XDANPy0EEjOg,874
69
+ kumoai/experimental/rfm/backend/sqlite/table.py,sha256=mBiZC21gQwfR4demFrP37GmawMHfIm-G82mLQeBqIZo,3901
70
+ kumoai/experimental/rfm/base/__init__.py,sha256=FX1DonqdQLD-q0SBqhNIelr6HUN2LNiutviCmdikDok,282
71
+ kumoai/experimental/rfm/base/column.py,sha256=OE-PRQ8HO4uTq0e3_3eHJFfhp5nzw79zd-43g3iMh4g,2385
72
+ kumoai/experimental/rfm/base/sampler.py,sha256=277THFQGvQK_OkBBKHQ7166oGpnL_zVPFble7Ehlu3c,31805
73
+ kumoai/experimental/rfm/base/source.py,sha256=H5yN9xAwK3i_69EdqOV_x58muPGKQiI8ev5BhHQDZEo,290
74
+ kumoai/experimental/rfm/base/table.py,sha256=6GlWUz3tAu2g1QqTA5idWGmfo2KEJoNApDFZRn8e0pg,20388
75
+ kumoai/experimental/rfm/infer/__init__.py,sha256=qKg8or-SpgTApD6ePw1PJ4aUZPrOLTHLRCmBIJ92hrk,486
67
76
  kumoai/experimental/rfm/infer/categorical.py,sha256=bqmfrE5ZCBTcb35lA4SyAkCu3MgttAn29VBJYMBNhVg,893
77
+ kumoai/experimental/rfm/infer/dtype.py,sha256=Hf_drluYNuN59lTSe-8GuXalg20Pv93kCktB6Hb9f74,2686
68
78
  kumoai/experimental/rfm/infer/id.py,sha256=xaJBETLZa8ttzZCsDwFSwfyCi3VYsLc_kDWT_t_6Ih4,954
69
79
  kumoai/experimental/rfm/infer/multicategorical.py,sha256=D-1KwYRkOSkBrOJr4Xa3eTCoAF9O9hPGa7Vg67V5_HU,1150
80
+ kumoai/experimental/rfm/infer/pkey.py,sha256=Hvztcircd4iGdsnFU9Xi1kq_A5ONMnkAdnrpQT5svSs,4519
81
+ kumoai/experimental/rfm/infer/time_col.py,sha256=G98Cgz1m9G9VA-ApnCmGYnJxEFwp1jfaPf3nCMOz_N0,1882
70
82
  kumoai/experimental/rfm/infer/timestamp.py,sha256=L2VxjtYTSyUBYAo4M-L08xSQlPpqnHMAVF5_vxjh3Y0,1135
71
83
  kumoai/experimental/rfm/pquery/__init__.py,sha256=RkTn0I74uXOUuOiBpa6S-_QEYctMutkUnBEfF9ztQzI,159
72
84
  kumoai/experimental/rfm/pquery/executor.py,sha256=S8wwXbAkH-YSnmEVYB8d6wyJF4JJ003mH_0zFTvOp_I,2843
73
- kumoai/experimental/rfm/pquery/pandas_executor.py,sha256=QQpOZ_ArH3eSAkenaY3J-gW1Wn5A7f85RiqZxaO5u1Q,19019
85
+ kumoai/experimental/rfm/pquery/pandas_executor.py,sha256=h5F2hrnyqAPqaqH4RwTdgkeedfCcDTpph9sF0lwIrx4,19024
74
86
  kumoai/graph/__init__.py,sha256=QGk3OMwRzQJSGESdcc7hcQH6UDmNVJYTdqnRren4c7Q,240
75
87
  kumoai/graph/column.py,sha256=cQhioibTbIKIBZ-bf8-Bt4F4Iblhidps-CYWrkxRPnE,4295
76
88
  kumoai/graph/graph.py,sha256=Pq-dxi4MwoDtrrwm3xeyUB9Hl7ryNfHq4rMHuvyNB3c,39239
77
89
  kumoai/graph/table.py,sha256=BB-4ezyd7hrrj6QZwRBa80ySH0trwYb4fmhRn3xoK-k,34726
78
90
  kumoai/pquery/__init__.py,sha256=FF6QUTG_xrz2ic1I8NcIa8O993Ae98eZ9gkvQ4rapgo,558
79
91
  kumoai/pquery/prediction_table.py,sha256=hWG4L_ze4PLgUoxCXNKk8_nkYxVXELQs8_X8KGOE9yk,11063
80
- kumoai/pquery/predictive_query.py,sha256=GWhQpQxf6apyyu-bvE3z63mX6NLd8lKbyu_jzj7rNms,25608
92
+ kumoai/pquery/predictive_query.py,sha256=I5Ntc7YO1qEGxKrLuhAzZO3SySr8Wnjhde8eDbbB7zk,25542
81
93
  kumoai/pquery/training_table.py,sha256=L1QjaVlY4SAPD8OUmTaH6YjZzBbPOnS9mnAT69znWv0,16233
82
94
  kumoai/testing/__init__.py,sha256=XBQ_Sa3WnOYlpXZ3gUn8w6nVfZt-nfPhytfIBeiPt4w,178
83
- kumoai/testing/decorators.py,sha256=yznguzsdkL0UaZtBbnO6oaUrXisJvziaiO3dmN41UXE,1648
95
+ kumoai/testing/decorators.py,sha256=p79ZCQqPY_MHWy0_l7-xQ6wUIqFTn4AbrGWTHLvpbQY,1664
84
96
  kumoai/trainer/__init__.py,sha256=uCFXy9bw_byn_wYd3M-BTZCHTVvv4XXr8qRlh-QOvag,981
85
97
  kumoai/trainer/baseline_trainer.py,sha256=oXweh8j1sar6KhQfr3A7gmQxcDq7SG0Bx3jIenbtyC4,4117
86
98
  kumoai/trainer/config.py,sha256=7_Jv1w1mqaokCQwQdJkqCSgVpmh8GqE3fL1Ky_vvttI,100
@@ -92,8 +104,8 @@ kumoai/utils/__init__.py,sha256=wAKgmwtMIGuiauW9D_GGKH95K-24Kgwmld27mm4nsro,278
92
104
  kumoai/utils/datasets.py,sha256=UyAII-oAn7x3ombuvpbSQ41aVF9SYKBjQthTD-vcT2A,3011
93
105
  kumoai/utils/forecasting.py,sha256=ZgKeUCbWLOot0giAkoigwU5du8LkrwAicFOi5hVn6wg,7624
94
106
  kumoai/utils/progress_logger.py,sha256=MZsWgHd4UZQKCXiJZgQeW-Emi_BmzlCKPLPXOL_HqBo,5239
95
- kumoai-2.13.0.dev202511271731.dist-info/licenses/LICENSE,sha256=ZUilBDp--4vbhsEr6f_Upw9rnIx09zQ3K9fXQ0rfd6w,1111
96
- kumoai-2.13.0.dev202511271731.dist-info/METADATA,sha256=J6GLHeIxdSBnxASmR6VqdgENJuaKxk7qfc7Q1NeXf5E,2544
97
- kumoai-2.13.0.dev202511271731.dist-info/WHEEL,sha256=8UP9x9puWI0P1V_d7K2oMTBqfeLNm21CTzZ_Ptr0NXU,101
98
- kumoai-2.13.0.dev202511271731.dist-info/top_level.txt,sha256=YjU6UcmomoDx30vEXLsOU784ED7VztQOsFApk1SFwvs,7
99
- kumoai-2.13.0.dev202511271731.dist-info/RECORD,,
107
+ kumoai-2.14.0.dev202512111731.dist-info/licenses/LICENSE,sha256=ZUilBDp--4vbhsEr6f_Upw9rnIx09zQ3K9fXQ0rfd6w,1111
108
+ kumoai-2.14.0.dev202512111731.dist-info/METADATA,sha256=BTYN5IeO5MXS1lF5aA-O_p7x32QpVK0kuMMp6Wy2S-A,2580
109
+ kumoai-2.14.0.dev202512111731.dist-info/WHEEL,sha256=8UP9x9puWI0P1V_d7K2oMTBqfeLNm21CTzZ_Ptr0NXU,101
110
+ kumoai-2.14.0.dev202512111731.dist-info/top_level.txt,sha256=YjU6UcmomoDx30vEXLsOU784ED7VztQOsFApk1SFwvs,7
111
+ kumoai-2.14.0.dev202512111731.dist-info/RECORD,,
@@ -1,182 +0,0 @@
1
- from typing import Dict, List, Optional, Tuple
2
-
3
- import numpy as np
4
- import pandas as pd
5
- from kumoapi.rfm.context import EdgeLayout, Link, Subgraph, Table
6
- from kumoapi.typing import Stype
7
-
8
- import kumoai.kumolib as kumolib
9
- from kumoai.experimental.rfm.local_graph_store import LocalGraphStore
10
- from kumoai.experimental.rfm.utils import normalize_text
11
-
12
-
13
- class LocalGraphSampler:
14
- def __init__(self, graph_store: LocalGraphStore) -> None:
15
- self._graph_store = graph_store
16
- self._sampler = kumolib.NeighborSampler(
17
- self._graph_store.node_types,
18
- self._graph_store.edge_types,
19
- {
20
- '__'.join(edge_type): colptr
21
- for edge_type, colptr in self._graph_store.colptr_dict.items()
22
- },
23
- {
24
- '__'.join(edge_type): row
25
- for edge_type, row in self._graph_store.row_dict.items()
26
- },
27
- self._graph_store.time_dict,
28
- )
29
-
30
- def __call__(
31
- self,
32
- entity_table_names: Tuple[str, ...],
33
- node: np.ndarray,
34
- time: np.ndarray,
35
- num_neighbors: List[int],
36
- exclude_cols_dict: Dict[str, List[str]],
37
- ) -> Subgraph:
38
-
39
- (
40
- row_dict,
41
- col_dict,
42
- node_dict,
43
- batch_dict,
44
- num_sampled_nodes_dict,
45
- num_sampled_edges_dict,
46
- ) = self._sampler.sample(
47
- {
48
- '__'.join(edge_type): num_neighbors
49
- for edge_type in self._graph_store.edge_types
50
- },
51
- {}, # time interval based sampling
52
- entity_table_names[0],
53
- node,
54
- time // 1000**3, # nanoseconds to seconds
55
- )
56
-
57
- table_dict: Dict[str, Table] = {}
58
- for table_name, node in node_dict.items():
59
- batch = batch_dict[table_name]
60
-
61
- if len(node) == 0:
62
- continue
63
-
64
- df = self._graph_store.df_dict[table_name]
65
-
66
- num_sampled_nodes = num_sampled_nodes_dict[table_name].tolist()
67
- stype_dict = { # Exclude target columns:
68
- column_name: stype
69
- for column_name, stype in
70
- self._graph_store.stype_dict[table_name].items()
71
- if column_name not in exclude_cols_dict.get(table_name, [])
72
- }
73
- primary_key: Optional[str] = None
74
- if table_name in entity_table_names:
75
- primary_key = self._graph_store.pkey_name_dict.get(table_name)
76
-
77
- columns: List[str] = []
78
- if table_name in entity_table_names:
79
- columns += [self._graph_store.pkey_name_dict[table_name]]
80
- columns += list(stype_dict.keys())
81
-
82
- if len(columns) == 0:
83
- table_dict[table_name] = Table(
84
- df=pd.DataFrame(index=range(len(node))),
85
- row=None,
86
- batch=batch,
87
- num_sampled_nodes=num_sampled_nodes,
88
- stype_dict=stype_dict,
89
- primary_key=primary_key,
90
- )
91
- continue
92
-
93
- row: Optional[np.ndarray] = None
94
- if table_name in self._graph_store.end_time_column_dict:
95
- # Set end time to NaT for all values greater than anchor time:
96
- df = df.iloc[node].reset_index(drop=True)
97
- col_name = self._graph_store.end_time_column_dict[table_name]
98
- ser = df[col_name]
99
- value = ser.astype('datetime64[ns]').astype(int).to_numpy()
100
- mask = value > time[batch]
101
- df.loc[mask, col_name] = pd.NaT
102
- else:
103
- # Only store unique rows in `df` above a certain threshold:
104
- unique_node, inverse = np.unique(node, return_inverse=True)
105
- if len(node) > 1.05 * len(unique_node):
106
- df = df.iloc[unique_node].reset_index(drop=True)
107
- row = inverse
108
- else:
109
- df = df.iloc[node].reset_index(drop=True)
110
-
111
- # Filter data frame to minimal set of columns:
112
- df = df[columns]
113
-
114
- # Normalize text (if not already pre-processed):
115
- for column_name, stype in stype_dict.items():
116
- if stype == Stype.text:
117
- df[column_name] = normalize_text(df[column_name])
118
-
119
- table_dict[table_name] = Table(
120
- df=df,
121
- row=row,
122
- batch=batch,
123
- num_sampled_nodes=num_sampled_nodes,
124
- stype_dict=stype_dict,
125
- primary_key=primary_key,
126
- )
127
-
128
- link_dict: Dict[Tuple[str, str, str], Link] = {}
129
- for edge_type in self._graph_store.edge_types:
130
- edge_type_str = '__'.join(edge_type)
131
-
132
- row = row_dict[edge_type_str]
133
- col = col_dict[edge_type_str]
134
-
135
- if len(row) == 0:
136
- continue
137
-
138
- # Do not store reverse edge type if it is a replica:
139
- rev_edge_type = Subgraph.rev_edge_type(edge_type)
140
- rev_edge_type_str = '__'.join(rev_edge_type)
141
- if (rev_edge_type in link_dict
142
- and np.array_equal(row, col_dict[rev_edge_type_str])
143
- and np.array_equal(col, row_dict[rev_edge_type_str])):
144
- link = Link(
145
- layout=EdgeLayout.REV,
146
- row=None,
147
- col=None,
148
- num_sampled_edges=(
149
- num_sampled_edges_dict[edge_type_str].tolist()),
150
- )
151
- link_dict[edge_type] = link
152
- continue
153
-
154
- layout = EdgeLayout.COO
155
- if np.array_equal(row, np.arange(len(row))):
156
- row = None
157
- if np.array_equal(col, np.arange(len(col))):
158
- col = None
159
-
160
- # Store in compressed representation if more efficient:
161
- num_cols = table_dict[edge_type[2]].num_rows
162
- if col is not None and len(col) > num_cols + 1:
163
- layout = EdgeLayout.CSC
164
- colcount = np.bincount(col, minlength=num_cols)
165
- col = np.empty(num_cols + 1, dtype=col.dtype)
166
- col[0] = 0
167
- np.cumsum(colcount, out=col[1:])
168
-
169
- link = Link(
170
- layout=layout,
171
- row=row,
172
- col=col,
173
- num_sampled_edges=(
174
- num_sampled_edges_dict[edge_type_str].tolist()),
175
- )
176
- link_dict[edge_type] = link
177
-
178
- return Subgraph(
179
- anchor_time=time,
180
- table_dict=table_dict,
181
- link_dict=link_dict,
182
- )