kumoai 2.13.0.dev202512081731__cp313-cp313-macosx_11_0_arm64.whl → 2.14.0.dev202512151351__cp313-cp313-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. kumoai/_version.py +1 -1
  2. kumoai/client/pquery.py +6 -2
  3. kumoai/experimental/rfm/backend/local/graph_store.py +19 -62
  4. kumoai/experimental/rfm/backend/local/sampler.py +213 -14
  5. kumoai/experimental/rfm/backend/local/table.py +12 -2
  6. kumoai/experimental/rfm/backend/snow/__init__.py +2 -0
  7. kumoai/experimental/rfm/backend/snow/sampler.py +264 -0
  8. kumoai/experimental/rfm/backend/snow/table.py +35 -17
  9. kumoai/experimental/rfm/backend/sqlite/__init__.py +2 -0
  10. kumoai/experimental/rfm/backend/sqlite/sampler.py +354 -0
  11. kumoai/experimental/rfm/backend/sqlite/table.py +36 -11
  12. kumoai/experimental/rfm/base/__init__.py +17 -6
  13. kumoai/experimental/rfm/base/sampler.py +438 -38
  14. kumoai/experimental/rfm/base/source.py +1 -0
  15. kumoai/experimental/rfm/base/sql_sampler.py +56 -0
  16. kumoai/experimental/rfm/base/table.py +12 -1
  17. kumoai/experimental/rfm/graph.py +26 -9
  18. kumoai/experimental/rfm/pquery/pandas_executor.py +1 -1
  19. kumoai/experimental/rfm/rfm.py +214 -151
  20. kumoai/pquery/predictive_query.py +10 -6
  21. kumoai/testing/snow.py +50 -0
  22. kumoai/utils/__init__.py +2 -0
  23. kumoai/utils/sql.py +3 -0
  24. {kumoai-2.13.0.dev202512081731.dist-info → kumoai-2.14.0.dev202512151351.dist-info}/METADATA +2 -2
  25. {kumoai-2.13.0.dev202512081731.dist-info → kumoai-2.14.0.dev202512151351.dist-info}/RECORD +28 -25
  26. kumoai/experimental/rfm/local_graph_sampler.py +0 -223
  27. kumoai/experimental/rfm/local_pquery_driver.py +0 -689
  28. {kumoai-2.13.0.dev202512081731.dist-info → kumoai-2.14.0.dev202512151351.dist-info}/WHEEL +0 -0
  29. {kumoai-2.13.0.dev202512081731.dist-info → kumoai-2.14.0.dev202512151351.dist-info}/licenses/LICENSE +0 -0
  30. {kumoai-2.13.0.dev202512081731.dist-info → kumoai-2.14.0.dev202512151351.dist-info}/top_level.txt +0 -0
kumoai/testing/snow.py ADDED
@@ -0,0 +1,50 @@
1
+ import json
2
+ import os
3
+
4
+ from kumoai.experimental.rfm.backend.snow import Connection
5
+ from kumoai.experimental.rfm.backend.snow import connect as _connect
6
+
7
+
8
+ def connect(
9
+ region: str,
10
+ id: str,
11
+ account: str,
12
+ user: str,
13
+ warehouse: str,
14
+ database: str | None = None,
15
+ schema: str | None = None,
16
+ ) -> Connection:
17
+
18
+ kwargs = dict(password=os.getenv('SNOWFLAKE_PASSWORD'))
19
+ if kwargs['password'] is None:
20
+ import boto3
21
+ from cryptography.hazmat.primitives import serialization
22
+
23
+ client = boto3.client(
24
+ service_name='secretsmanager',
25
+ region_name=region,
26
+ )
27
+ secret_id = (f'arn:aws:secretsmanager:{region}:{id}:secret:'
28
+ f'{account}.snowflakecomputing.com')
29
+ response = client.get_secret_value(SecretId=secret_id)['SecretString']
30
+ secret = json.loads(response)
31
+
32
+ private_key = serialization.load_pem_private_key(
33
+ secret['kumo_user_secretkey'].encode(),
34
+ password=None,
35
+ )
36
+ kwargs['private_key'] = private_key.private_bytes(
37
+ encoding=serialization.Encoding.DER,
38
+ format=serialization.PrivateFormat.PKCS8,
39
+ encryption_algorithm=serialization.NoEncryption(),
40
+ )
41
+
42
+ return _connect(
43
+ account=account,
44
+ user=user,
45
+ warehouse='WH_XS',
46
+ database='KUMO',
47
+ schema=schema,
48
+ session_parameters=dict(CLIENT_TELEMETRY_ENABLED=False),
49
+ **kwargs,
50
+ )
kumoai/utils/__init__.py CHANGED
@@ -1,8 +1,10 @@
1
+ from .sql import quote_ident
1
2
  from .progress_logger import ProgressLogger, InteractiveProgressLogger
