kumoai 2.13.0.dev202511211730__py3-none-any.whl → 2.14.0.dev202512141732__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. kumoai/__init__.py +12 -0
  2. kumoai/_version.py +1 -1
  3. kumoai/client/pquery.py +6 -2
  4. kumoai/connector/utils.py +23 -2
  5. kumoai/experimental/rfm/__init__.py +20 -45
  6. kumoai/experimental/rfm/backend/__init__.py +0 -0
  7. kumoai/experimental/rfm/backend/local/__init__.py +42 -0
  8. kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +37 -90
  9. kumoai/experimental/rfm/backend/local/sampler.py +313 -0
  10. kumoai/experimental/rfm/backend/local/table.py +119 -0
  11. kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
  12. kumoai/experimental/rfm/backend/snow/sampler.py +119 -0
  13. kumoai/experimental/rfm/backend/snow/table.py +135 -0
  14. kumoai/experimental/rfm/backend/sqlite/__init__.py +32 -0
  15. kumoai/experimental/rfm/backend/sqlite/sampler.py +112 -0
  16. kumoai/experimental/rfm/backend/sqlite/table.py +115 -0
  17. kumoai/experimental/rfm/base/__init__.py +23 -0
  18. kumoai/experimental/rfm/base/column.py +66 -0
  19. kumoai/experimental/rfm/base/sampler.py +773 -0
  20. kumoai/experimental/rfm/base/source.py +19 -0
  21. kumoai/experimental/rfm/{local_table.py → base/table.py} +152 -141
  22. kumoai/experimental/rfm/{local_graph.py → graph.py} +352 -80
  23. kumoai/experimental/rfm/infer/__init__.py +6 -0
  24. kumoai/experimental/rfm/infer/dtype.py +79 -0
  25. kumoai/experimental/rfm/infer/pkey.py +126 -0
  26. kumoai/experimental/rfm/infer/time_col.py +62 -0
  27. kumoai/experimental/rfm/pquery/pandas_executor.py +1 -1
  28. kumoai/experimental/rfm/rfm.py +224 -167
  29. kumoai/experimental/rfm/sagemaker.py +11 -3
  30. kumoai/pquery/predictive_query.py +10 -6
  31. kumoai/testing/decorators.py +1 -1
  32. kumoai/testing/snow.py +50 -0
  33. kumoai/utils/__init__.py +2 -0
  34. kumoai/utils/sql.py +3 -0
  35. {kumoai-2.13.0.dev202511211730.dist-info → kumoai-2.14.0.dev202512141732.dist-info}/METADATA +9 -8
  36. {kumoai-2.13.0.dev202511211730.dist-info → kumoai-2.14.0.dev202512141732.dist-info}/RECORD +39 -23
  37. kumoai/experimental/rfm/local_graph_sampler.py +0 -182
  38. kumoai/experimental/rfm/local_pquery_driver.py +0 -689
  39. kumoai/experimental/rfm/utils.py +0 -344
  40. {kumoai-2.13.0.dev202511211730.dist-info → kumoai-2.14.0.dev202512141732.dist-info}/WHEEL +0 -0
  41. {kumoai-2.13.0.dev202511211730.dist-info → kumoai-2.14.0.dev202512141732.dist-info}/licenses/LICENSE +0 -0
  42. {kumoai-2.13.0.dev202511211730.dist-info → kumoai-2.14.0.dev202512141732.dist-info}/top_level.txt +0 -0
@@ -2,15 +2,22 @@ import base64
2
2
  import json
3
3
  from typing import Any, Dict, List, Tuple
4
4
 
5
- import boto3
6
5
  import requests
7
- from mypy_boto3_sagemaker_runtime.client import SageMakerRuntimeClient
8
- from mypy_boto3_sagemaker_runtime.type_defs import InvokeEndpointOutputTypeDef
9
6
 
10
7
  from kumoai.client import KumoClient
11
8
  from kumoai.client.endpoints import Endpoint, HTTPMethod
12
9
  from kumoai.exceptions import HTTPException
13
10
 
11
+ try:
12
+ # isort: off
13
+ from mypy_boto3_sagemaker_runtime.client import SageMakerRuntimeClient
14
+ from mypy_boto3_sagemaker_runtime.type_defs import (
15
+ InvokeEndpointOutputTypeDef, )
16
+ # isort: on
17
+ except ImportError:
18
+ SageMakerRuntimeClient = Any
19
+ InvokeEndpointOutputTypeDef = Any
20
+
14
21
 
15
22
  class SageMakerResponseAdapter(requests.Response):
16
23
  def __init__(self, sm_response: InvokeEndpointOutputTypeDef):
