kumoai 2.14.0.dev202601011731__cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kumoai might be problematic. Click here for more details.

Files changed (122) hide show
  1. kumoai/__init__.py +300 -0
  2. kumoai/_logging.py +29 -0
  3. kumoai/_singleton.py +25 -0
  4. kumoai/_version.py +1 -0
  5. kumoai/artifact_export/__init__.py +9 -0
  6. kumoai/artifact_export/config.py +209 -0
  7. kumoai/artifact_export/job.py +108 -0
  8. kumoai/client/__init__.py +5 -0
  9. kumoai/client/client.py +223 -0
  10. kumoai/client/connector.py +110 -0
  11. kumoai/client/endpoints.py +150 -0
  12. kumoai/client/graph.py +120 -0
  13. kumoai/client/jobs.py +471 -0
  14. kumoai/client/online.py +78 -0
  15. kumoai/client/pquery.py +207 -0
  16. kumoai/client/rfm.py +112 -0
  17. kumoai/client/source_table.py +53 -0
  18. kumoai/client/table.py +101 -0
  19. kumoai/client/utils.py +130 -0
  20. kumoai/codegen/__init__.py +19 -0
  21. kumoai/codegen/cli.py +100 -0
  22. kumoai/codegen/context.py +16 -0
  23. kumoai/codegen/edits.py +473 -0
  24. kumoai/codegen/exceptions.py +10 -0
  25. kumoai/codegen/generate.py +222 -0
  26. kumoai/codegen/handlers/__init__.py +4 -0
  27. kumoai/codegen/handlers/connector.py +118 -0
  28. kumoai/codegen/handlers/graph.py +71 -0
  29. kumoai/codegen/handlers/pquery.py +62 -0
  30. kumoai/codegen/handlers/table.py +109 -0
  31. kumoai/codegen/handlers/utils.py +42 -0
  32. kumoai/codegen/identity.py +114 -0
  33. kumoai/codegen/loader.py +93 -0
  34. kumoai/codegen/naming.py +94 -0
  35. kumoai/codegen/registry.py +121 -0
  36. kumoai/connector/__init__.py +31 -0
  37. kumoai/connector/base.py +153 -0
  38. kumoai/connector/bigquery_connector.py +200 -0
  39. kumoai/connector/databricks_connector.py +213 -0
  40. kumoai/connector/file_upload_connector.py +189 -0
  41. kumoai/connector/glue_connector.py +150 -0
  42. kumoai/connector/s3_connector.py +278 -0
  43. kumoai/connector/snowflake_connector.py +252 -0
  44. kumoai/connector/source_table.py +471 -0
  45. kumoai/connector/utils.py +1796 -0
  46. kumoai/databricks.py +14 -0
  47. kumoai/encoder/__init__.py +4 -0
  48. kumoai/exceptions.py +26 -0
  49. kumoai/experimental/__init__.py +0 -0
  50. kumoai/experimental/rfm/__init__.py +210 -0
  51. kumoai/experimental/rfm/authenticate.py +432 -0
  52. kumoai/experimental/rfm/backend/__init__.py +0 -0
  53. kumoai/experimental/rfm/backend/local/__init__.py +42 -0
  54. kumoai/experimental/rfm/backend/local/graph_store.py +297 -0
  55. kumoai/experimental/rfm/backend/local/sampler.py +312 -0
  56. kumoai/experimental/rfm/backend/local/table.py +113 -0
  57. kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
  58. kumoai/experimental/rfm/backend/snow/sampler.py +297 -0
  59. kumoai/experimental/rfm/backend/snow/table.py +242 -0
  60. kumoai/experimental/rfm/backend/sqlite/__init__.py +32 -0
  61. kumoai/experimental/rfm/backend/sqlite/sampler.py +398 -0
  62. kumoai/experimental/rfm/backend/sqlite/table.py +184 -0
  63. kumoai/experimental/rfm/base/__init__.py +30 -0
  64. kumoai/experimental/rfm/base/column.py +152 -0
  65. kumoai/experimental/rfm/base/expression.py +44 -0
  66. kumoai/experimental/rfm/base/sampler.py +761 -0
  67. kumoai/experimental/rfm/base/source.py +19 -0
  68. kumoai/experimental/rfm/base/sql_sampler.py +143 -0
  69. kumoai/experimental/rfm/base/table.py +736 -0
  70. kumoai/experimental/rfm/graph.py +1237 -0
  71. kumoai/experimental/rfm/infer/__init__.py +19 -0
  72. kumoai/experimental/rfm/infer/categorical.py +40 -0
  73. kumoai/experimental/rfm/infer/dtype.py +82 -0
  74. kumoai/experimental/rfm/infer/id.py +46 -0
  75. kumoai/experimental/rfm/infer/multicategorical.py +48 -0
  76. kumoai/experimental/rfm/infer/pkey.py +128 -0
  77. kumoai/experimental/rfm/infer/stype.py +35 -0
  78. kumoai/experimental/rfm/infer/time_col.py +61 -0
  79. kumoai/experimental/rfm/infer/timestamp.py +41 -0
  80. kumoai/experimental/rfm/pquery/__init__.py +7 -0
  81. kumoai/experimental/rfm/pquery/executor.py +102 -0
  82. kumoai/experimental/rfm/pquery/pandas_executor.py +530 -0
  83. kumoai/experimental/rfm/relbench.py +76 -0
  84. kumoai/experimental/rfm/rfm.py +1184 -0
  85. kumoai/experimental/rfm/sagemaker.py +138 -0
  86. kumoai/experimental/rfm/task_table.py +231 -0
  87. kumoai/formatting.py +30 -0
  88. kumoai/futures.py +99 -0
  89. kumoai/graph/__init__.py +12 -0
  90. kumoai/graph/column.py +106 -0
  91. kumoai/graph/graph.py +948 -0
  92. kumoai/graph/table.py +838 -0
  93. kumoai/jobs.py +80 -0
  94. kumoai/kumolib.cpython-310-x86_64-linux-gnu.so +0 -0
  95. kumoai/mixin.py +28 -0
  96. kumoai/pquery/__init__.py +25 -0
  97. kumoai/pquery/prediction_table.py +287 -0
  98. kumoai/pquery/predictive_query.py +641 -0
  99. kumoai/pquery/training_table.py +424 -0
  100. kumoai/spcs.py +121 -0
  101. kumoai/testing/__init__.py +8 -0
  102. kumoai/testing/decorators.py +57 -0
  103. kumoai/testing/snow.py +50 -0
  104. kumoai/trainer/__init__.py +42 -0
  105. kumoai/trainer/baseline_trainer.py +93 -0
  106. kumoai/trainer/config.py +2 -0
  107. kumoai/trainer/distilled_trainer.py +175 -0
  108. kumoai/trainer/job.py +1192 -0
  109. kumoai/trainer/online_serving.py +258 -0
  110. kumoai/trainer/trainer.py +475 -0
  111. kumoai/trainer/util.py +103 -0
  112. kumoai/utils/__init__.py +11 -0
  113. kumoai/utils/datasets.py +83 -0
  114. kumoai/utils/display.py +51 -0
  115. kumoai/utils/forecasting.py +209 -0
  116. kumoai/utils/progress_logger.py +343 -0
  117. kumoai/utils/sql.py +3 -0
  118. kumoai-2.14.0.dev202601011731.dist-info/METADATA +71 -0
  119. kumoai-2.14.0.dev202601011731.dist-info/RECORD +122 -0
  120. kumoai-2.14.0.dev202601011731.dist-info/WHEEL +6 -0
  121. kumoai-2.14.0.dev202601011731.dist-info/licenses/LICENSE +9 -0
  122. kumoai-2.14.0.dev202601011731.dist-info/top_level.txt +1 -0
