kumoai 2.14.0.dev202601011731__cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kumoai might be problematic. Click here for more details.
- kumoai/__init__.py +300 -0
- kumoai/_logging.py +29 -0
- kumoai/_singleton.py +25 -0
- kumoai/_version.py +1 -0
- kumoai/artifact_export/__init__.py +9 -0
- kumoai/artifact_export/config.py +209 -0
- kumoai/artifact_export/job.py +108 -0
- kumoai/client/__init__.py +5 -0
- kumoai/client/client.py +223 -0
- kumoai/client/connector.py +110 -0
- kumoai/client/endpoints.py +150 -0
- kumoai/client/graph.py +120 -0
- kumoai/client/jobs.py +471 -0
- kumoai/client/online.py +78 -0
- kumoai/client/pquery.py +207 -0
- kumoai/client/rfm.py +112 -0
- kumoai/client/source_table.py +53 -0
- kumoai/client/table.py +101 -0
- kumoai/client/utils.py +130 -0
- kumoai/codegen/__init__.py +19 -0
- kumoai/codegen/cli.py +100 -0
- kumoai/codegen/context.py +16 -0
- kumoai/codegen/edits.py +473 -0
- kumoai/codegen/exceptions.py +10 -0
- kumoai/codegen/generate.py +222 -0
- kumoai/codegen/handlers/__init__.py +4 -0
- kumoai/codegen/handlers/connector.py +118 -0
- kumoai/codegen/handlers/graph.py +71 -0
- kumoai/codegen/handlers/pquery.py +62 -0
- kumoai/codegen/handlers/table.py +109 -0
- kumoai/codegen/handlers/utils.py +42 -0
- kumoai/codegen/identity.py +114 -0
- kumoai/codegen/loader.py +93 -0
- kumoai/codegen/naming.py +94 -0
- kumoai/codegen/registry.py +121 -0
- kumoai/connector/__init__.py +31 -0
- kumoai/connector/base.py +153 -0
- kumoai/connector/bigquery_connector.py +200 -0
- kumoai/connector/databricks_connector.py +213 -0
- kumoai/connector/file_upload_connector.py +189 -0
- kumoai/connector/glue_connector.py +150 -0
- kumoai/connector/s3_connector.py +278 -0
- kumoai/connector/snowflake_connector.py +252 -0
- kumoai/connector/source_table.py +471 -0
- kumoai/connector/utils.py +1796 -0
- kumoai/databricks.py +14 -0
- kumoai/encoder/__init__.py +4 -0
- kumoai/exceptions.py +26 -0
- kumoai/experimental/__init__.py +0 -0
- kumoai/experimental/rfm/__init__.py +210 -0
- kumoai/experimental/rfm/authenticate.py +432 -0
- kumoai/experimental/rfm/backend/__init__.py +0 -0
- kumoai/experimental/rfm/backend/local/__init__.py +42 -0
- kumoai/experimental/rfm/backend/local/graph_store.py +297 -0
- kumoai/experimental/rfm/backend/local/sampler.py +312 -0
- kumoai/experimental/rfm/backend/local/table.py +113 -0
- kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
- kumoai/experimental/rfm/backend/snow/sampler.py +297 -0
- kumoai/experimental/rfm/backend/snow/table.py +242 -0
- kumoai/experimental/rfm/backend/sqlite/__init__.py +32 -0
- kumoai/experimental/rfm/backend/sqlite/sampler.py +398 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +184 -0
- kumoai/experimental/rfm/base/__init__.py +30 -0
- kumoai/experimental/rfm/base/column.py +152 -0
- kumoai/experimental/rfm/base/expression.py +44 -0
- kumoai/experimental/rfm/base/sampler.py +761 -0
- kumoai/experimental/rfm/base/source.py +19 -0
- kumoai/experimental/rfm/base/sql_sampler.py +143 -0
- kumoai/experimental/rfm/base/table.py +736 -0
- kumoai/experimental/rfm/graph.py +1237 -0
- kumoai/experimental/rfm/infer/__init__.py +19 -0
- kumoai/experimental/rfm/infer/categorical.py +40 -0
- kumoai/experimental/rfm/infer/dtype.py +82 -0
- kumoai/experimental/rfm/infer/id.py +46 -0
- kumoai/experimental/rfm/infer/multicategorical.py +48 -0
- kumoai/experimental/rfm/infer/pkey.py +128 -0
- kumoai/experimental/rfm/infer/stype.py +35 -0
- kumoai/experimental/rfm/infer/time_col.py +61 -0
- kumoai/experimental/rfm/infer/timestamp.py +41 -0
- kumoai/experimental/rfm/pquery/__init__.py +7 -0
- kumoai/experimental/rfm/pquery/executor.py +102 -0
- kumoai/experimental/rfm/pquery/pandas_executor.py +530 -0
- kumoai/experimental/rfm/relbench.py +76 -0
- kumoai/experimental/rfm/rfm.py +1184 -0
- kumoai/experimental/rfm/sagemaker.py +138 -0
- kumoai/experimental/rfm/task_table.py +231 -0
- kumoai/formatting.py +30 -0
- kumoai/futures.py +99 -0
- kumoai/graph/__init__.py +12 -0
- kumoai/graph/column.py +106 -0
- kumoai/graph/graph.py +948 -0
- kumoai/graph/table.py +838 -0
- kumoai/jobs.py +80 -0
- kumoai/kumolib.cpython-310-x86_64-linux-gnu.so +0 -0
- kumoai/mixin.py +28 -0
- kumoai/pquery/__init__.py +25 -0
- kumoai/pquery/prediction_table.py +287 -0
- kumoai/pquery/predictive_query.py +641 -0
- kumoai/pquery/training_table.py +424 -0
- kumoai/spcs.py +121 -0
- kumoai/testing/__init__.py +8 -0
- kumoai/testing/decorators.py +57 -0
- kumoai/testing/snow.py +50 -0
- kumoai/trainer/__init__.py +42 -0
- kumoai/trainer/baseline_trainer.py +93 -0
- kumoai/trainer/config.py +2 -0
- kumoai/trainer/distilled_trainer.py +175 -0
- kumoai/trainer/job.py +1192 -0
- kumoai/trainer/online_serving.py +258 -0
- kumoai/trainer/trainer.py +475 -0
- kumoai/trainer/util.py +103 -0
- kumoai/utils/__init__.py +11 -0
- kumoai/utils/datasets.py +83 -0
- kumoai/utils/display.py +51 -0
- kumoai/utils/forecasting.py +209 -0
- kumoai/utils/progress_logger.py +343 -0
- kumoai/utils/sql.py +3 -0
- kumoai-2.14.0.dev202601011731.dist-info/METADATA +71 -0
- kumoai-2.14.0.dev202601011731.dist-info/RECORD +122 -0
- kumoai-2.14.0.dev202601011731.dist-info/WHEEL +6 -0
- kumoai-2.14.0.dev202601011731.dist-info/licenses/LICENSE +9 -0
- kumoai-2.14.0.dev202601011731.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,222 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
from collections import OrderedDict
|
|
6
|
+
from typing import Any, Optional
|
|
7
|
+
|
|
8
|
+
from kumoai.codegen.context import CodegenContext
|
|
9
|
+
from kumoai.codegen.exceptions import (
|
|
10
|
+
CyclicDependencyError,
|
|
11
|
+
UnsupportedEntityError,
|
|
12
|
+
)
|
|
13
|
+
from kumoai.codegen.identity import get_config_id
|
|
14
|
+
from kumoai.codegen.loader import load_from_id
|
|
15
|
+
from kumoai.codegen.naming import NameManager
|
|
16
|
+
from kumoai.codegen.registry import (
|
|
17
|
+
REG,
|
|
18
|
+
Handler,
|
|
19
|
+
execute_in_env,
|
|
20
|
+
get_from_env,
|
|
21
|
+
init_execution_env,
|
|
22
|
+
register_shared_parents,
|
|
23
|
+
store_object_var,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
logger = logging.getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _get_handler(obj_type: type) -> Handler:
|
|
30
|
+
if obj_type not in REG:
|
|
31
|
+
raise UnsupportedEntityError(
|
|
32
|
+
f"No handler registered for type: {obj_type.__name__}")
|
|
33
|
+
return REG[obj_type]
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def get_kumo_id(obj: object) -> str:
|
|
37
|
+
if hasattr(obj, 'id'):
|
|
38
|
+
return obj.id
|
|
39
|
+
elif hasattr(obj, 'name'):
|
|
40
|
+
return obj.name
|
|
41
|
+
elif hasattr(obj, 'source_name'):
|
|
42
|
+
return obj.source_name
|
|
43
|
+
else:
|
|
44
|
+
return ''
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _generate(
|
|
48
|
+
obj: object,
|
|
49
|
+
name_manager: NameManager,
|
|
50
|
+
config_to_var: dict[str, str],
|
|
51
|
+
stack: set[int],
|
|
52
|
+
codegen_ctx: CodegenContext,
|
|
53
|
+
context: Optional[dict[str, Any]] = None,
|
|
54
|
+
id_to_var: Optional[dict[int, str]] = None,
|
|
55
|
+
) -> tuple[list[str], list[str]]:
|
|
56
|
+
"""Generate code for an object and its parents.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
obj: The object to generate code for.
|
|
60
|
+
name_manager: The name manager to use for variable names.
|
|
61
|
+
config_to_var: A dictionary mapping config IDs to variable names.
|
|
62
|
+
stack: A set of object IDs to detect cycles.
|
|
63
|
+
codegen_ctx: The codegen context.
|
|
64
|
+
context: A dictionary of context information.
|
|
65
|
+
id_to_var: A dictionary mapping object IDs to variable names.
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
A tuple containing a list of imports and a list of lines of code.
|
|
69
|
+
"""
|
|
70
|
+
if id_to_var is None:
|
|
71
|
+
id_to_var = {}
|
|
72
|
+
|
|
73
|
+
obj_id = id(obj)
|
|
74
|
+
config_id = get_config_id(obj)
|
|
75
|
+
|
|
76
|
+
# Check for configuration-based deduplication first
|
|
77
|
+
if config_id in config_to_var:
|
|
78
|
+
# Reuse existing variable for this configuration
|
|
79
|
+
id_to_var[obj_id] = config_to_var[config_id]
|
|
80
|
+
return [], []
|
|
81
|
+
|
|
82
|
+
# Check for exact object reuse (faster path)
|
|
83
|
+
if obj_id in id_to_var:
|
|
84
|
+
return [], []
|
|
85
|
+
|
|
86
|
+
# Cycle detection using real object IDs
|
|
87
|
+
if obj_id in stack:
|
|
88
|
+
raise CyclicDependencyError(
|
|
89
|
+
f"Cyclic dependency detected for object ID: {obj_id}")
|
|
90
|
+
|
|
91
|
+
stack.add(obj_id)
|
|
92
|
+
handler = _get_handler(type(obj))
|
|
93
|
+
all_imports, all_lines = [], []
|
|
94
|
+
|
|
95
|
+
for parent in handler.parents(obj, codegen_ctx):
|
|
96
|
+
parent_imports, parent_lines = _generate(parent, name_manager,
|
|
97
|
+
config_to_var, stack,
|
|
98
|
+
codegen_ctx, context,
|
|
99
|
+
id_to_var)
|
|
100
|
+
all_imports.extend(parent_imports)
|
|
101
|
+
all_lines.extend(parent_lines)
|
|
102
|
+
|
|
103
|
+
# Register shared parents if handler supports it
|
|
104
|
+
register_shared_parents(codegen_ctx, obj, handler)
|
|
105
|
+
|
|
106
|
+
var_name = name_manager.assign_entity_variable(obj)
|
|
107
|
+
# Store both config-based and id-based mappings
|
|
108
|
+
config_to_var[config_id] = var_name
|
|
109
|
+
id_to_var[obj_id] = var_name
|
|
110
|
+
|
|
111
|
+
# Store in codegen context for handlers to access
|
|
112
|
+
store_object_var(codegen_ctx, obj, var_name)
|
|
113
|
+
|
|
114
|
+
all_imports.extend(handler.required_imports(obj))
|
|
115
|
+
context = context or {}
|
|
116
|
+
context['target_id'] = get_kumo_id(obj)
|
|
117
|
+
creation_lines = handler.emit_lines(obj, var_name, context, codegen_ctx)
|
|
118
|
+
|
|
119
|
+
# Execute creation lines immediately in context environment
|
|
120
|
+
execute_in_env(codegen_ctx, creation_lines, handler.required_imports(obj))
|
|
121
|
+
|
|
122
|
+
all_lines.extend(creation_lines)
|
|
123
|
+
|
|
124
|
+
if handler.detect_edits:
|
|
125
|
+
# Get baseline object from context environment
|
|
126
|
+
baseline_obj = get_from_env(codegen_ctx, var_name)
|
|
127
|
+
if baseline_obj is not None:
|
|
128
|
+
edits = handler.detect_edits(obj, baseline_obj, name_manager)
|
|
129
|
+
for edit in edits:
|
|
130
|
+
edit_lines = edit.emit_lines(var_name)
|
|
131
|
+
all_lines.extend(edit_lines)
|
|
132
|
+
all_imports.extend(edit.required_imports)
|
|
133
|
+
# Execute edit lines immediately in context environment
|
|
134
|
+
execute_in_env(codegen_ctx, edit_lines, edit.required_imports)
|
|
135
|
+
|
|
136
|
+
stack.remove(obj_id)
|
|
137
|
+
return all_imports, all_lines
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def _load_entity_from_spec(input_spec: dict[str, Any]) -> object:
|
|
141
|
+
"""Load entity from input specification."""
|
|
142
|
+
if "id" in input_spec:
|
|
143
|
+
entity_class = input_spec.get("entity_class")
|
|
144
|
+
return load_from_id(input_spec["id"], entity_class)
|
|
145
|
+
elif "json" in input_spec:
|
|
146
|
+
raise NotImplementedError("JSON loading not yet implemented")
|
|
147
|
+
elif "object" in input_spec:
|
|
148
|
+
return input_spec["object"]
|
|
149
|
+
else:
|
|
150
|
+
raise ValueError("input_spec must contain 'id', 'json', or 'object'")
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def _write_script(code: str, output_path: str) -> None:
|
|
154
|
+
"""Write generated code to file."""
|
|
155
|
+
with open(output_path, "w") as f:
|
|
156
|
+
f.write(code)
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def _assemble_code(imports: list[str], lines: list[str]) -> str:
|
|
160
|
+
"""Assemble final code from components."""
|
|
161
|
+
from kumoai import __version__
|
|
162
|
+
|
|
163
|
+
header = [
|
|
164
|
+
f"# Generated with Kumo SDK version: {__version__}",
|
|
165
|
+
"import kumoai as kumo",
|
|
166
|
+
"import os",
|
|
167
|
+
"",
|
|
168
|
+
'kumo.init(url=os.getenv("KUMO_API_ENDPOINT"), '
|
|
169
|
+
'api_key=os.getenv("KUMO_API_KEY"))',
|
|
170
|
+
"",
|
|
171
|
+
]
|
|
172
|
+
|
|
173
|
+
unique_imports = list(OrderedDict.fromkeys(imports))
|
|
174
|
+
code = header + unique_imports + [""] + lines
|
|
175
|
+
return "\n".join(code) + "\n"
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def _init_kumo() -> None:
|
|
179
|
+
"""Initialize Kumo SDK for this python session."""
|
|
180
|
+
import kumoai as kumo
|
|
181
|
+
if os.getenv("KUMO_API_ENDPOINT") is None:
|
|
182
|
+
logger.warning("KUMO_API_ENDPOINT env variable is not set, "
|
|
183
|
+
"assuming kumo.init has already been called")
|
|
184
|
+
return
|
|
185
|
+
if os.getenv("KUMO_API_KEY") is None:
|
|
186
|
+
logger.warning("KUMO_API_KEY env variable is not set, "
|
|
187
|
+
"assuming kumo.init has already been called")
|
|
188
|
+
return
|
|
189
|
+
kumo.init(url=os.getenv("KUMO_API_ENDPOINT"),
|
|
190
|
+
api_key=os.getenv("KUMO_API_KEY"))
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def generate_code(input_spec: dict[str, Any],
|
|
194
|
+
output_path: Optional[str] = None) -> str:
|
|
195
|
+
"""Generate Python SDK code from Kumo entity specification."""
|
|
196
|
+
# Create codegen context for this generation session
|
|
197
|
+
codegen_ctx = CodegenContext()
|
|
198
|
+
|
|
199
|
+
# Initialize execution environment in context
|
|
200
|
+
init_execution_env(codegen_ctx)
|
|
201
|
+
|
|
202
|
+
_init_kumo()
|
|
203
|
+
entity = _load_entity_from_spec(input_spec)
|
|
204
|
+
|
|
205
|
+
context = {}
|
|
206
|
+
if "id" in input_spec:
|
|
207
|
+
context["input_method"] = "id"
|
|
208
|
+
context["target_id"] = input_spec["id"]
|
|
209
|
+
elif "json" in input_spec:
|
|
210
|
+
context["input_method"] = "json"
|
|
211
|
+
else:
|
|
212
|
+
context["input_method"] = "object"
|
|
213
|
+
|
|
214
|
+
name_manager = NameManager()
|
|
215
|
+
imports, lines = _generate(entity, name_manager, config_to_var={},
|
|
216
|
+
stack=set(), codegen_ctx=codegen_ctx,
|
|
217
|
+
context=context, id_to_var={})
|
|
218
|
+
|
|
219
|
+
code = _assemble_code(imports, lines)
|
|
220
|
+
if output_path:
|
|
221
|
+
_write_script(code, output_path)
|
|
222
|
+
return code
|
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Dict, Type
|
|
4
|
+
|
|
5
|
+
from kumoai.codegen.context import CodegenContext
|
|
6
|
+
from kumoai.codegen.handlers.utils import _get_canonical_import_path
|
|
7
|
+
from kumoai.codegen.registry import Handler
|
|
8
|
+
from kumoai.connector import (
|
|
9
|
+
BigQueryConnector,
|
|
10
|
+
DatabricksConnector,
|
|
11
|
+
FileUploadConnector,
|
|
12
|
+
S3Connector,
|
|
13
|
+
SnowflakeConnector,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _get_by_name_handler_factory(cls: type) -> Handler:
|
|
18
|
+
"""Factory for creating a Handler for any connector that uses the
|
|
19
|
+
get_by_name pattern.
|
|
20
|
+
"""
|
|
21
|
+
def get_imports(obj: object, bound_cls: type = cls) -> list[str]:
|
|
22
|
+
canonical_module = _get_canonical_import_path(bound_cls)
|
|
23
|
+
return [f"from {canonical_module} import {bound_cls.__name__}"]
|
|
24
|
+
|
|
25
|
+
def get_lines(
|
|
26
|
+
obj: object,
|
|
27
|
+
var_name: str,
|
|
28
|
+
context: dict,
|
|
29
|
+
codegen_ctx: CodegenContext,
|
|
30
|
+
) -> list[str]:
|
|
31
|
+
assert isinstance(obj, cls)
|
|
32
|
+
assert hasattr(obj, "name")
|
|
33
|
+
obj_name = getattr(obj, "name")
|
|
34
|
+
return [f"{var_name} = {cls.__name__}.get_by_name('{obj_name}')"]
|
|
35
|
+
|
|
36
|
+
return Handler(
|
|
37
|
+
parents=lambda e, ctx: [],
|
|
38
|
+
required_imports=get_imports,
|
|
39
|
+
emit_lines=get_lines,
|
|
40
|
+
detect_edits=None,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def get_handlers() -> Dict[Type, Handler]:
|
|
45
|
+
"""Returns a dictionary of handlers for all connector types."""
|
|
46
|
+
handlers: Dict[Type, Handler] = {}
|
|
47
|
+
|
|
48
|
+
# S3Connector gets special handling to support both patterns
|
|
49
|
+
|
|
50
|
+
def _s3_connector_parents(obj: object,
|
|
51
|
+
codegen_ctx: CodegenContext) -> list[object]:
|
|
52
|
+
return []
|
|
53
|
+
|
|
54
|
+
def _s3_connector_imports(obj: object) -> list[str]:
|
|
55
|
+
canonical_module = _get_canonical_import_path(S3Connector)
|
|
56
|
+
return [f"from {canonical_module} import S3Connector"]
|
|
57
|
+
|
|
58
|
+
def _s3_connector_emit_lines(
|
|
59
|
+
obj: object,
|
|
60
|
+
var_name: str,
|
|
61
|
+
context: dict,
|
|
62
|
+
codegen_ctx: CodegenContext,
|
|
63
|
+
) -> list[str]:
|
|
64
|
+
assert isinstance(obj, S3Connector)
|
|
65
|
+
assert hasattr(obj, "name")
|
|
66
|
+
|
|
67
|
+
# Check if connector has root_dir attribute and it's not None
|
|
68
|
+
if hasattr(obj, "root_dir") and obj.root_dir is not None:
|
|
69
|
+
root_dir = getattr(obj, "root_dir")
|
|
70
|
+
return [f"{var_name} = S3Connector(root_dir='{root_dir}')"]
|
|
71
|
+
else:
|
|
72
|
+
obj_name = getattr(obj, "name")
|
|
73
|
+
return [f"{var_name} = S3Connector.get_by_name('{obj_name}')"]
|
|
74
|
+
|
|
75
|
+
handlers[S3Connector] = Handler(
|
|
76
|
+
parents=_s3_connector_parents,
|
|
77
|
+
required_imports=_s3_connector_imports,
|
|
78
|
+
emit_lines=_s3_connector_emit_lines,
|
|
79
|
+
detect_edits=None,
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
# Other persistent connectors use the standard get_by_name pattern
|
|
83
|
+
other_persistent_connectors = [
|
|
84
|
+
BigQueryConnector,
|
|
85
|
+
SnowflakeConnector,
|
|
86
|
+
DatabricksConnector,
|
|
87
|
+
]
|
|
88
|
+
for connector_cls in other_persistent_connectors:
|
|
89
|
+
handlers[connector_cls] = _get_by_name_handler_factory(connector_cls)
|
|
90
|
+
|
|
91
|
+
def _file_upload_parents(obj: object,
|
|
92
|
+
codegen_ctx: CodegenContext) -> list[object]:
|
|
93
|
+
return []
|
|
94
|
+
|
|
95
|
+
def _file_upload_imports(obj: object) -> list[str]:
|
|
96
|
+
canonical_module = _get_canonical_import_path(type(obj))
|
|
97
|
+
return [f"from {canonical_module} import {type(obj).__name__}"]
|
|
98
|
+
|
|
99
|
+
def _file_upload_emit_lines(
|
|
100
|
+
obj: object,
|
|
101
|
+
var_name: str,
|
|
102
|
+
context: dict,
|
|
103
|
+
codegen_ctx: CodegenContext,
|
|
104
|
+
) -> list[str]:
|
|
105
|
+
assert isinstance(obj, FileUploadConnector)
|
|
106
|
+
return [
|
|
107
|
+
f"{var_name} = {type(obj).__name__}"
|
|
108
|
+
f"(file_type='{obj.file_type}')"
|
|
109
|
+
]
|
|
110
|
+
|
|
111
|
+
handlers[FileUploadConnector] = Handler(
|
|
112
|
+
parents=_file_upload_parents,
|
|
113
|
+
required_imports=_file_upload_imports,
|
|
114
|
+
emit_lines=_file_upload_emit_lines,
|
|
115
|
+
detect_edits=None,
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
return handlers
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Dict, Type
|
|
3
|
+
|
|
4
|
+
from kumoai.codegen.context import CodegenContext
|
|
5
|
+
from kumoai.codegen.handlers.utils import _get_canonical_import_path
|
|
6
|
+
from kumoai.codegen.registry import Handler, get_object_var
|
|
7
|
+
from kumoai.graph import Edge, Graph
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def get_handlers() -> Dict[Type, Handler]:
|
|
13
|
+
"""Returns a dictionary of handlers for Graph types."""
|
|
14
|
+
handlers: Dict[Type, Handler] = {}
|
|
15
|
+
|
|
16
|
+
def _graph_parents(obj: object,
|
|
17
|
+
codegen_ctx: CodegenContext) -> list[object]:
|
|
18
|
+
"""Graph depends on its tables and edges."""
|
|
19
|
+
assert isinstance(obj, Graph)
|
|
20
|
+
return [table for table in obj.tables.values()]
|
|
21
|
+
|
|
22
|
+
def _graph_imports(obj: object) -> list[str]:
|
|
23
|
+
"""Get import statements needed for Graph."""
|
|
24
|
+
imports_needed = [Graph, Edge]
|
|
25
|
+
imports = []
|
|
26
|
+
for obj_type in imports_needed:
|
|
27
|
+
canonical_module = _get_canonical_import_path(obj_type)
|
|
28
|
+
imports.append(
|
|
29
|
+
f"from {canonical_module} import {obj_type.__name__}")
|
|
30
|
+
return imports
|
|
31
|
+
|
|
32
|
+
def _graph_emit_lines(
|
|
33
|
+
obj: object,
|
|
34
|
+
var_name: str,
|
|
35
|
+
context: dict,
|
|
36
|
+
codegen_ctx: CodegenContext,
|
|
37
|
+
) -> list[str]:
|
|
38
|
+
"""Generate code lines to recreate a Graph using Graph()."""
|
|
39
|
+
assert isinstance(obj, Graph)
|
|
40
|
+
|
|
41
|
+
tables_vars = {
|
|
42
|
+
table_name: get_object_var(codegen_ctx, table)
|
|
43
|
+
for table_name, table in obj.tables.items()
|
|
44
|
+
}
|
|
45
|
+
tables_format_inner = ", ".join(f"{key!r}: {value}"
|
|
46
|
+
for key, value in tables_vars.items())
|
|
47
|
+
all_tables_var_name = f"{var_name}_tables"
|
|
48
|
+
all_edges_var_name = f"{var_name}_edges"
|
|
49
|
+
|
|
50
|
+
note = (f"# Note: This could also be done with Graph.load("
|
|
51
|
+
f"'{context['target_id']}') for simpler code")
|
|
52
|
+
|
|
53
|
+
lines = [
|
|
54
|
+
note, f"{all_tables_var_name} = {{{tables_format_inner}}}",
|
|
55
|
+
f"{all_edges_var_name} = {obj.edges}",
|
|
56
|
+
f"{var_name} = Graph(tables={all_tables_var_name}, "
|
|
57
|
+
f"edges={all_edges_var_name})", f"{var_name}.validate()",
|
|
58
|
+
f"# Optionally, you can save the graph to backend using "
|
|
59
|
+
f"{var_name}.save({context['target_id']}) and then load it"
|
|
60
|
+
]
|
|
61
|
+
|
|
62
|
+
return lines
|
|
63
|
+
|
|
64
|
+
handlers[Graph] = Handler(
|
|
65
|
+
parents=_graph_parents,
|
|
66
|
+
required_imports=_graph_imports,
|
|
67
|
+
emit_lines=_graph_emit_lines,
|
|
68
|
+
detect_edits=None,
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
return handlers
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Dict, Type
|
|
3
|
+
|
|
4
|
+
from kumoai.codegen.context import CodegenContext
|
|
5
|
+
from kumoai.codegen.handlers.utils import _get_canonical_import_path
|
|
6
|
+
from kumoai.codegen.registry import Handler, get_object_var
|
|
7
|
+
from kumoai.pquery import PredictiveQuery
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def get_handlers() -> Dict[Type, Handler]:
|
|
13
|
+
"""Returns a dictionary of handlers for PredictiveQuery types."""
|
|
14
|
+
handlers: Dict[Type, Handler] = {}
|
|
15
|
+
|
|
16
|
+
def _pquery_parents(obj: object,
|
|
17
|
+
codegen_ctx: CodegenContext) -> list[object]:
|
|
18
|
+
"""PredictiveQuery depends on its graph."""
|
|
19
|
+
assert isinstance(obj, PredictiveQuery)
|
|
20
|
+
return [obj.graph]
|
|
21
|
+
|
|
22
|
+
def _pquery_imports(obj: object) -> list[str]:
|
|
23
|
+
"""Get import statements needed for PredictiveQuery."""
|
|
24
|
+
imports_needed = [PredictiveQuery]
|
|
25
|
+
imports = []
|
|
26
|
+
for obj_type in imports_needed:
|
|
27
|
+
canonical_module = _get_canonical_import_path(obj_type)
|
|
28
|
+
imports.append(
|
|
29
|
+
f"from {canonical_module} import {obj_type.__name__}")
|
|
30
|
+
return imports
|
|
31
|
+
|
|
32
|
+
def _pquery_emit_lines(
|
|
33
|
+
obj: object,
|
|
34
|
+
var_name: str,
|
|
35
|
+
context: dict,
|
|
36
|
+
codegen_ctx: CodegenContext,
|
|
37
|
+
) -> list[str]:
|
|
38
|
+
"""Generate code lines to recreate a PredictiveQuery."""
|
|
39
|
+
assert isinstance(obj, PredictiveQuery)
|
|
40
|
+
|
|
41
|
+
graph_var = get_object_var(codegen_ctx, obj.graph)
|
|
42
|
+
|
|
43
|
+
if '\n' in obj.query:
|
|
44
|
+
formatted_query = f'"""{obj.query}"""'
|
|
45
|
+
else:
|
|
46
|
+
formatted_query = repr(obj.query)
|
|
47
|
+
|
|
48
|
+
lines = [
|
|
49
|
+
f"{var_name} = PredictiveQuery(" + f"graph={graph_var}," +
|
|
50
|
+
f" query={formatted_query}," + ")", f"{var_name}.validate()"
|
|
51
|
+
]
|
|
52
|
+
|
|
53
|
+
return lines
|
|
54
|
+
|
|
55
|
+
handlers[PredictiveQuery] = Handler(
|
|
56
|
+
parents=_pquery_parents,
|
|
57
|
+
required_imports=_pquery_imports,
|
|
58
|
+
emit_lines=_pquery_emit_lines,
|
|
59
|
+
detect_edits=None,
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
return handlers
|
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Dict, Sequence, Type
|
|
3
|
+
|
|
4
|
+
from kumoai.codegen.context import CodegenContext
|
|
5
|
+
from kumoai.codegen.edits import (
|
|
6
|
+
UniversalReplacementEdit,
|
|
7
|
+
detect_edits_recursive,
|
|
8
|
+
)
|
|
9
|
+
from kumoai.codegen.handlers.utils import _get_canonical_import_path
|
|
10
|
+
from kumoai.codegen.naming import NameManager
|
|
11
|
+
from kumoai.codegen.registry import Handler, get_object_var
|
|
12
|
+
from kumoai.graph.table import Table
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def get_handlers() -> Dict[Type, Handler]:
|
|
18
|
+
"""Returns a dictionary of handlers for Table types."""
|
|
19
|
+
handlers: Dict[Type, Handler] = {}
|
|
20
|
+
|
|
21
|
+
def _table_parents(obj: object,
|
|
22
|
+
codegen_ctx: CodegenContext) -> list[object]:
|
|
23
|
+
"""Table depends on its source_table's connector."""
|
|
24
|
+
assert isinstance(obj, Table)
|
|
25
|
+
return [obj.source_table.connector]
|
|
26
|
+
|
|
27
|
+
def _table_imports(obj: object) -> list[str]:
|
|
28
|
+
"""Get import statements needed for Table."""
|
|
29
|
+
canonical_module = _get_canonical_import_path(Table)
|
|
30
|
+
return [f"from {canonical_module} import Table"]
|
|
31
|
+
|
|
32
|
+
def _build_args(obj: Table, connector_var: str,
|
|
33
|
+
table_name: str) -> list[str]:
|
|
34
|
+
"""Build arguments for Table.from_source_table() call."""
|
|
35
|
+
args = [f"source_table={connector_var}['{table_name}']"]
|
|
36
|
+
|
|
37
|
+
if obj.primary_key is not None:
|
|
38
|
+
args.append(f"primary_key='{obj.primary_key.name}'")
|
|
39
|
+
|
|
40
|
+
if obj.time_column is not None:
|
|
41
|
+
args.append(f"time_column='{obj.time_column.name}'")
|
|
42
|
+
|
|
43
|
+
if obj.end_time_column is not None:
|
|
44
|
+
args.append(f"end_time_column='{obj.end_time_column.name}'")
|
|
45
|
+
|
|
46
|
+
source_table_cols = {col.name: col for col in obj.source_table.columns}
|
|
47
|
+
table_cols = {col.name: col for col in obj.columns}
|
|
48
|
+
|
|
49
|
+
if len(source_table_cols) != len(table_cols):
|
|
50
|
+
assert len(source_table_cols) > len(table_cols)
|
|
51
|
+
table_col_names = [col_name for col_name in table_cols.keys()]
|
|
52
|
+
args.append(f"column_names={table_col_names}")
|
|
53
|
+
|
|
54
|
+
return args
|
|
55
|
+
|
|
56
|
+
def _table_emit_lines(
|
|
57
|
+
obj: object,
|
|
58
|
+
var_name: str,
|
|
59
|
+
context: dict,
|
|
60
|
+
codegen_ctx: CodegenContext,
|
|
61
|
+
) -> list[str]:
|
|
62
|
+
"""Generate code lines to recreate a Table using from_source_table."""
|
|
63
|
+
assert isinstance(obj, Table)
|
|
64
|
+
|
|
65
|
+
connector_var = get_object_var(codegen_ctx, obj.source_table.connector)
|
|
66
|
+
table_name = obj.source_table.name
|
|
67
|
+
|
|
68
|
+
args = _build_args(obj, connector_var, table_name)
|
|
69
|
+
|
|
70
|
+
lines = []
|
|
71
|
+
|
|
72
|
+
if context.get("input_method") == "id" and context.get("target_id"):
|
|
73
|
+
note = (f"# Note: This could also be done with Table.load("
|
|
74
|
+
f"'{context['target_id']}') for simpler code")
|
|
75
|
+
lines.append(note)
|
|
76
|
+
|
|
77
|
+
if len(args) == 1:
|
|
78
|
+
lines.append(f"{var_name} = Table.from_source_table({args[0]})")
|
|
79
|
+
else:
|
|
80
|
+
args_str = ",\n ".join(args)
|
|
81
|
+
lines.append(
|
|
82
|
+
f"{var_name} = Table.from_source_table(\n {args_str},\n)")
|
|
83
|
+
|
|
84
|
+
return lines
|
|
85
|
+
|
|
86
|
+
def _table_detect_edits(
|
|
87
|
+
target: object, baseline: object,
|
|
88
|
+
name_manager: NameManager) -> Sequence[UniversalReplacementEdit]:
|
|
89
|
+
"""Detect edits needed to make baseline match target table."""
|
|
90
|
+
assert isinstance(target, Table)
|
|
91
|
+
assert isinstance(baseline, Table)
|
|
92
|
+
|
|
93
|
+
try:
|
|
94
|
+
result = detect_edits_recursive(target, baseline, "", name_manager)
|
|
95
|
+
logger.debug(f"Found for table {len(result.edits)} edits with "
|
|
96
|
+
f"{len(result.imports)} imports")
|
|
97
|
+
return result.edits
|
|
98
|
+
except Exception as e:
|
|
99
|
+
logger.error(f"Error during table edit detection: {e}")
|
|
100
|
+
return []
|
|
101
|
+
|
|
102
|
+
handlers[Table] = Handler(
|
|
103
|
+
parents=_table_parents,
|
|
104
|
+
required_imports=_table_imports,
|
|
105
|
+
emit_lines=_table_emit_lines,
|
|
106
|
+
detect_edits=_table_detect_edits,
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
return handlers
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import importlib
|
|
4
|
+
from typing import Type
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def _get_canonical_import_path(cls: Type) -> str:
|
|
8
|
+
"""Dynamically finds the shortest, most canonical import path.
|
|
9
|
+
|
|
10
|
+
For example, given the S3Connector class, it would prefer
|
|
11
|
+
'from kumoai import S3Connector' over
|
|
12
|
+
'from kumoai.connector import S3Connector' over
|
|
13
|
+
'from kumoai.connector.s3_connector import S3Connector'.
|
|
14
|
+
"""
|
|
15
|
+
base_module_path = cls.__module__
|
|
16
|
+
class_name = cls.__name__
|
|
17
|
+
|
|
18
|
+
parts = base_module_path.split(".")
|
|
19
|
+
|
|
20
|
+
# The longest path is the one where the class is defined
|
|
21
|
+
canonical_path = base_module_path
|
|
22
|
+
|
|
23
|
+
# Walk up the module hierarchy to find a shorter path.
|
|
24
|
+
# e.g., from 'kumoai.connector.s3_connector' to 'kumoai.connector'
|
|
25
|
+
for i in range(len(parts) - 1, 0, -1):
|
|
26
|
+
parent_module_path = ".".join(parts[:i])
|
|
27
|
+
try:
|
|
28
|
+
parent_module = importlib.import_module(parent_module_path)
|
|
29
|
+
if hasattr(parent_module, class_name):
|
|
30
|
+
if getattr(parent_module, class_name) is cls:
|
|
31
|
+
canonical_path = parent_module_path
|
|
32
|
+
else:
|
|
33
|
+
# A different object has the same name
|
|
34
|
+
break
|
|
35
|
+
else:
|
|
36
|
+
# The class isn't in this parent module
|
|
37
|
+
break
|
|
38
|
+
except ImportError:
|
|
39
|
+
# If we can't import a parent, we can't go higher.
|
|
40
|
+
break
|
|
41
|
+
|
|
42
|
+
return canonical_path
|