@@ -34,6 +41,7 @@ class SageMakerResponseAdapter(requests.Response):
34
41
 
35
42
  class KumoClient_SageMakerAdapter(KumoClient):
36
43
  def __init__(self, region: str, endpoint_name: str):
44
+ import boto3
37
45
  self._client: SageMakerRuntimeClient = boto3.client(
38
46
  service_name="sagemaker-runtime", region_name=region)
39
47
  self._endpoint_name = endpoint_name
@@ -370,9 +370,11 @@ class PredictiveQuery:
370
370
  train_table_job_api = global_state.client.generate_train_table_job_api
371
371
  job_id: GenerateTrainTableJobID = train_table_job_api.create(
372
372
  GenerateTrainTableRequest(
373
- dict(custom_tags), pq_id, plan,
374
- graph_snapshot_id=self.graph.snapshot(
375
- non_blocking=non_blocking)))
373
+ dict(custom_tags),
374
+ pq_id,
375
+ plan,
376
+ None,
377
+ ))
376
378
 
377
379
  self._train_table = TrainingTableJob(job_id=job_id)
378
380
  if non_blocking:
@@ -451,9 +453,11 @@ class PredictiveQuery:
451
453
  bp_table_api = global_state.client.generate_prediction_table_job_api
452
454
  job_id: GeneratePredictionTableJobID = bp_table_api.create(
453
455
  GeneratePredictionTableRequest(
454
- dict(custom_tags), pq_id, plan,
455
- graph_snapshot_id=self.graph.snapshot(
456
- non_blocking=non_blocking)))
456
+ dict(custom_tags),
457
+ pq_id,
458
+ plan,
459
+ None,
460
+ ))
457
461
 
458
462
  self._prediction_table = PredictionTableJob(job_id=job_id)
459
463
  if non_blocking:
@@ -25,7 +25,7 @@ def onlyFullTest(func: Callable) -> Callable:
25
25
  def has_package(package: str) -> bool:
26
26
  r"""Returns ``True`` in case ``package`` is installed."""
27
27
  req = Requirement(package)
28
- if importlib.util.find_spec(req.name) is None:
28
+ if importlib.util.find_spec(req.name) is None: # type: ignore
29
29
  return False
30
30
 
31
31
  try:
kumoai/testing/snow.py ADDED
@@ -0,0 +1,50 @@
1
+ import json
2
+ import os
3
+
4
+ from kumoai.experimental.rfm.backend.snow import Connection
5
+ from kumoai.experimental.rfm.backend.snow import connect as _connect
6
+
7
+
8
+ def connect(
9
+ region: str,
10
+ id: str,
11
+ account: str,
12
+ user: str,
13
+ warehouse: str,
14
+ database: str | None = None,
15
+ schema: str | None = None,
16
+ ) -> Connection:
17
+
18
+ kwargs = dict(password=os.getenv('SNOWFLAKE_PASSWORD'))
19
+ if kwargs['password'] is None:
20
+ import boto3
21
+ from cryptography.hazmat.primitives import serialization
22
+
23
+ client = boto3.client(
24
+ service_name='secretsmanager',
25
+ region_name=region,
26
+ )
27
+ secret_id = (f'arn:aws:secretsmanager:{region}:{id}:secret:'
28
+ f'{account}.snowflakecomputing.com')
29
+ response = client.get_secret_value(SecretId=secret_id)['SecretString']
30
+ secret = json.loads(response)
31
+
32
+ private_key = serialization.load_pem_private_key(
33
+ secret['kumo_user_secretkey'].encode(),
34
+ password=None,
35
+ )
36
+ kwargs['private_key'] = private_key.private_bytes(
37
+ encoding=serialization.Encoding.DER,
38
+ format=serialization.PrivateFormat.PKCS8,
39
+ encryption_algorithm=serialization.NoEncryption(),
40
+ )
41
+
42
+ return _connect(
43
+ account=account,
44
+ user=user,
45
+ warehouse='WH_XS',
46
+ database='KUMO',
47
+ schema=schema,
48
+ session_parameters=dict(CLIENT_TELEMETRY_ENABLED=False),
49
+ **kwargs,
50
+ )
kumoai/utils/__init__.py CHANGED
@@ -1,8 +1,10 @@
1
+ from .sql import quote_ident
1
2
  from .progress_logger import ProgressLogger, InteractiveProgressLogger
2
3
  from .forecasting import ForecastVisualizer
3
4
  from .datasets import from_relbench
4
5
 