@@ -0,0 +1,19 @@
1
+ from dataclasses import dataclass
2
+
3
+ from kumoapi.typing import Dtype
4
+
5
+
6
+ @dataclass
7
+ class SourceColumn:
8
+ name: str
9
+ dtype: Dtype | None
10
+ is_primary_key: bool
11
+ is_unique_key: bool
12
+ is_nullable: bool
13
+
14
+
15
+ @dataclass
16
+ class SourceForeignKey:
17
+ name: str
18
+ dst_table: str
19
+ primary_key: str
@@ -0,0 +1,143 @@
1
+ from abc import abstractmethod
2
+ from typing import TYPE_CHECKING, Literal
3
+
4
+ import numpy as np
5
+ import pandas as pd
6
+ from kumoapi.typing import Dtype
7
+
8
+ from kumoai.experimental.rfm.base import (
9
+ LocalExpression,
10
+ Sampler,
11
+ SamplerOutput,
12
+ SourceColumn,
13
+ )
14
+ from kumoai.utils import ProgressLogger, quote_ident
15
+
16
+ if TYPE_CHECKING:
17
+ from kumoai.experimental.rfm import Graph
18
+
19
+
20
+ class SQLSampler(Sampler):
21
+ def __init__(
22
+ self,
23
+ graph: 'Graph',
24
+ verbose: bool | ProgressLogger = True,
25
+ ) -> None:
26
+ super().__init__(graph=graph, verbose=verbose)
27
+
28
+ self._source_name_dict: dict[str, str] = {
29
+ table.name: table._quoted_source_name
30
+ for table in graph.tables.values()
31
+ }
32
+
33
+ self._source_table_dict: dict[str, dict[str, SourceColumn]] = {}
34
+ for table in graph.tables.values():
35
+ self._source_table_dict[table.name] = {}
36
+ for column in table.columns:
37
+ if not column.is_source:
38
+ continue
39
+ src_column = table._source_column_dict[column.name]
40
+ self._source_table_dict[table.name][column.name] = src_column
41
+
42
+ self._table_dtype_dict: dict[str, dict[str, Dtype]] = {}
43
+ for table in graph.tables.values():
44
+ self._table_dtype_dict[table.name] = {}
45
+ for column in table.columns:
46
+ self._table_dtype_dict[table.name][column.name] = column.dtype
47
+
48
+ self._table_column_ref_dict: dict[str, dict[str, str]] = {}
49
+ self._table_column_proj_dict: dict[str, dict[str, str]] = {}
50
+ for table in graph.tables.values():
51
+ column_ref_dict: dict[str, str] = {}
52
+ column_proj_dict: dict[str, str] = {}
53
+ for column in table.columns:
54
+ if column.expr is not None:
55
+ assert isinstance(column.expr, LocalExpression)
56
+ column_ref_dict[column.name] = column.expr.value
57
+ column_proj_dict[column.name] = (
58
+ f'{column.expr} AS {quote_ident(column.name)}')
59
+ else:
60
+ column_ref_dict[column.name] = quote_ident(column.name)
61
+ column_proj_dict[column.name] = quote_ident(column.name)
62
+ self._table_column_ref_dict[table.name] = column_ref_dict
63
+ self._table_column_proj_dict[table.name] = column_proj_dict
64
+
65
+ @property
66
+ def source_name_dict(self) -> dict[str, str]:
67
+ r"""The source table names for all tables in the graph."""
68
+ return self._source_name_dict
69
+
70
+ @property
71
+ def source_table_dict(self) -> dict[str, dict[str, SourceColumn]]:
72
+ r"""The source column information for all tables in the graph."""
73
+ return self._source_table_dict
74
+
75
+ @property
76
+ def table_dtype_dict(self) -> dict[str, dict[str, Dtype]]:
77
+ r"""The data types for all columns in all tables in the graph."""
78
+ return self._table_dtype_dict
79
+
80
+ @property
81
+ def table_column_ref_dict(self) -> dict[str, dict[str, str]]:
82
+ r"""The SQL reference expression for all columns in all tables in the
83
+ graph.
84
+ """
85
+ return self._table_column_ref_dict
86
+
87
+ @property
88
+ def table_column_proj_dict(self) -> dict[str, dict[str, str]]:
89
+ r"""The SQL projection expressions for all columns in all tables in the
90
+ graph.
91
+ """
92
+ return self._table_column_proj_dict
93
+
94
+ def _sample_subgraph(
95
+ self,
96
+ entity_table_name: str,
97
+ entity_pkey: pd.Series,
98
+ anchor_time: pd.Series | Literal['entity'],
99
+ columns_dict: dict[str, set[str]],
100
+ num_neighbors: list[int],
101
+ ) -> SamplerOutput:
102
+
103
+ df, batch = self._by_pkey(
104
+ table_name=entity_table_name,
105
+ pkey=entity_pkey,
106
+ columns=columns_dict[entity_table_name],
107
+ )
108
+ if len(batch) != len(entity_pkey):
109
+ mask = np.ones(len(entity_pkey), dtype=bool)
110
+ mask[batch] = False
111
+ raise KeyError(f"The primary keys "
112
+ f"{entity_pkey.iloc[mask].tolist()} do not exist "
113
+ f"in the '{entity_table_name}' table")
114
+
115
+ perm = batch.argsort()
116
+ batch = batch[perm]
117
+ df = df.iloc[perm].reset_index(drop=True)
118
+
119
+ if not isinstance(anchor_time, pd.Series):
120
+ time_column = self.time_column_dict[entity_table_name]
121
+ anchor_time = df[time_column]
122
+
123
+ return SamplerOutput(
124
+ anchor_time=anchor_time.astype(int).to_numpy(),
125
+ df_dict={entity_table_name: df},
126
+ inverse_dict={},
127
+ batch_dict={entity_table_name: batch},
128
+ num_sampled_nodes_dict={entity_table_name: [len(batch)]},
129
+ row_dict={},
130
+ col_dict={},
131
+ num_sampled_edges_dict={},
132
+ )
133
+
134
+ # Abstract Methods ########################################################
135
+
136
+ @abstractmethod
137
+ def _by_pkey(
138
+ self,
139
+ table_name: str,
140
+ pkey: pd.Series,
141
+ columns: set[str],
142
+ ) -> tuple[pd.DataFrame, np.ndarray]:
143
+ pass