2
3
  from .forecasting import ForecastVisualizer
3
4
  from .datasets import from_relbench
4
5
 
5
6
  __all__ = [
7
+ 'quote_ident',
6
8
  'ProgressLogger',
7
9
  'InteractiveProgressLogger',
8
10
  'ForecastVisualizer',
kumoai/utils/sql.py ADDED
@@ -0,0 +1,3 @@
1
+ def quote_ident(name: str) -> str:
2
+ r"""Quotes a SQL identifier."""
3
+ return '"' + name.replace('"', '""') + '"'
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: kumoai
3
- Version: 2.13.0.dev202512081731
3
+ Version: 2.14.0.dev202512151351
4
4
  Summary: AI on the Modern Data Stack
5
5
  Author-email: "Kumo.AI" <hello@kumo.ai>
6
6
  License-Expression: MIT
@@ -23,7 +23,7 @@ Requires-Dist: requests>=2.28.2
23
23
  Requires-Dist: urllib3
24
24
  Requires-Dist: plotly
25
25
  Requires-Dist: typing_extensions>=4.5.0
26
- Requires-Dist: kumo-api==0.48.0
26
+ Requires-Dist: kumo-api==0.49.0
27
27
  Requires-Dist: tqdm>=4.66.0
28
28
  Requires-Dist: aiohttp>=3.10.0
29
29
  Requires-Dist: pydantic>=1.10.21
@@ -1,7 +1,7 @@
1
1
  kumoai/kumolib.cpython-313-darwin.so,sha256=waBv-DiZ3WcasxiCQ-OM9EbSTgTtCfBTZIibXAK-JiQ,232816
2
2
  kumoai/_logging.py,sha256=U2_5ROdyk92P4xO4H2WJV8EC7dr6YxmmnM-b7QX9M7I,886
3
3
  kumoai/mixin.py,sha256=MP413xzuCqWhxAPUHmloLA3j4ZyF1tEtfi516b_hOXQ,812
4
- kumoai/_version.py,sha256=W0EIBX5oPkQ0eXYnfNBgKMhonz56bp9ySi_IPtjQoCA,39
4
+ kumoai/_version.py,sha256=bOoqsL1s-b_8ovgo8rxykwFD8SfZQM7pj8skKFtYz5U,39
5
5
  kumoai/__init__.py,sha256=Nn9YH_x9kAeEFn8RWbP95slZow0qFnakPZZ1WADe1hY,10843
6
6
  kumoai/formatting.py,sha256=jA_rLDCGKZI8WWCha-vtuLenVKTZvli99Tqpurz1H84,953
7
7
  kumoai/futures.py,sha256=oJFIfdCM_3nWIqQteBKYMY4fPhoYlYWE_JA2o6tx-ng,3737
@@ -11,24 +11,24 @@ kumoai/databricks.py,sha256=e6E4lOFvZHXFwh4CO1kXU1zzDU3AapLQYMxjiHPC-HQ,476
11
11
  kumoai/spcs.py,sha256=N31d7rLa-bgYh8e2J4YzX1ScxGLqiVXrqJnCl1y4Mts,4139
12
12
  kumoai/_singleton.py,sha256=UTwrbDkoZSGB8ZelorvprPDDv9uZkUi1q_SrmsyngpQ,836
13
13
  kumoai/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
- kumoai/experimental/rfm/local_graph_sampler.py,sha256=32ZCNirPyCqCD8IccaXmRt0EJk1p54mWXpJ33NotAqE,7883
15
- kumoai/experimental/rfm/local_pquery_driver.py,sha256=dhOS1L9aboya86EL4AFYc8bQkimbOchSLfe_jn2qGh4,26158
16
- kumoai/experimental/rfm/graph.py,sha256=76hlQyaEYqBYNIF3jslIqRRuAPNtXvc1kR6InwyHH-M,39751
14
+ kumoai/experimental/rfm/graph.py,sha256=awVJSk4cWRMacS5CJvJtR8TR56FEbrJPcQCukNydQOc,40392
17
15
  kumoai/experimental/rfm/__init__.py,sha256=slliYcrh80xPtQQ_nnsp3ny9IbmHCyirmdZUfKTdME4,6064
18
16
  kumoai/experimental/rfm/sagemaker.py,sha256=_hTrFg4qfXe7uzwqSEG_wze-IFkwn7qde9OpUodCpbc,4982
19
- kumoai/experimental/rfm/rfm.py,sha256=BSgxeM0xW2mt74jq4Ah4hl85RxT6337NoDQP7f7iXvY,47699
17
+ kumoai/experimental/rfm/rfm.py,sha256=YyUzoyu7STVnmGnKWTgAPVB4GlMng_n1PDXv45o9oJM,49976
20
18
  kumoai/experimental/rfm/authenticate.py,sha256=FiuHMvP7V3zBZUlHMDMbNLhc-UgDZgz4hjVSTuQ7DRw,18888
21
19
  kumoai/experimental/rfm/backend/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
22
- kumoai/experimental/rfm/backend/sqlite/__init__.py,sha256=jYmZDNAVsojuPO1Q5idFmG5N0aCB8BDyrpAoS31n9bc,844
23
- kumoai/experimental/rfm/backend/sqlite/table.py,sha256=kcYpWaZKFez2Tru6Sdz-Ywk8jP8VpLnjmCIQQtRFGnU,3800
20
+ kumoai/experimental/rfm/backend/sqlite/__init__.py,sha256=cA-PZL1oTaLxthZbfLSudexImtF6jRsGkdjSp-66dCM,914
21
+ kumoai/experimental/rfm/backend/sqlite/table.py,sha256=fSn-CsOKK584Qgn8a-k8eymykoOrIX5w0CCzOBNW0Zk,4677
22
+ kumoai/experimental/rfm/backend/sqlite/sampler.py,sha256=SsQCJB04DyNoz7Vyy6oF4sfaqCZt5aHpE1Kxf1qEGco,14467
24
23
  kumoai/experimental/rfm/backend/local/__init__.py,sha256=2s9sSA-E-8pfkkzCH4XPuaSxSznEURMfMgwEIfYYPsg,1014
25
- kumoai/experimental/rfm/backend/local/table.py,sha256=Ahob9HidpU6z_M41rK5FATa3d7CL2UzZl8pGVyrzLNc,3565
26
- kumoai/experimental/rfm/backend/local/graph_store.py,sha256=RpfJldemOG-4RzGSIS9EcytHbvC4gYm-Ps3a-4qfptk,13297
27
- kumoai/experimental/rfm/backend/local/sampler.py,sha256=L1S2qxvkS_O8wy4K-czTxojPmklRrReTR8P3-e_8-hM,3823
28
- kumoai/experimental/rfm/backend/snow/__init__.py,sha256=B-tG-p8WA-mBuwvK1f0S2gdRPEGwApdxlnyeVSnY2xg,927
29
- kumoai/experimental/rfm/backend/snow/table.py,sha256=sHagXhW7RifzOiB4yjxV_9FtR0KUFVIw1mYwZe4bpMg,4255
24
+ kumoai/experimental/rfm/backend/local/table.py,sha256=-R_9nncosByAfSMfUt6HgCUNoW_MLGJW3F5SnAd4Ru0,3744
25
+ kumoai/experimental/rfm/backend/local/graph_store.py,sha256=5cHuExHljU_Z56KV3s-PwzeiLuPKgh2mCxcjTMmPZ8E,11928
26
+ kumoai/experimental/rfm/backend/local/sampler.py,sha256=85HoHCDiFOiuD_vFPZRx9JCyQUlLsqgsuB3NAw50wNw,10836
27
+ kumoai/experimental/rfm/backend/snow/__init__.py,sha256=BYfsiuJ4Ee30GjG9EuUtitMHXnRfvVKi85zNlIwldV4,993
28
+ kumoai/experimental/rfm/backend/snow/table.py,sha256=--kC_jh4kbXDulvWWwqERjZOKUQrAP506ksKtzQY7qk,4841
29
+ kumoai/experimental/rfm/backend/snow/sampler.py,sha256=vp_Z2Ov6IUaT9x5_tuPGFaz3XX_jQPLQZhkKX9AS5cI,10341
30
30
  kumoai/experimental/rfm/pquery/__init__.py,sha256=X0O3EIq5SMfBEE-ii5Cq6iDhR3s3XMXB52Cx5htoePw,152
31
- kumoai/experimental/rfm/pquery/pandas_executor.py,sha256=kiBJq7uVGbasG7TiqsubEl6ey3UYzZiM4bwxILqp_54,18487
31
+ kumoai/experimental/rfm/pquery/pandas_executor.py,sha256=wYI9a3smClR2pQGwsYRdmpOm0PlUsbtyW9wpAVpCEe4,18492
32
32
  kumoai/experimental/rfm/pquery/executor.py,sha256=f7-pJhL0BgFU9E4o4gQpQyArOvyrZtwxFmks34-QOAE,2741
33
33
  kumoai/experimental/rfm/infer/multicategorical.py,sha256=0-cLpDnGryhr76QhZNO-klKokJ6MUSfxXcGdQ61oykY,1102
34
34
  kumoai/experimental/rfm/infer/categorical.py,sha256=VwNaKwKbRYkTxEJ1R6gziffC8dGsEThcDEfbi-KqW5c,853
@@ -38,10 +38,11 @@ kumoai/experimental/rfm/infer/id.py,sha256=ZIO0DWIoiEoS_8MVc5lkqBfkTWWQ0yGCgjkwL
38
38
  kumoai/experimental/rfm/infer/dtype.py,sha256=ZZ6ztqJnTR1CaC2z5Uhf0o0rSdNThnss5tem5JNQkck,2607
39
39
  kumoai/experimental/rfm/infer/__init__.py,sha256=krdMFN8iKZlSFOl-M5MW1KuSviQV3H1E18jj2uB8g6Q,469
40
40
  kumoai/experimental/rfm/infer/timestamp.py,sha256=vM9--7eStzaGG13Y-oLYlpNJyhL6f9dp17HDXwtl_DM,1094
41
- kumoai/experimental/rfm/base/__init__.py,sha256=3haYsIYypeL-U-9RuOOPnRdWaRlh-g_yE4ACJ2KLjOY,335
42
- kumoai/experimental/rfm/base/table.py,sha256=yaY7Auvq2KblXOid3-a_Pw6RgnPK5Y1zGAY2xi1D2gg,19843
43
- kumoai/experimental/rfm/base/sampler.py,sha256=b45kllqSm-lpXbP9XbrGQPMx_hEIfesJILViAanh6rk,13456
44
- kumoai/experimental/rfm/base/source.py,sha256=8_waFQVsctryHkm9BwmFZ9-vw5cXAXfjk7KDmcl_kic,272
41
+ kumoai/experimental/rfm/base/sql_sampler.py,sha256=ibLn1pT2zLhs1VpK4PUf9E89aUO5q9iT1S2jmGYkKP4,1644
42
+ kumoai/experimental/rfm/base/__init__.py,sha256=8nCg154X94HTLVOATcO54tX3axFm8QlZG9T1M3ZasnI,549
43
+ kumoai/experimental/rfm/base/table.py,sha256=neGldEZaweoJ8VRgnEnaSpAISSkSTkgXxItuuywBM4E,20010
44
+ kumoai/experimental/rfm/base/sampler.py,sha256=aCD98t0CUhAvGXEFv24Vq2g4otuclpKkkyL1rMR_mFg,31449
45
+ kumoai/experimental/rfm/base/source.py,sha256=RqlI_kBoRV0ADb8KdEKn15RNHMdFUzEVzb57lIoyBM4,294
45
46
  kumoai/experimental/rfm/base/column.py,sha256=izCJmufJcd1RSi-ptFMfrue-JYag38MJxizka7ya0-A,2319
46
47
  kumoai/encoder/__init__.py,sha256=VPGs4miBC_WfwWeOXeHhFomOUocERFavhKf5fqITcds,182
47
48
  kumoai/graph/graph.py,sha256=iyp4klPIMn2ttuEqMJvsrxKb_tmz_DTnvziIhCegduM,38291
@@ -52,8 +53,9 @@ kumoai/artifact_export/config.py,sha256=jOPDduduxv0uuB-7xVlDiZglfpmFF5lzQhhH1SMk
52
53
  kumoai/artifact_export/job.py,sha256=GEisSwvcjK_35RgOfsLXGgxMTXIWm765B_BW_Kgs-V0,3275
53
54
  kumoai/artifact_export/__init__.py,sha256=BsfDrc3mCHpO9-BqvqKm8qrXDIwfdaoH5UIoG4eQkc4,238
54
55
  kumoai/utils/datasets.py,sha256=ptKIUoBONVD55pTVNdRCkQT3NWdN_r9UAUu4xewPa3U,2928
55
- kumoai/utils/__init__.py,sha256=wGDC_31XJ-7ipm6eawjLAJaP4EfmtNOH8BHzaetQ9Ko,268
56
+ kumoai/utils/__init__.py,sha256=cF5ACzp1X61sqhlCHc6biQk6fc4gW_oyhGsBrjx-SoM,316
56
57
  kumoai/utils/progress_logger.py,sha256=pngEGzMHkiOUKOa6fbzxCEc2xlA4SJKV4TDTVVoqObM,5062
58
+ kumoai/utils/sql.py,sha256=f6lR6rBEW7Dtk0NdM26dOZXUHDizEHb1WPlBCJrwoq0,118
57
59
  kumoai/utils/forecasting.py,sha256=-nDS6ucKNfQhTQOfebjefj0wwWH3-KYNslIomxwwMBM,7415
58
60
  kumoai/codegen/generate.py,sha256=SvfWWa71xSAOjH9645yQvgoEM-o4BYjupM_EpUxqB_E,7331
59
61
  kumoai/codegen/naming.py,sha256=_XVQGxHfuub4bhvyuBKjltD5Lm_oPpibvP_LZteCGk0,3021
@@ -71,6 +73,7 @@ kumoai/codegen/handlers/__init__.py,sha256=k8TB_Kn-1BycBBi51kqFS2fZHCpCPgR9-3J9g
71
73
  kumoai/codegen/handlers/utils.py,sha256=58b2GCgaTBUp2aId7BLMXMV0ENrusbNbfw7mlyXAXPE,1447
72
74
  kumoai/codegen/handlers/connector.py,sha256=afGf_GreyQ9y6qF3QTgSiM416qtUcP298SatNqUFhvQ,3828
73
75
  kumoai/codegen/handlers/table.py,sha256=POHpA-GFYFGTSuerGmtigYablk-Wq1L3EBvsOI-iFMQ,3956
76
+ kumoai/testing/snow.py,sha256=ubx3yJP0UHxsNiar1-jNdv8ZfszKc8Js3_Gg70uf008,1487
74
77
  kumoai/testing/__init__.py,sha256=goHIIo3JE7uHV7njo4_aTd89mVVR74BEAZ2uyBaOR0w,170
75
78
  kumoai/testing/decorators.py,sha256=83tMifuPTpUqX7zHxMttkj1TDdB62EBtAP-Fjj72Zdo,1607
76
79
  kumoai/connector/glue_connector.py,sha256=HivT0QYQ8-XeB4QLgWvghiqXuq7jyBK9G2R1py_NnE4,4697
@@ -84,10 +87,10 @@ kumoai/connector/utils.py,sha256=wlqQxMmPvnFNoCcczGkKYjSu05h8OhWh4fhTzQm_2bQ,646
84
87
  kumoai/connector/s3_connector.py,sha256=3kbv-h7DwD8O260Q0h1GPm5wwQpLt-Tb3d_CBSaie44,10155
85
88
  kumoai/connector/base.py,sha256=cujXSZF3zAfuxNuEw54DSL1T7XCuR4t0shSMDuPUagQ,5291
86
89
  kumoai/pquery/__init__.py,sha256=uTXr7t1eXcVfM-ETaM_1ImfEqhrmaj8BjiIvy1YZTL8,533
87
- kumoai/pquery/predictive_query.py,sha256=oUqwdOWLLkPM-G4PhpUk_6mwSJGBtaD3t37Wp5Oow8M,24971
90
+ kumoai/pquery/predictive_query.py,sha256=UXn1s8ztubYZMNGl4ijaeidMiGlFveb1TGw9qI5-TAo,24901
88
91
  kumoai/pquery/prediction_table.py,sha256=QPDH22X1UB0NIufY7qGuV2XW7brG3Pv--FbjNezzM2g,10776
89
92
  kumoai/pquery/training_table.py,sha256=elmPDZx11kPiC_dkOhJcBUGtHKgL32GCBvZ9k6U0pMg,15809
90
- kumoai/client/pquery.py,sha256=R2hc-M8vPoyIDH0ywLwFVxCznVAqpZz3w2HszjdNW-o,6891
93
+ kumoai/client/pquery.py,sha256=IQ8As-OOJOkuMoMosphOsA5hxQYLCbzOQJO7RezK8uY,7091
91
94
  kumoai/client/client.py,sha256=Jda8V9yiu3LbhxlcgRWPeYi7eF6jzCKcq8-B_vEd1ik,8514
92
95
  kumoai/client/graph.py,sha256=zvLEDExLT_RVbUMHqVl0m6tO6s2gXmYSoWmPF6YMlnA,3831
93
96
  kumoai/client/online.py,sha256=pkBBh_DEC3GAnPcNw6bopNRlGe7EUbIFe7_seQqZRaw,2720
@@ -106,8 +109,8 @@ kumoai/trainer/baseline_trainer.py,sha256=LlfViNOmswNv4c6zJJLsyv0pC2mM2WKMGYx06o
106
109
  kumoai/trainer/__init__.py,sha256=zUdFl-f-sBWmm2x8R-rdVzPBeU2FaMzUY5mkcgoTa1k,939
107
110
  kumoai/trainer/online_serving.py,sha256=9cddb5paeZaCgbUeceQdAOxysCtV5XP-KcsgFz_XR5w,9566
108
111
  kumoai/trainer/trainer.py,sha256=hBXO7gwpo3t59zKFTeIkK65B8QRmWCwO33sbDuEAPlY,20133
109
- kumoai-2.13.0.dev202512081731.dist-info/RECORD,,
110
- kumoai-2.13.0.dev202512081731.dist-info/WHEEL,sha256=oqGJCpG61FZJmvyZ3C_0aCv-2mdfcY9e3fXvyUNmWfM,136
111
- kumoai-2.13.0.dev202512081731.dist-info/top_level.txt,sha256=YjU6UcmomoDx30vEXLsOU784ED7VztQOsFApk1SFwvs,7
112
- kumoai-2.13.0.dev202512081731.dist-info/METADATA,sha256=ulcPeS_yowF-CWxGh5m_20ummlecVuiXBDNMgvXH-VU,2510
113
- kumoai-2.13.0.dev202512081731.dist-info/licenses/LICENSE,sha256=TbWlyqRmhq9PEzCaTI0H0nWLQCCOywQM8wYH8MbjfLo,1102
112
+ kumoai-2.14.0.dev202512151351.dist-info/RECORD,,
113
+ kumoai-2.14.0.dev202512151351.dist-info/WHEEL,sha256=oqGJCpG61FZJmvyZ3C_0aCv-2mdfcY9e3fXvyUNmWfM,136
114
+ kumoai-2.14.0.dev202512151351.dist-info/top_level.txt,sha256=YjU6UcmomoDx30vEXLsOU784ED7VztQOsFApk1SFwvs,7
115
+ kumoai-2.14.0.dev202512151351.dist-info/METADATA,sha256=RzqzEYc4ILs4XE1EIMufUNuB-yVzeuwDctT3Qb65zdk,2510
116
+ kumoai-2.14.0.dev202512151351.dist-info/licenses/LICENSE,sha256=TbWlyqRmhq9PEzCaTI0H0nWLQCCOywQM8wYH8MbjfLo,1102
@@ -1,223 +0,0 @@
1
- import re
2
- from typing import Dict, List, Optional, Tuple
3
-
4
- import numpy as np
5
- import pandas as pd
6
- from kumoapi.rfm.context import EdgeLayout, Link, Subgraph, Table
7
- from kumoapi.typing import Stype
8
-
9
- import kumoai.kumolib as kumolib
10
- from kumoai.experimental.rfm.backend.local import LocalGraphStore
11
-
12
- PUNCTUATION = re.compile(r"[\'\"\.,\(\)\!\?\;\:]")
13
- MULTISPACE = re.compile(r"\s+")
14
-
15
-
16
- def normalize_text(
17
- ser: pd.Series,
18
- max_words: Optional[int] = 50,
19
- ) -> pd.Series:
20
- r"""Normalizes text into a list of lower-case words.
21
-
22
- Args:
23
- ser: The :class:`pandas.Series` to normalize.
24
- max_words: The maximum number of words to return.
25
- This will auto-shrink any large text column to avoid blowing up
26
- context size.
27
- """
28
- if len(ser) == 0 or pd.api.types.is_list_like(ser.iloc[0]):
29
- return ser
30
-
31
- def normalize_fn(line: str) -> list[str]:
32
- line = PUNCTUATION.sub(" ", line)
33
- line = re.sub(r"<br\s*/?>", " ", line) # Handle <br /> or <br>
34
- line = MULTISPACE.sub(" ", line)
35
- words = line.split()
36
- if max_words is not None:
37
- words = words[:max_words]
38
- return words
39
-
40
- ser = ser.fillna('').astype(str)
41
-
42
- if max_words is not None:
43
- # We estimate the number of words as 5 characters + 1 space in an
44
- # English text on average. We need this pre-filter here, as word
45
- # splitting on a giant text can be very expensive:
46
- ser = ser.str[:6 * max_words]
47
-
48
- ser = ser.str.lower()
49
- ser = ser.map(normalize_fn)
50
-
51
- return ser
52
-
53
-
54
- class LocalGraphSampler:
55
- def __init__(self, graph_store: LocalGraphStore) -> None:
56
- self._graph_store = graph_store
57
- self._sampler = kumolib.NeighborSampler(
58
- self._graph_store.node_types,
59
- self._graph_store.edge_types,
60
- {
61
- '__'.join(edge_type): colptr
62
- for edge_type, colptr in self._graph_store.colptr_dict.items()
63
- },
64
- {
65
- '__'.join(edge_type): row
66
- for edge_type, row in self._graph_store.row_dict.items()
67
- },
68
- self._graph_store.time_dict,
69
- )
70
-
71
- def __call__(
72
- self,
73
- entity_table_names: Tuple[str, ...],
74
- node: np.ndarray,
75
- time: np.ndarray,
76
- num_neighbors: List[int],
77
- exclude_cols_dict: Dict[str, List[str]],
78
- ) -> Subgraph:
79
-
80
- (
81
- row_dict,
82
- col_dict,
83
- node_dict,
84
- batch_dict,
85
- num_sampled_nodes_dict,
86
- num_sampled_edges_dict,
87
- ) = self._sampler.sample(
88
- {
89
- '__'.join(edge_type): num_neighbors
90
- for edge_type in self._graph_store.edge_types
91
- },
92
- {}, # time interval based sampling
93
- entity_table_names[0],
94
- node,
95
- time // 1000**3, # nanoseconds to seconds
96
- )
97
-
98
- table_dict: Dict[str, Table] = {}
99
- for table_name, node in node_dict.items():
100
- batch = batch_dict[table_name]
101
-
102
- if len(node) == 0:
103
- continue
104
-
105
- df = self._graph_store.df_dict[table_name]
106
-
107
- num_sampled_nodes = num_sampled_nodes_dict[table_name].tolist()
108
- stype_dict = { # Exclude target columns:
109
- column_name: stype
110
- for column_name, stype in
111
- self._graph_store.stype_dict[table_name].items()
112
- if column_name not in exclude_cols_dict.get(table_name, [])
113
- }
114
- primary_key: Optional[str] = None
115
- if table_name in entity_table_names:
116
- primary_key = self._graph_store.pkey_name_dict.get(table_name)
117
-
118
- columns: List[str] = []
119
- if table_name in entity_table_names:
120
- columns += [self._graph_store.pkey_name_dict[table_name]]
121
- columns += list(stype_dict.keys())
122
-
123
- if len(columns) == 0:
124
- table_dict[table_name] = Table(
125
- df=pd.DataFrame(index=range(len(node))),
126
- row=None,
127
- batch=batch,
128
- num_sampled_nodes=num_sampled_nodes,
129
- stype_dict=stype_dict,
130
- primary_key=primary_key,
131
- )
132
- continue
133
-
134
- row: Optional[np.ndarray] = None
135
- if table_name in self._graph_store.end_time_column_dict:
136
- # Set end time to NaT for all values greater than anchor time:
137
- df = df.iloc[node].reset_index(drop=True)
138
- col_name = self._graph_store.end_time_column_dict[table_name]
139
- ser = df[col_name]
140
- value = ser.astype('datetime64[ns]').astype(int).to_numpy()
141
- mask = value > time[batch]
142
- df.loc[mask, col_name] = pd.NaT
143
- else:
144
- # Only store unique rows in `df` above a certain threshold:
145
- unique_node, inverse = np.unique(node, return_inverse=True)
146
- if len(node) > 1.05 * len(unique_node):
147
- df = df.iloc[unique_node].reset_index(drop=True)
148
- row = inverse
149
- else:
150
- df = df.iloc[node].reset_index(drop=True)
151
-
152
- # Filter data frame to minimal set of columns:
153
- df = df[columns]
154
-
155
- # Normalize text (if not already pre-processed):
156
- for column_name, stype in stype_dict.items():
157
- if stype == Stype.text:
158
- df[column_name] = normalize_text(df[column_name])
159
-
160
- table_dict[table_name] = Table(
161
- df=df,
162
- row=row,
163
- batch=batch,
164
- num_sampled_nodes=num_sampled_nodes,
165
- stype_dict=stype_dict,
166
- primary_key=primary_key,
167
- )
168
-
169
- link_dict: Dict[Tuple[str, str, str], Link] = {}
170
- for edge_type in self._graph_store.edge_types:
171
- edge_type_str = '__'.join(edge_type)
172
-
173
- row = row_dict[edge_type_str]
174
- col = col_dict[edge_type_str]
175
-
176
- if len(row) == 0:
177
- continue
178
-
179
- # Do not store reverse edge type if it is a replica:
180
- rev_edge_type = Subgraph.rev_edge_type(edge_type)
181
- rev_edge_type_str = '__'.join(rev_edge_type)
182
- if (rev_edge_type in link_dict
183
- and np.array_equal(row, col_dict[rev_edge_type_str])
184
- and np.array_equal(col, row_dict[rev_edge_type_str])):
185
- link = Link(
186
- layout=EdgeLayout.REV,
187
- row=None,
188
- col=None,
189
- num_sampled_edges=(
190
- num_sampled_edges_dict[edge_type_str].tolist()),
191
- )
192
- link_dict[edge_type] = link
193
- continue
194
-
195
- layout = EdgeLayout.COO
196
- if np.array_equal(row, np.arange(len(row))):
197
- row = None
198
- if np.array_equal(col, np.arange(len(col))):
199
- col = None
200
-
201
- # Store in compressed representation if more efficient:
202
- num_cols = table_dict[edge_type[2]].num_rows
203
- if col is not None and len(col) > num_cols + 1:
204
- layout = EdgeLayout.CSC
205
- colcount = np.bincount(col, minlength=num_cols)
206
- col = np.empty(num_cols + 1, dtype=col.dtype)
207
- col[0] = 0
208
- np.cumsum(colcount, out=col[1:])
209
-
210
- link = Link(
211
- layout=layout,
212
- row=row,
213
- col=col,
214
- num_sampled_edges=(
215
- num_sampled_edges_dict[edge_type_str].tolist()),
216
- )
217
- link_dict[edge_type] = link
218
-
219
- return Subgraph(
220
- anchor_time=time,
221
- table_dict=table_dict,
222
- link_dict=link_dict,
223
- )