5
6
  __all__ = [
7
+ 'quote_ident',
6
8
  'ProgressLogger',
7
9
  'InteractiveProgressLogger',
8
10
  'ForecastVisualizer',
kumoai/utils/sql.py ADDED
@@ -0,0 +1,3 @@
1
+ def quote_ident(name: str) -> str:
2
+ r"""Quotes a SQL identifier."""
3
+ return '"' + name.replace('"', '""') + '"'
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: kumoai
3
- Version: 2.13.0.dev202511211730
3
+ Version: 2.14.0.dev202512141732
4
4
  Summary: AI on the Modern Data Stack
5
5
  Author-email: "Kumo.AI" <hello@kumo.ai>
6
6
  License-Expression: MIT
@@ -23,13 +23,11 @@ Requires-Dist: requests>=2.28.2
23
23
  Requires-Dist: urllib3
24
24
  Requires-Dist: plotly
25
25
  Requires-Dist: typing_extensions>=4.5.0
26
- Requires-Dist: kumo-api==0.46.0
26
+ Requires-Dist: kumo-api==0.49.0
27
27
  Requires-Dist: tqdm>=4.66.0
28
28
  Requires-Dist: aiohttp>=3.10.0
29
29
  Requires-Dist: pydantic>=1.10.21
30
30
  Requires-Dist: rich>=9.0.0
31
- Requires-Dist: mypy-boto3-sagemaker-runtime
32
- Requires-Dist: boto3
33
31
  Provides-Extra: doc
34
32
  Requires-Dist: sphinx; extra == "doc"
35
33
  Requires-Dist: sphinx-book-theme; extra == "doc"
@@ -40,13 +38,16 @@ Provides-Extra: test
40
38
  Requires-Dist: pytest; extra == "test"
41
39
  Requires-Dist: pytest-mock; extra == "test"
42
40
  Requires-Dist: requests-mock; extra == "test"
43
- Provides-Extra: test-sagemaker
44
- Requires-Dist: sagemaker; extra == "test-sagemaker"
45
- Requires-Dist: pandas==2.1.4; extra == "test-sagemaker"
46
- Requires-Dist: pyarrow==12.0.1; extra == "test-sagemaker"
41
+ Provides-Extra: sqlite
42
+ Requires-Dist: adbc_driver_sqlite; extra == "sqlite"
43
+ Provides-Extra: snowflake
44
+ Requires-Dist: snowflake-connector-python; extra == "snowflake"
45
+ Requires-Dist: pyyaml; extra == "snowflake"
47
46
  Provides-Extra: sagemaker
48
47
  Requires-Dist: boto3<2.0,>=1.30.0; extra == "sagemaker"
49
48
  Requires-Dist: mypy-boto3-sagemaker-runtime<2.0,>=1.34.0; extra == "sagemaker"
49
+ Provides-Extra: test-sagemaker
50
+ Requires-Dist: sagemaker<3.0; extra == "test-sagemaker"
50
51
  Dynamic: license-file
51
52
  Dynamic: requires-dist
52
53
 
@@ -1,7 +1,7 @@
1
- kumoai/__init__.py,sha256=L3yOOtpSdwe3PYQlJBLkiQd3Ypp8iB5ChXkzprk3Si4,10546
1
+ kumoai/__init__.py,sha256=Nn9YH_x9kAeEFn8RWbP95slZow0qFnakPZZ1WADe1hY,10843
2
2
  kumoai/_logging.py,sha256=U2_5ROdyk92P4xO4H2WJV8EC7dr6YxmmnM-b7QX9M7I,886
3
3
  kumoai/_singleton.py,sha256=UTwrbDkoZSGB8ZelorvprPDDv9uZkUi1q_SrmsyngpQ,836
4
- kumoai/_version.py,sha256=SG9nFVn5zQGjYC2bJHf14ClhCtvj0GzhN9puMBHh-sE,39
4
+ kumoai/_version.py,sha256=jDuUtpQbekehTWtqQXeU4rYieU0OA-7URATbFO1iyvo,39
5
5
  kumoai/databricks.py,sha256=e6E4lOFvZHXFwh4CO1kXU1zzDU3AapLQYMxjiHPC-HQ,476
6
6
  kumoai/exceptions.py,sha256=b-_sdbAKOg50uaJZ65GmBLdTo4HANdjl8_R0sJpwaN0,833
7
7
  kumoai/formatting.py,sha256=jA_rLDCGKZI8WWCha-vtuLenVKTZvli99Tqpurz1H84,953
@@ -19,7 +19,7 @@ kumoai/client/endpoints.py,sha256=iF2ZD25AJCIVbmBJ8tTZ8y1Ch0m6nTp18ydN7h4WiTk,53
19
19
  kumoai/client/graph.py,sha256=zvLEDExLT_RVbUMHqVl0m6tO6s2gXmYSoWmPF6YMlnA,3831
20
20
  kumoai/client/jobs.py,sha256=iu_Wrta6BQMlV6ZtzSnmhjwNPKDMQDXOsqVVIyWodqw,17074
21
21
  kumoai/client/online.py,sha256=pkBBh_DEC3GAnPcNw6bopNRlGe7EUbIFe7_seQqZRaw,2720
22
- kumoai/client/pquery.py,sha256=R2hc-M8vPoyIDH0ywLwFVxCznVAqpZz3w2HszjdNW-o,6891
22
+ kumoai/client/pquery.py,sha256=IQ8As-OOJOkuMoMosphOsA5hxQYLCbzOQJO7RezK8uY,7091
23
23
  kumoai/client/rfm.py,sha256=NxKk8mH2A-B58rSXhDWaph4KeiSyJYDq-RO-vAHh7es,3726
24
24
  kumoai/client/source_table.py,sha256=VCsCcM7KYcnjGP7HLTb-AOSEGEVsJTWjk8bMg1JdgPU,2101
25
25
  kumoai/client/table.py,sha256=cQG-RPm-e91idEgse1IPJDvBmzddIDGDkuyrR1rq4wU,3235
@@ -49,37 +49,52 @@ kumoai/connector/glue_connector.py,sha256=HivT0QYQ8-XeB4QLgWvghiqXuq7jyBK9G2R1py
49
49
  kumoai/connector/s3_connector.py,sha256=3kbv-h7DwD8O260Q0h1GPm5wwQpLt-Tb3d_CBSaie44,10155
50
50
  kumoai/connector/snowflake_connector.py,sha256=K0s-H9tW3rve8g2x1PbyxvzSpkROfGQZz-Qa4PoT4UE,9022
51
51
  kumoai/connector/source_table.py,sha256=QLT8bEYaxeMwy-b168url0VfnkTrs5K6VKLbxTI4hEY,17539
52
- kumoai/connector/utils.py,sha256=PUjunLpfqMZsrPDo2EmnyJRBl_mt-E6ugv2kNkf5Rn8,64011
52
+ kumoai/connector/utils.py,sha256=wlqQxMmPvnFNoCcczGkKYjSu05h8OhWh4fhTzQm_2bQ,64694
53
53
  kumoai/encoder/__init__.py,sha256=VPGs4miBC_WfwWeOXeHhFomOUocERFavhKf5fqITcds,182
54
54
  kumoai/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
55
- kumoai/experimental/rfm/__init__.py,sha256=wKfMKTxfuJNH1GCWGZ7-288HXil0tsCuXqg-BAFctZE,6812
55
+ kumoai/experimental/rfm/__init__.py,sha256=slliYcrh80xPtQQ_nnsp3ny9IbmHCyirmdZUfKTdME4,6064
56
56
  kumoai/experimental/rfm/authenticate.py,sha256=FiuHMvP7V3zBZUlHMDMbNLhc-UgDZgz4hjVSTuQ7DRw,18888
57
- kumoai/experimental/rfm/local_graph.py,sha256=2iJDlsGVzqCe1bD_puXWlhwGkn7YnQyJ4p4C-fwCZNE,30076
58
- kumoai/experimental/rfm/local_graph_sampler.py,sha256=5DbhL9h0usFKSJfnx7HjLMPcG54qwJ48M2tmONqxXyY,6672
59
- kumoai/experimental/rfm/local_graph_store.py,sha256=8BqonuaMftAAsjgZpB369i5AeNd1PkisMbbEqc0cKBo,13847
60
- kumoai/experimental/rfm/local_pquery_driver.py,sha256=aO7Jfwx9gxGKYvpqxZx1LLWdI1MhuZQOPtAITxoOQO0,26162
61
- kumoai/experimental/rfm/local_table.py,sha256=r8xZ33Mjs6JD8ud6h23tZ99Dag2DvZ4h6tWjmGrKQg4,19605
62
- kumoai/experimental/rfm/rfm.py,sha256=8SvGWfMuRYJgiz5OTplu7m47mDrHAjQ2mRZtRASnSCk,48136
63
- kumoai/experimental/rfm/sagemaker.py,sha256=e0rRQ28WcgAk_ALqUyU20d193c8_68rCkSengZIHu3Y,4823
64
- kumoai/experimental/rfm/utils.py,sha256=3IiBvT_aLBkkcJh3H11_50yt_XlEzHR0cm9Kprrtl8k,11123
65
- kumoai/experimental/rfm/infer/__init__.py,sha256=xQ8_SuejIzXyn2J7bIKX3pXumFtRuEfBtE5oEDUDJjI,293
57
+ kumoai/experimental/rfm/graph.py,sha256=awVJSk4cWRMacS5CJvJtR8TR56FEbrJPcQCukNydQOc,40392
58
+ kumoai/experimental/rfm/rfm.py,sha256=j4LfCbmHPs9RrwRNjTEE-sPUC6RcHNcymKBtG_Rt-M4,49670
59
+ kumoai/experimental/rfm/sagemaker.py,sha256=_hTrFg4qfXe7uzwqSEG_wze-IFkwn7qde9OpUodCpbc,4982
60
+ kumoai/experimental/rfm/backend/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
61
+ kumoai/experimental/rfm/backend/local/__init__.py,sha256=2s9sSA-E-8pfkkzCH4XPuaSxSznEURMfMgwEIfYYPsg,1014
62
+ kumoai/experimental/rfm/backend/local/graph_store.py,sha256=5cHuExHljU_Z56KV3s-PwzeiLuPKgh2mCxcjTMmPZ8E,11928
63
+ kumoai/experimental/rfm/backend/local/sampler.py,sha256=wJXQdiOOCIjqpNkDoXgPV0cnsKjeJ2Zi77-pJ8zAEM4,10738
64
+ kumoai/experimental/rfm/backend/local/table.py,sha256=-R_9nncosByAfSMfUt6HgCUNoW_MLGJW3F5SnAd4Ru0,3744
65
+ kumoai/experimental/rfm/backend/snow/__init__.py,sha256=BYfsiuJ4Ee30GjG9EuUtitMHXnRfvVKi85zNlIwldV4,993
66
+ kumoai/experimental/rfm/backend/snow/sampler.py,sha256=d5hyKyiNxnjmINGxM7K5nCVnN7ZaknM1QdHLt8EBt5c,4287
67
+ kumoai/experimental/rfm/backend/snow/table.py,sha256=bnBGgkXmgLcSZlr0t4pYo9aSk7DO5PuSDM06HBgA5BE,4841
68
+ kumoai/experimental/rfm/backend/sqlite/__init__.py,sha256=cA-PZL1oTaLxthZbfLSudexImtF6jRsGkdjSp-66dCM,914
69
+ kumoai/experimental/rfm/backend/sqlite/sampler.py,sha256=n3etgKH7viiIgMXHFLM3KQvKvCpylXRTw0WBNoSKd7s,4057
70
+ kumoai/experimental/rfm/backend/sqlite/table.py,sha256=9F2_6iqgM3h9jkIcecHnEwQq-UG3RBmR9Y9XPcFLR04,4161
71
+ kumoai/experimental/rfm/base/__init__.py,sha256=IdQwFFEeXZnTBkj1I65qQDL6QlcbmjVoTP6-quCtmFc,481
72
+ kumoai/experimental/rfm/base/column.py,sha256=izCJmufJcd1RSi-ptFMfrue-JYag38MJxizka7ya0-A,2319
73
+ kumoai/experimental/rfm/base/sampler.py,sha256=aCD98t0CUhAvGXEFv24Vq2g4otuclpKkkyL1rMR_mFg,31449
74
+ kumoai/experimental/rfm/base/source.py,sha256=RqlI_kBoRV0ADb8KdEKn15RNHMdFUzEVzb57lIoyBM4,294
75
+ kumoai/experimental/rfm/base/table.py,sha256=neGldEZaweoJ8VRgnEnaSpAISSkSTkgXxItuuywBM4E,20010
76
+ kumoai/experimental/rfm/infer/__init__.py,sha256=krdMFN8iKZlSFOl-M5MW1KuSviQV3H1E18jj2uB8g6Q,469
66
77
  kumoai/experimental/rfm/infer/categorical.py,sha256=VwNaKwKbRYkTxEJ1R6gziffC8dGsEThcDEfbi-KqW5c,853
78
+ kumoai/experimental/rfm/infer/dtype.py,sha256=ZZ6ztqJnTR1CaC2z5Uhf0o0rSdNThnss5tem5JNQkck,2607
67
79
  kumoai/experimental/rfm/infer/id.py,sha256=ZIO0DWIoiEoS_8MVc5lkqBfkTWWQ0yGCgjkwLdaYa_Q,908
68
80
  kumoai/experimental/rfm/infer/multicategorical.py,sha256=0-cLpDnGryhr76QhZNO-klKokJ6MUSfxXcGdQ61oykY,1102
81
+ kumoai/experimental/rfm/infer/pkey.py,sha256=ubNqW1LIjLKiXbjXELAY3g6n2f3u2Eis_uC2DEiXFiU,4393
82
+ kumoai/experimental/rfm/infer/time_col.py,sha256=7R5Itl8RRBOr61qLpRTanIqrUVZFZcAXzDA9lCw4nx4,1820
69
83
  kumoai/experimental/rfm/infer/timestamp.py,sha256=vM9--7eStzaGG13Y-oLYlpNJyhL6f9dp17HDXwtl_DM,1094
70
84
  kumoai/experimental/rfm/pquery/__init__.py,sha256=X0O3EIq5SMfBEE-ii5Cq6iDhR3s3XMXB52Cx5htoePw,152
71
85
  kumoai/experimental/rfm/pquery/executor.py,sha256=f7-pJhL0BgFU9E4o4gQpQyArOvyrZtwxFmks34-QOAE,2741
72
- kumoai/experimental/rfm/pquery/pandas_executor.py,sha256=kiBJq7uVGbasG7TiqsubEl6ey3UYzZiM4bwxILqp_54,18487
86
+ kumoai/experimental/rfm/pquery/pandas_executor.py,sha256=wYI9a3smClR2pQGwsYRdmpOm0PlUsbtyW9wpAVpCEe4,18492
73
87
  kumoai/graph/__init__.py,sha256=n8X4X8luox4hPBHTRC9R-3JzvYYMoR8n7lF1H4w4Hzc,228
74
88
  kumoai/graph/column.py,sha256=t7wBmcx0VYKXjIoESU9Nq-AisiJOdlqd80t8zby1R8Y,4189
75
89
  kumoai/graph/graph.py,sha256=iyp4klPIMn2ttuEqMJvsrxKb_tmz_DTnvziIhCegduM,38291
76
90
  kumoai/graph/table.py,sha256=nZqYX8xlyAz6kVtlE2vf9BAIOCoWeFNIfbGbReDCb7k,33888
77
91
  kumoai/pquery/__init__.py,sha256=uTXr7t1eXcVfM-ETaM_1ImfEqhrmaj8BjiIvy1YZTL8,533
78
92
  kumoai/pquery/prediction_table.py,sha256=QPDH22X1UB0NIufY7qGuV2XW7brG3Pv--FbjNezzM2g,10776
79
- kumoai/pquery/predictive_query.py,sha256=oUqwdOWLLkPM-G4PhpUk_6mwSJGBtaD3t37Wp5Oow8M,24971
93
+ kumoai/pquery/predictive_query.py,sha256=UXn1s8ztubYZMNGl4ijaeidMiGlFveb1TGw9qI5-TAo,24901
80
94
  kumoai/pquery/training_table.py,sha256=elmPDZx11kPiC_dkOhJcBUGtHKgL32GCBvZ9k6U0pMg,15809
81
95
  kumoai/testing/__init__.py,sha256=goHIIo3JE7uHV7njo4_aTd89mVVR74BEAZ2uyBaOR0w,170
82
- kumoai/testing/decorators.py,sha256=RiFrJcP-ym-mB1BYSGC26bBiryxoR9-GwL1G4EHc2sc,1591
96
+ kumoai/testing/decorators.py,sha256=83tMifuPTpUqX7zHxMttkj1TDdB62EBtAP-Fjj72Zdo,1607
97
+ kumoai/testing/snow.py,sha256=ubx3yJP0UHxsNiar1-jNdv8ZfszKc8Js3_Gg70uf008,1487
83
98
  kumoai/trainer/__init__.py,sha256=zUdFl-f-sBWmm2x8R-rdVzPBeU2FaMzUY5mkcgoTa1k,939
84
99
  kumoai/trainer/baseline_trainer.py,sha256=LlfViNOmswNv4c6zJJLsyv0pC2mM2WKMGYx06ogtEVc,4024
85
100
  kumoai/trainer/config.py,sha256=-2RfK10AsVVThSyfWtlyfH4Fc4EwTdu0V3yrDRtIOjk,98
@@ -87,12 +102,13 @@ kumoai/trainer/job.py,sha256=Wk69nzFhbvuA3nEvtCstI04z5CxkgvQ6tHnGchE0Lkg,44938
87
102
  kumoai/trainer/online_serving.py,sha256=9cddb5paeZaCgbUeceQdAOxysCtV5XP-KcsgFz_XR5w,9566
88
103
  kumoai/trainer/trainer.py,sha256=hBXO7gwpo3t59zKFTeIkK65B8QRmWCwO33sbDuEAPlY,20133
89
104
  kumoai/trainer/util.py,sha256=bDPGkMF9KOy4HgtA-OwhXP17z9cbrfMnZGtyGuUq_Eo,4062
90
- kumoai/utils/__init__.py,sha256=wGDC_31XJ-7ipm6eawjLAJaP4EfmtNOH8BHzaetQ9Ko,268
105
+ kumoai/utils/__init__.py,sha256=cF5ACzp1X61sqhlCHc6biQk6fc4gW_oyhGsBrjx-SoM,316
91
106
  kumoai/utils/datasets.py,sha256=ptKIUoBONVD55pTVNdRCkQT3NWdN_r9UAUu4xewPa3U,2928
92
107
  kumoai/utils/forecasting.py,sha256=-nDS6ucKNfQhTQOfebjefj0wwWH3-KYNslIomxwwMBM,7415
93
108
  kumoai/utils/progress_logger.py,sha256=pngEGzMHkiOUKOa6fbzxCEc2xlA4SJKV4TDTVVoqObM,5062
94
- kumoai-2.13.0.dev202511211730.dist-info/licenses/LICENSE,sha256=TbWlyqRmhq9PEzCaTI0H0nWLQCCOywQM8wYH8MbjfLo,1102
95
- kumoai-2.13.0.dev202511211730.dist-info/METADATA,sha256=S4KjjwNYqtLxfkX6vqdiOoo-iUPcfoXnKGJif23K9jU,2475
96
- kumoai-2.13.0.dev202511211730.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
97
- kumoai-2.13.0.dev202511211730.dist-info/top_level.txt,sha256=YjU6UcmomoDx30vEXLsOU784ED7VztQOsFApk1SFwvs,7
98
- kumoai-2.13.0.dev202511211730.dist-info/RECORD,,
109
+ kumoai/utils/sql.py,sha256=f6lR6rBEW7Dtk0NdM26dOZXUHDizEHb1WPlBCJrwoq0,118
110
+ kumoai-2.14.0.dev202512141732.dist-info/licenses/LICENSE,sha256=TbWlyqRmhq9PEzCaTI0H0nWLQCCOywQM8wYH8MbjfLo,1102
111
+ kumoai-2.14.0.dev202512141732.dist-info/METADATA,sha256=afHVHJGv5rxq96TO9d3VqyuN_Nhvf-dFGx3nDJo41lk,2510
112
+ kumoai-2.14.0.dev202512141732.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
113
+ kumoai-2.14.0.dev202512141732.dist-info/top_level.txt,sha256=YjU6UcmomoDx30vEXLsOU784ED7VztQOsFApk1SFwvs,7
114
+ kumoai-2.14.0.dev202512141732.dist-info/RECORD,,
@@ -1,182 +0,0 @@
1
- from typing import Dict, List, Optional, Tuple
2
-
3
- import numpy as np
4
- import pandas as pd
5
- from kumoapi.rfm.context import EdgeLayout, Link, Subgraph, Table
6
- from kumoapi.typing import Stype
7
-
8
- import kumoai.kumolib as kumolib
9
- from kumoai.experimental.rfm.local_graph_store import LocalGraphStore
10
- from kumoai.experimental.rfm.utils import normalize_text
11
-
12
-
13
- class LocalGraphSampler:
14
- def __init__(self, graph_store: LocalGraphStore) -> None:
15
- self._graph_store = graph_store
16
- self._sampler = kumolib.NeighborSampler(
17
- self._graph_store.node_types,
18
- self._graph_store.edge_types,
19
- {
20
- '__'.join(edge_type): colptr
21
- for edge_type, colptr in self._graph_store.colptr_dict.items()
22
- },
23
- {
24
- '__'.join(edge_type): row
25
- for edge_type, row in self._graph_store.row_dict.items()
26
- },
27
- self._graph_store.time_dict,
28
- )
29
-
30
- def __call__(
31
- self,
32
- entity_table_names: Tuple[str, ...],
33
- node: np.ndarray,
34
- time: np.ndarray,
35
- num_neighbors: List[int],
36
- exclude_cols_dict: Dict[str, List[str]],
37
- ) -> Subgraph:
38
-
39
- (
40
- row_dict,
41
- col_dict,
42
- node_dict,
43
- batch_dict,
44
- num_sampled_nodes_dict,
45
- num_sampled_edges_dict,
46
- ) = self._sampler.sample(
47
- {
48
- '__'.join(edge_type): num_neighbors
49
- for edge_type in self._graph_store.edge_types
50
- },
51
- {}, # time interval based sampling
52
- entity_table_names[0],
53
- node,
54
- time // 1000**3, # nanoseconds to seconds
55
- )
56
-
57
- table_dict: Dict[str, Table] = {}
58
- for table_name, node in node_dict.items():
59
- batch = batch_dict[table_name]
60
-
61
- if len(node) == 0:
62
- continue
63
-
64
- df = self._graph_store.df_dict[table_name]
65
-
66
- num_sampled_nodes = num_sampled_nodes_dict[table_name].tolist()
67
- stype_dict = { # Exclude target columns:
68
- column_name: stype
69
- for column_name, stype in
70
- self._graph_store.stype_dict[table_name].items()
71
- if column_name not in exclude_cols_dict.get(table_name, [])
72
- }
73
- primary_key: Optional[str] = None
74
- if table_name in entity_table_names:
75
- primary_key = self._graph_store.pkey_name_dict.get(table_name)
76
-
77
- columns: List[str] = []
78
- if table_name in entity_table_names:
79
- columns += [self._graph_store.pkey_name_dict[table_name]]
80
- columns += list(stype_dict.keys())
81
-
82
- if len(columns) == 0:
83
- table_dict[table_name] = Table(
84
- df=pd.DataFrame(index=range(len(node))),
85
- row=None,
86
- batch=batch,
87
- num_sampled_nodes=num_sampled_nodes,
88
- stype_dict=stype_dict,
89
- primary_key=primary_key,
90
- )
91
- continue
92
-
93
- row: Optional[np.ndarray] = None
94
- if table_name in self._graph_store.end_time_column_dict:
95
- # Set end time to NaT for all values greater than anchor time:
96
- df = df.iloc[node].reset_index(drop=True)
97
- col_name = self._graph_store.end_time_column_dict[table_name]
98
- ser = df[col_name]
99
- value = ser.astype('datetime64[ns]').astype(int).to_numpy()
100
- mask = value > time[batch]
101
- df.loc[mask, col_name] = pd.NaT
102
- else:
103
- # Only store unique rows in `df` above a certain threshold:
104
- unique_node, inverse = np.unique(node, return_inverse=True)
105
- if len(node) > 1.05 * len(unique_node):
106
- df = df.iloc[unique_node].reset_index(drop=True)
107
- row = inverse
108
- else:
109
- df = df.iloc[node].reset_index(drop=True)
110
-
111
- # Filter data frame to minimal set of columns:
112
- df = df[columns]
113
-
114
- # Normalize text (if not already pre-processed):
115
- for column_name, stype in stype_dict.items():
116
- if stype == Stype.text:
117
- df[column_name] = normalize_text(df[column_name])
118
-
119
- table_dict[table_name] = Table(
120
- df=df,
121
- row=row,
122
- batch=batch,
123
- num_sampled_nodes=num_sampled_nodes,
124
- stype_dict=stype_dict,
125
- primary_key=primary_key,
126
- )
127
-
128
- link_dict: Dict[Tuple[str, str, str], Link] = {}
129
- for edge_type in self._graph_store.edge_types:
130
- edge_type_str = '__'.join(edge_type)
131
-
132
- row = row_dict[edge_type_str]
133
- col = col_dict[edge_type_str]
134
-
135
- if len(row) == 0:
136
- continue
137
-
138
- # Do not store reverse edge type if it is a replica:
139
- rev_edge_type = Subgraph.rev_edge_type(edge_type)
140
- rev_edge_type_str = '__'.join(rev_edge_type)
141
- if (rev_edge_type in link_dict
142
- and np.array_equal(row, col_dict[rev_edge_type_str])
143
- and np.array_equal(col, row_dict[rev_edge_type_str])):
144
- link = Link(
145
- layout=EdgeLayout.REV,
146
- row=None,
147
- col=None,
148
- num_sampled_edges=(
149
- num_sampled_edges_dict[edge_type_str].tolist()),
150
- )
151
- link_dict[edge_type] = link
152
- continue
153
-
154
- layout = EdgeLayout.COO
155
- if np.array_equal(row, np.arange(len(row))):
156
- row = None
157
- if np.array_equal(col, np.arange(len(col))):
158
- col = None
159
-
160
- # Store in compressed representation if more efficient:
161
- num_cols = table_dict[edge_type[2]].num_rows
162
- if col is not None and len(col) > num_cols + 1:
163
- layout = EdgeLayout.CSC
164
- colcount = np.bincount(col, minlength=num_cols)
165
- col = np.empty(num_cols + 1, dtype=col.dtype)
166
- col[0] = 0
167
- np.cumsum(colcount, out=col[1:])
168
-
169
- link = Link(
170
- layout=layout,
171
- row=row,
172
- col=col,
173
- num_sampled_edges=(
174
- num_sampled_edges_dict[edge_type_str].tolist()),
175
- )
176
- link_dict[edge_type] = link
177
-
178
- return Subgraph(
179
- anchor_time=time,
180
- table_dict=table_dict,
181
- link_dict=link_dict,
182
- )