relationalai 1.0.0a3__py3-none-any.whl → 1.0.0a5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- relationalai/config/config.py +47 -21
- relationalai/config/connections/__init__.py +5 -2
- relationalai/config/connections/duckdb.py +2 -2
- relationalai/config/connections/local.py +31 -0
- relationalai/config/connections/snowflake.py +0 -1
- relationalai/config/external/raiconfig_converter.py +235 -0
- relationalai/config/external/raiconfig_models.py +202 -0
- relationalai/config/external/utils.py +31 -0
- relationalai/config/shims.py +1 -0
- relationalai/semantics/__init__.py +10 -8
- relationalai/semantics/backends/sql/sql_compiler.py +1 -4
- relationalai/semantics/experimental/__init__.py +0 -0
- relationalai/semantics/experimental/builder.py +295 -0
- relationalai/semantics/experimental/builtins.py +154 -0
- relationalai/semantics/frontend/base.py +67 -42
- relationalai/semantics/frontend/core.py +34 -6
- relationalai/semantics/frontend/front_compiler.py +209 -37
- relationalai/semantics/frontend/pprint.py +6 -2
- relationalai/semantics/metamodel/__init__.py +7 -0
- relationalai/semantics/metamodel/metamodel.py +2 -0
- relationalai/semantics/metamodel/metamodel_analyzer.py +58 -16
- relationalai/semantics/metamodel/pprint.py +6 -1
- relationalai/semantics/metamodel/rewriter.py +11 -7
- relationalai/semantics/metamodel/typer.py +116 -41
- relationalai/semantics/reasoners/__init__.py +11 -0
- relationalai/semantics/reasoners/graph/__init__.py +35 -0
- relationalai/semantics/reasoners/graph/core.py +9028 -0
- relationalai/semantics/std/__init__.py +30 -10
- relationalai/semantics/std/aggregates.py +641 -12
- relationalai/semantics/std/common.py +146 -13
- relationalai/semantics/std/constraints.py +71 -1
- relationalai/semantics/std/datetime.py +904 -21
- relationalai/semantics/std/decimals.py +143 -2
- relationalai/semantics/std/floats.py +57 -4
- relationalai/semantics/std/integers.py +98 -4
- relationalai/semantics/std/math.py +857 -35
- relationalai/semantics/std/numbers.py +216 -20
- relationalai/semantics/std/re.py +213 -5
- relationalai/semantics/std/strings.py +437 -44
- relationalai/shims/executor.py +60 -52
- relationalai/shims/fixtures.py +85 -0
- relationalai/shims/helpers.py +26 -2
- relationalai/shims/hoister.py +28 -9
- relationalai/shims/mm2v0.py +204 -173
- relationalai/tools/cli/cli.py +192 -10
- relationalai/tools/cli/components/progress_reader.py +1 -1
- relationalai/tools/cli/docs.py +394 -0
- relationalai/tools/debugger.py +11 -4
- relationalai/tools/qb_debugger.py +435 -0
- relationalai/tools/typer_debugger.py +1 -2
- relationalai/util/dataclasses.py +3 -5
- relationalai/util/docutils.py +1 -2
- relationalai/util/error.py +2 -5
- relationalai/util/python.py +23 -0
- relationalai/util/runtime.py +1 -2
- relationalai/util/schema.py +2 -4
- relationalai/util/structures.py +4 -2
- relationalai/util/tracing.py +8 -2
- {relationalai-1.0.0a3.dist-info → relationalai-1.0.0a5.dist-info}/METADATA +8 -5
- {relationalai-1.0.0a3.dist-info → relationalai-1.0.0a5.dist-info}/RECORD +118 -95
- {relationalai-1.0.0a3.dist-info → relationalai-1.0.0a5.dist-info}/WHEEL +1 -1
- v0/relationalai/__init__.py +1 -1
- v0/relationalai/clients/client.py +52 -18
- v0/relationalai/clients/exec_txn_poller.py +122 -0
- v0/relationalai/clients/local.py +23 -8
- v0/relationalai/clients/resources/azure/azure.py +36 -11
- v0/relationalai/clients/resources/snowflake/__init__.py +4 -4
- v0/relationalai/clients/resources/snowflake/cli_resources.py +12 -1
- v0/relationalai/clients/resources/snowflake/direct_access_resources.py +124 -100
- v0/relationalai/clients/resources/snowflake/engine_service.py +381 -0
- v0/relationalai/clients/resources/snowflake/engine_state_handlers.py +35 -29
- v0/relationalai/clients/resources/snowflake/error_handlers.py +43 -2
- v0/relationalai/clients/resources/snowflake/snowflake.py +277 -179
- v0/relationalai/clients/resources/snowflake/use_index_poller.py +8 -0
- v0/relationalai/clients/types.py +5 -0
- v0/relationalai/errors.py +19 -1
- v0/relationalai/semantics/lqp/algorithms.py +173 -0
- v0/relationalai/semantics/lqp/builtins.py +199 -2
- v0/relationalai/semantics/lqp/executor.py +68 -37
- v0/relationalai/semantics/lqp/ir.py +28 -2
- v0/relationalai/semantics/lqp/model2lqp.py +215 -45
- v0/relationalai/semantics/lqp/passes.py +13 -658
- v0/relationalai/semantics/lqp/rewrite/__init__.py +12 -0
- v0/relationalai/semantics/lqp/rewrite/algorithm.py +385 -0
- v0/relationalai/semantics/lqp/rewrite/constants_to_vars.py +70 -0
- v0/relationalai/semantics/lqp/rewrite/deduplicate_vars.py +104 -0
- v0/relationalai/semantics/lqp/rewrite/eliminate_data.py +108 -0
- v0/relationalai/semantics/lqp/rewrite/extract_keys.py +25 -3
- v0/relationalai/semantics/lqp/rewrite/period_math.py +77 -0
- v0/relationalai/semantics/lqp/rewrite/quantify_vars.py +65 -31
- v0/relationalai/semantics/lqp/rewrite/unify_definitions.py +317 -0
- v0/relationalai/semantics/lqp/utils.py +11 -1
- v0/relationalai/semantics/lqp/validators.py +14 -1
- v0/relationalai/semantics/metamodel/builtins.py +2 -1
- v0/relationalai/semantics/metamodel/compiler.py +2 -1
- v0/relationalai/semantics/metamodel/dependency.py +12 -3
- v0/relationalai/semantics/metamodel/executor.py +11 -1
- v0/relationalai/semantics/metamodel/factory.py +2 -2
- v0/relationalai/semantics/metamodel/helpers.py +7 -0
- v0/relationalai/semantics/metamodel/ir.py +3 -2
- v0/relationalai/semantics/metamodel/rewrite/dnf_union_splitter.py +30 -20
- v0/relationalai/semantics/metamodel/rewrite/flatten.py +50 -13
- v0/relationalai/semantics/metamodel/rewrite/format_outputs.py +9 -3
- v0/relationalai/semantics/metamodel/typer/checker.py +6 -4
- v0/relationalai/semantics/metamodel/typer/typer.py +4 -3
- v0/relationalai/semantics/metamodel/visitor.py +4 -3
- v0/relationalai/semantics/reasoners/optimization/solvers_dev.py +1 -1
- v0/relationalai/semantics/reasoners/optimization/solvers_pb.py +336 -86
- v0/relationalai/semantics/rel/compiler.py +2 -1
- v0/relationalai/semantics/rel/executor.py +3 -2
- v0/relationalai/semantics/tests/lqp/__init__.py +0 -0
- v0/relationalai/semantics/tests/lqp/algorithms.py +345 -0
- v0/relationalai/tools/cli.py +339 -186
- v0/relationalai/tools/cli_controls.py +216 -67
- v0/relationalai/tools/cli_helpers.py +410 -6
- v0/relationalai/util/format.py +5 -2
- {relationalai-1.0.0a3.dist-info → relationalai-1.0.0a5.dist-info}/entry_points.txt +0 -0
- {relationalai-1.0.0a3.dist-info → relationalai-1.0.0a5.dist-info}/top_level.txt +0 -0
relationalai/config/config.py
CHANGED
|
@@ -6,32 +6,28 @@ Main configuration class for PyRel (YAML-based).
|
|
|
6
6
|
from __future__ import annotations
|
|
7
7
|
|
|
8
8
|
from abc import ABC
|
|
9
|
-
from typing import Any, TypeVar, overload, Literal, TYPE_CHECKING
|
|
9
|
+
from typing import Any, TypeVar, overload, Literal, TYPE_CHECKING, cast
|
|
10
10
|
from pydantic import Field, model_validator
|
|
11
11
|
from pydantic_settings import SettingsConfigDict
|
|
12
|
-
|
|
13
|
-
try:
|
|
14
|
-
from confocal import BaseConfig
|
|
15
|
-
except ImportError:
|
|
16
|
-
# Confocal not yet published to PyPI - use pydantic_settings as fallback
|
|
17
|
-
# Config system not actively used yet, this is just to prevent import errors
|
|
18
|
-
from pydantic_settings import BaseSettings as BaseConfig # type: ignore[misc, assignment]
|
|
12
|
+
from confocal import BaseConfig
|
|
19
13
|
|
|
20
14
|
if TYPE_CHECKING:
|
|
21
15
|
import snowflake.snowpark
|
|
22
|
-
import duckdb
|
|
16
|
+
import duckdb # type: ignore[import-not-found]
|
|
17
|
+
import requests
|
|
23
18
|
|
|
24
|
-
from .connections import ConnectionConfig, BaseConnection, SnowflakeConnection, DuckDBConnection
|
|
19
|
+
from .connections import ConnectionConfig, BaseConnection, SnowflakeConnection, DuckDBConnection, LocalConnection
|
|
25
20
|
from .config_fields import EngineConfig, DataConfig, CompilerConfig, ModelConfig, ReasonerConfig, DebugConfig
|
|
26
21
|
from .external.dbt_converter import convert_dbt_to_rai
|
|
27
22
|
from .external.snowflake_converter import convert_snowflake_to_rai
|
|
28
|
-
from .external.
|
|
23
|
+
from .external.raiconfig_converter import convert_raiconfig_to_rai
|
|
24
|
+
from .external.utils import find_dbt_profiles_file, find_snowflake_config_file, find_raiconfig_toml_file
|
|
29
25
|
|
|
30
26
|
# TypeVar for generic connection retrieval
|
|
31
27
|
T = TypeVar('T', bound=BaseConnection)
|
|
32
28
|
|
|
33
29
|
|
|
34
|
-
class _Config(BaseConfig, ABC):
|
|
30
|
+
class _Config(BaseConfig, ABC):
|
|
35
31
|
"""Base configuration class with common fields and methods."""
|
|
36
32
|
|
|
37
33
|
active_profile: str | None = Field(
|
|
@@ -171,6 +167,8 @@ class _Config(BaseConfig, ABC): # type: ignore for now until we publish confocal
|
|
|
171
167
|
connection_class = "SnowflakeConnection"
|
|
172
168
|
elif isinstance(connection, DuckDBConnection):
|
|
173
169
|
connection_class = "DuckDBConnection"
|
|
170
|
+
elif isinstance(connection, LocalConnection):
|
|
171
|
+
connection_class = "LocalConnection"
|
|
174
172
|
else:
|
|
175
173
|
connection_class = actual_type
|
|
176
174
|
|
|
@@ -179,18 +177,15 @@ class _Config(BaseConfig, ABC): # type: ignore for now until we publish confocal
|
|
|
179
177
|
f"Got: {connection_class} (actual: {actual_type})"
|
|
180
178
|
)
|
|
181
179
|
|
|
182
|
-
return connection
|
|
183
|
-
|
|
184
|
-
@overload
|
|
185
|
-
def get_session(self, connection_type: type[SnowflakeConnection]) -> snowflake.snowpark.Session: ...
|
|
180
|
+
return cast(T, connection)
|
|
186
181
|
|
|
187
182
|
@overload
|
|
188
|
-
def get_session(self
|
|
183
|
+
def get_session(self) -> snowflake.snowpark.Session | duckdb.DuckDBPyConnection | requests.Session: ...
|
|
189
184
|
|
|
190
185
|
@overload
|
|
191
|
-
def get_session(self, connection_type:
|
|
186
|
+
def get_session(self, connection_type: type[SnowflakeConnection] | type[DuckDBConnection] | type[LocalConnection]) -> snowflake.snowpark.Session | duckdb.DuckDBPyConnection | requests.Session: ...
|
|
192
187
|
|
|
193
|
-
def get_session(self, connection_type: type[SnowflakeConnection] | type[DuckDBConnection] | None = None) -> snowflake.snowpark.Session | duckdb.DuckDBPyConnection:
|
|
188
|
+
def get_session(self, connection_type: type[SnowflakeConnection] | type[DuckDBConnection] | type[LocalConnection] | None = None) -> snowflake.snowpark.Session | duckdb.DuckDBPyConnection | requests.Session:
|
|
194
189
|
if connection_type is None:
|
|
195
190
|
connection = self.get_default_connection()
|
|
196
191
|
return connection.get_session()
|
|
@@ -250,6 +245,35 @@ class ConfigFromDBT(_Config):
|
|
|
250
245
|
return convert_dbt_to_rai(data)
|
|
251
246
|
|
|
252
247
|
|
|
248
|
+
class ConfigFromRAIConfigToml(_Config):
|
|
249
|
+
"""Config loaded from deprecated raiconfig.toml (for backwards compatibility)"""
|
|
250
|
+
|
|
251
|
+
source: Literal["raiconfig_toml"] = Field(default="raiconfig_toml", exclude=True)
|
|
252
|
+
|
|
253
|
+
model_config = SettingsConfigDict(
|
|
254
|
+
toml_file=find_raiconfig_toml_file(),
|
|
255
|
+
extra="ignore",
|
|
256
|
+
nested_model_default_partial_update=True,
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
@model_validator(mode='before')
|
|
260
|
+
@classmethod
|
|
261
|
+
def convert_raiconfig_structure(cls, data: Any):
|
|
262
|
+
"""Convert deprecated raiconfig.toml format to new format."""
|
|
263
|
+
if not isinstance(data, dict):
|
|
264
|
+
return data
|
|
265
|
+
|
|
266
|
+
# Add deprecation warning
|
|
267
|
+
import warnings
|
|
268
|
+
warnings.warn(
|
|
269
|
+
"raiconfig.toml is deprecated. Please migrate to raiconfig.yaml.",
|
|
270
|
+
DeprecationWarning,
|
|
271
|
+
stacklevel=2
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
return convert_raiconfig_to_rai(data)
|
|
275
|
+
|
|
276
|
+
|
|
253
277
|
# =============================================================================
|
|
254
278
|
# Config Factory - Tries each source in order
|
|
255
279
|
# =============================================================================
|
|
@@ -258,11 +282,13 @@ def Config(**data) -> _Config:
|
|
|
258
282
|
"""
|
|
259
283
|
Create Config by trying multiple sources in priority order:
|
|
260
284
|
1. RAIConfig (raiconfig.yaml) - or direct data if provided
|
|
261
|
-
2.
|
|
262
|
-
3.
|
|
285
|
+
2. ConfigFromRAIConfigToml (raiconfig.toml - DEPRECATED) - only if no data provided
|
|
286
|
+
3. ConfigFromSnowflake (config.toml) - only if no data provided
|
|
287
|
+
4. ConfigFromDBT (profiles.yml) - only if no data provided
|
|
263
288
|
"""
|
|
264
289
|
sources = [
|
|
265
290
|
("RAIConfig (raiconfig.yaml)", RAIConfig, True), # passes **data
|
|
291
|
+
("ConfigFromRAIConfigToml (raiconfig.toml - DEPRECATED)", ConfigFromRAIConfigToml, False),
|
|
266
292
|
("ConfigFromSnowflake (config.toml)", ConfigFromSnowflake, False),
|
|
267
293
|
("ConfigFromDBT (profiles.yml)", ConfigFromDBT, False),
|
|
268
294
|
]
|
|
@@ -5,8 +5,9 @@ This module exports:
|
|
|
5
5
|
- BaseConnection: Base class for all connections
|
|
6
6
|
- Snowflake authenticators: UsernamePasswordAuth, UsernamePasswordMFAAuth, etc.
|
|
7
7
|
- DuckDBConnection: DuckDB connection
|
|
8
|
+
- LocalConnection: Local RAI server connection
|
|
8
9
|
- SnowflakeConnection: Discriminated union of Snowflake authenticators
|
|
9
|
-
- ConnectionConfig: Top-level discriminated union (Snowflake | DuckDB)
|
|
10
|
+
- ConnectionConfig: Top-level discriminated union (Snowflake | DuckDB | Local)
|
|
10
11
|
"""
|
|
11
12
|
|
|
12
13
|
from __future__ import annotations
|
|
@@ -26,9 +27,10 @@ from .snowflake import (
|
|
|
26
27
|
SnowflakeAuthenticator,
|
|
27
28
|
)
|
|
28
29
|
from .duckdb import DuckDBConnection
|
|
30
|
+
from .local import LocalConnection
|
|
29
31
|
|
|
30
32
|
ConnectionConfig = Annotated[
|
|
31
|
-
Union[SnowflakeAuthenticator, DuckDBConnection],
|
|
33
|
+
Union[SnowflakeAuthenticator, DuckDBConnection, LocalConnection],
|
|
32
34
|
Field(discriminator="type")
|
|
33
35
|
]
|
|
34
36
|
|
|
@@ -42,5 +44,6 @@ __all__ = [
|
|
|
42
44
|
"ProgrammaticAccessTokenAuth",
|
|
43
45
|
"SnowflakeConnection",
|
|
44
46
|
"DuckDBConnection",
|
|
47
|
+
"LocalConnection",
|
|
45
48
|
"ConnectionConfig",
|
|
46
49
|
]
|
|
@@ -4,7 +4,7 @@ from typing import Literal, Any, TYPE_CHECKING
|
|
|
4
4
|
from pydantic import Field
|
|
5
5
|
|
|
6
6
|
if TYPE_CHECKING:
|
|
7
|
-
import duckdb
|
|
7
|
+
import duckdb # type: ignore[import-not-found]
|
|
8
8
|
|
|
9
9
|
from .base import BaseConnection
|
|
10
10
|
|
|
@@ -17,7 +17,7 @@ class DuckDBConnection(BaseConnection):
|
|
|
17
17
|
config: dict[str, Any] | None = Field(None, description="Additional DuckDB configuration options")
|
|
18
18
|
|
|
19
19
|
def get_session(self) -> duckdb.DuckDBPyConnection:
|
|
20
|
-
import duckdb
|
|
20
|
+
import duckdb # type: ignore[import-not-found]
|
|
21
21
|
|
|
22
22
|
if self._cached_session is None:
|
|
23
23
|
self._cached_session = duckdb.connect(
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
"""Local server connection configuration."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Literal, TYPE_CHECKING
|
|
6
|
+
from pydantic import Field
|
|
7
|
+
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
import requests
|
|
10
|
+
|
|
11
|
+
from .base import BaseConnection
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class LocalConnection(BaseConnection):
|
|
15
|
+
"""Connection to a local RAI server."""
|
|
16
|
+
|
|
17
|
+
type: Literal["local"] = "local"
|
|
18
|
+
|
|
19
|
+
host: str = Field(default="localhost", description="Local server host")
|
|
20
|
+
port: int = Field(default=8010, description="Local server port")
|
|
21
|
+
engine: str | None = Field(default=None, description="Engine name")
|
|
22
|
+
|
|
23
|
+
def get_session(self) -> requests.Session:
|
|
24
|
+
"""Return a requests.Session instance for the local server."""
|
|
25
|
+
# Should we migrate to using LocalClient to PyRel V1?
|
|
26
|
+
from v0.relationalai.clients.local import LocalClient # type: ignore[import-not-found]
|
|
27
|
+
|
|
28
|
+
if self._cached_session is None:
|
|
29
|
+
self._cached_session = LocalClient(host=self.host, port=self.port).http_session
|
|
30
|
+
|
|
31
|
+
return self._cached_session
|
|
@@ -0,0 +1,235 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Converter for deprecated raiconfig.toml to RAI Config format.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
from .raiconfig_models import RAIConfigFile, RAIConfigSnowflakeProfile, RAIConfigLocalProfile
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
# Central field mapping configuration
|
|
13
|
+
# Format: "old_field_name": ("new_section", "new_field_name")
|
|
14
|
+
# where section=None means top-level field
|
|
15
|
+
FIELD_MAPPINGS = {
|
|
16
|
+
# Top-level fields (stay at top level)
|
|
17
|
+
"use_graph_index": (None, "use_graph_index"),
|
|
18
|
+
"use_direct_access": (None, "use_direct_access"),
|
|
19
|
+
"enable_otel_handler": (None, "enable_otel_handler"),
|
|
20
|
+
"use_package_manager": (None, "use_package_manager"),
|
|
21
|
+
"reuse_model": (None, "reuse_model"),
|
|
22
|
+
"skip_invalid_data": (None, "skip_invalid_data"),
|
|
23
|
+
|
|
24
|
+
# Engine fields
|
|
25
|
+
"engine": ("engine", "name"),
|
|
26
|
+
"engine_size": ("engine", "size"),
|
|
27
|
+
"auto_suspend_mins": ("engine", "auto_suspend_mins"),
|
|
28
|
+
|
|
29
|
+
# Data fields
|
|
30
|
+
"wait_for_stream_sync": ("data", "wait_for_stream_sync"),
|
|
31
|
+
"ensure_change_tracking": ("data", "ensure_change_tracking"),
|
|
32
|
+
"data_freshness_mins": ("data", "data_freshness_mins"),
|
|
33
|
+
"query_timeout_mins": ("data", "query_timeout_mins"),
|
|
34
|
+
"download_url_type": ("data", "download_url_type"),
|
|
35
|
+
|
|
36
|
+
# Debug fields
|
|
37
|
+
"debug": ("debug", "enabled"), # Note: maps to "enabled"!
|
|
38
|
+
"show_full_traces": ("debug", "show_full_traces"),
|
|
39
|
+
|
|
40
|
+
# Compiler fields
|
|
41
|
+
"use_monotype_operators": ("compiler", "use_monotype_operators"),
|
|
42
|
+
"show_corerel_errors": ("compiler", "show_corerel_errors"),
|
|
43
|
+
"dry_run": ("compiler", "dry_run"),
|
|
44
|
+
"inspect_df": ("compiler", "inspect_df"),
|
|
45
|
+
"use_value_types": ("compiler", "use_value_types"),
|
|
46
|
+
"debug_hidden_keys": ("compiler", "debug_hidden_keys"),
|
|
47
|
+
"wide_outputs": ("compiler", "wide_outputs"),
|
|
48
|
+
"strict": ("compiler", "strict"),
|
|
49
|
+
"use_inlined_intermediates": ("compiler", "use_inlined_intermediates"),
|
|
50
|
+
"inline_value_maps": ("compiler", "inline_value_maps"),
|
|
51
|
+
"inline_entity_maps": ("compiler", "inline_entity_maps"),
|
|
52
|
+
|
|
53
|
+
# Model fields
|
|
54
|
+
"keep": ("model", "keep"),
|
|
55
|
+
"isolated": ("model", "isolated"),
|
|
56
|
+
"nowait_durable": ("model", "nowait_durable"),
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
# Connection fields (stay in connection dict)
|
|
60
|
+
CONNECTION_FIELDS = {
|
|
61
|
+
"account", "warehouse", "user", "password", "role", "database", "schema",
|
|
62
|
+
"authenticator", "rai_app_name", "passcode",
|
|
63
|
+
"private_key_path", "private_key_passphrase", "token",
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
# Old -> New field name mappings for connection fields
|
|
67
|
+
# Old field names from deprecated raiconfig.toml are mapped to new names
|
|
68
|
+
# These are accessed via model_extra in RAIConfigSnowflakeProfile.convert()
|
|
69
|
+
CONNECTION_FIELD_MAPPINGS = {
|
|
70
|
+
"private_key_file": "private_key_path",
|
|
71
|
+
"private_key_file_pwd": "private_key_passphrase",
|
|
72
|
+
"token_file_path": "token",
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def map_fields_to_nested_structure(merged: dict[str, Any]) -> dict[str, Any]:
|
|
77
|
+
"""
|
|
78
|
+
Map flat config fields to nested structure using FIELD_MAPPINGS.
|
|
79
|
+
|
|
80
|
+
This replaces all the manual if statements!
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
merged: Flat dictionary with all config fields
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
Nested dictionary structured according to new config format
|
|
87
|
+
"""
|
|
88
|
+
result = {}
|
|
89
|
+
nested_sections = {} # Collect fields for each section
|
|
90
|
+
|
|
91
|
+
for old_field, (section, new_field) in FIELD_MAPPINGS.items():
|
|
92
|
+
if old_field not in merged or merged[old_field] is None:
|
|
93
|
+
continue # Skip missing or None values
|
|
94
|
+
|
|
95
|
+
value = merged[old_field]
|
|
96
|
+
|
|
97
|
+
if section is None:
|
|
98
|
+
# Top-level field
|
|
99
|
+
result[new_field] = value
|
|
100
|
+
else:
|
|
101
|
+
# Nested section field
|
|
102
|
+
if section not in nested_sections:
|
|
103
|
+
nested_sections[section] = {}
|
|
104
|
+
nested_sections[section][new_field] = value
|
|
105
|
+
|
|
106
|
+
# Add nested sections to result
|
|
107
|
+
for section, fields in nested_sections.items():
|
|
108
|
+
if fields: # Only add non-empty sections
|
|
109
|
+
result[section] = fields
|
|
110
|
+
|
|
111
|
+
return result
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def merge_profile_with_toplevel(
|
|
115
|
+
config_file: RAIConfigFile,
|
|
116
|
+
profile: RAIConfigSnowflakeProfile | RAIConfigLocalProfile
|
|
117
|
+
) -> dict[str, Any]:
|
|
118
|
+
"""
|
|
119
|
+
Merge profile over top-level config (profile takes precedence).
|
|
120
|
+
|
|
121
|
+
This matches the v0 behavior where profile-specific settings override
|
|
122
|
+
global defaults.
|
|
123
|
+
"""
|
|
124
|
+
# Start with top-level config
|
|
125
|
+
merged = config_file.model_dump(
|
|
126
|
+
exclude={"profile", "active_profile"},
|
|
127
|
+
exclude_none=True
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
# Overlay profile fields (they take precedence)
|
|
131
|
+
profile_dict = profile.model_dump(exclude_none=True)
|
|
132
|
+
|
|
133
|
+
# Merge profile into config (profile overrides top-level)
|
|
134
|
+
merged.update(profile_dict)
|
|
135
|
+
|
|
136
|
+
return merged
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def convert_raiconfig_to_rai(
|
|
140
|
+
rai_config: dict[str, Any],
|
|
141
|
+
profile_name: str | None = None
|
|
142
|
+
) -> dict[str, Any]:
|
|
143
|
+
"""
|
|
144
|
+
Convert deprecated raiconfig.toml to new RAI Config format.
|
|
145
|
+
|
|
146
|
+
Flow:
|
|
147
|
+
1. Validate with RAIConfigFile model
|
|
148
|
+
2. Select active profile (parameter > active_profile > single profile)
|
|
149
|
+
3. Merge profile over top-level config (profile takes precedence)
|
|
150
|
+
4. Extract connection from profile
|
|
151
|
+
5. Map flat fields to nested structure
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
rai_config: Parsed TOML config dict
|
|
155
|
+
profile_name: Profile name to use (overrides active_profile)
|
|
156
|
+
|
|
157
|
+
Returns:
|
|
158
|
+
Dict in new config format ready for RAIConfig model
|
|
159
|
+
"""
|
|
160
|
+
config_file = RAIConfigFile(**rai_config)
|
|
161
|
+
|
|
162
|
+
# Profile selection logic
|
|
163
|
+
if not profile_name:
|
|
164
|
+
profile_name = config_file.active_profile
|
|
165
|
+
|
|
166
|
+
# Supported platforms
|
|
167
|
+
supported_platforms = {"snowflake", "local"}
|
|
168
|
+
|
|
169
|
+
# Filter to supported profiles (Snowflake and Local)
|
|
170
|
+
supported_profiles = {
|
|
171
|
+
name: data for name, data in config_file.profile.items()
|
|
172
|
+
if data.get("platform", "snowflake") in supported_platforms
|
|
173
|
+
}
|
|
174
|
+
|
|
175
|
+
if not profile_name and len(supported_profiles) == 1:
|
|
176
|
+
# Auto-select single supported profile
|
|
177
|
+
profile_name = next(iter(supported_profiles.keys()))
|
|
178
|
+
|
|
179
|
+
if not profile_name:
|
|
180
|
+
if not supported_profiles:
|
|
181
|
+
raise ValueError(
|
|
182
|
+
"No supported profiles found in raiconfig.toml. "
|
|
183
|
+
"Supported platforms: snowflake, local. "
|
|
184
|
+
"Please add a supported profile or migrate to raiconfig.yaml."
|
|
185
|
+
)
|
|
186
|
+
raise ValueError(
|
|
187
|
+
f"No profile specified. Either set 'active_profile' in config, "
|
|
188
|
+
f"provide profile_name parameter, or ensure only one supported profile exists. "
|
|
189
|
+
f"Available profiles: {list(supported_profiles.keys())}"
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
if profile_name not in config_file.profile:
|
|
193
|
+
available = list(config_file.profile.keys())
|
|
194
|
+
raise ValueError(
|
|
195
|
+
f"Profile '{profile_name}' not found. Available profiles: {available}"
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
# Get profile data
|
|
199
|
+
profile_data = config_file.profile[profile_name]
|
|
200
|
+
platform = profile_data.get("platform", "snowflake")
|
|
201
|
+
|
|
202
|
+
# Validate and convert based on platform
|
|
203
|
+
if platform == "local":
|
|
204
|
+
profile = RAIConfigLocalProfile(**profile_data)
|
|
205
|
+
connection = profile.convert()
|
|
206
|
+
connection_name = "local"
|
|
207
|
+
elif platform == "snowflake":
|
|
208
|
+
profile = RAIConfigSnowflakeProfile(**profile_data)
|
|
209
|
+
connection = profile.convert()
|
|
210
|
+
connection_name = "snowflake"
|
|
211
|
+
else:
|
|
212
|
+
raise ValueError(
|
|
213
|
+
f"Profile '{profile_name}' uses platform '{platform}'. "
|
|
214
|
+
f"Supported platforms: snowflake, local. "
|
|
215
|
+
f"Please use a supported profile or migrate to raiconfig.yaml."
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
# Merge profile with top-level config
|
|
219
|
+
merged = merge_profile_with_toplevel(config_file, profile)
|
|
220
|
+
|
|
221
|
+
# Build base result
|
|
222
|
+
result = {
|
|
223
|
+
"connections": {connection_name: connection},
|
|
224
|
+
"default_connection": connection_name,
|
|
225
|
+
}
|
|
226
|
+
|
|
227
|
+
# Map all config fields automatically using FIELD_MAPPINGS
|
|
228
|
+
mapped_fields = map_fields_to_nested_structure(merged)
|
|
229
|
+
result.update(mapped_fields)
|
|
230
|
+
|
|
231
|
+
# Handle reasoner separately (nested dict structure)
|
|
232
|
+
if "reasoner" in merged and merged["reasoner"]:
|
|
233
|
+
result["reasoner"] = merged["reasoner"]
|
|
234
|
+
|
|
235
|
+
return result
|
|
@@ -0,0 +1,202 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Pydantic models for deprecated raiconfig.toml structure.
|
|
3
|
+
|
|
4
|
+
These models validate the structure of deprecated raiconfig.toml files before
|
|
5
|
+
converting them to RAI Config format.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
from typing import Any, Literal
|
|
11
|
+
from pydantic import BaseModel, ConfigDict, Field
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def normalize_authenticator(authenticator: str | None) -> str:
|
|
15
|
+
"""Normalize old authenticator names to new format."""
|
|
16
|
+
if not authenticator:
|
|
17
|
+
return "username_password"
|
|
18
|
+
|
|
19
|
+
auth = authenticator.lower()
|
|
20
|
+
|
|
21
|
+
mapping = {
|
|
22
|
+
"snowflake": "username_password",
|
|
23
|
+
"username_password": "username_password",
|
|
24
|
+
"username_password_mfa": "username_password_mfa",
|
|
25
|
+
"externalbrowser": "externalbrowser",
|
|
26
|
+
"jwt": "jwt",
|
|
27
|
+
"snowflake_jwt": "jwt",
|
|
28
|
+
"oauth": "oauth",
|
|
29
|
+
"programmatic_access_token": "programmatic_access_token",
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
return mapping.get(auth, "username_password")
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class RAIConfigSnowflakeProfile(BaseModel):
|
|
36
|
+
"""Snowflake profile in deprecated raiconfig.toml."""
|
|
37
|
+
model_config = ConfigDict(extra="allow")
|
|
38
|
+
|
|
39
|
+
platform: Literal["snowflake"]
|
|
40
|
+
|
|
41
|
+
# Required connection fields
|
|
42
|
+
account: str
|
|
43
|
+
warehouse: str
|
|
44
|
+
|
|
45
|
+
# Auth fields (required by some authenticators)
|
|
46
|
+
user: str | None = None
|
|
47
|
+
password: str | None = None
|
|
48
|
+
|
|
49
|
+
# Optional connection fields
|
|
50
|
+
role: str | None = None
|
|
51
|
+
database: str | None = None
|
|
52
|
+
schema_: str | None = Field(default=None, alias="schema")
|
|
53
|
+
authenticator: str | None = None
|
|
54
|
+
rai_app_name: str = "RELATIONALAI"
|
|
55
|
+
|
|
56
|
+
# Auth-specific fields
|
|
57
|
+
passcode: str | None = None # MFA
|
|
58
|
+
private_key_path: str | None = None # JWT
|
|
59
|
+
private_key_passphrase: str | None = None
|
|
60
|
+
token: str | None = None # OAuth/PAT
|
|
61
|
+
|
|
62
|
+
# Optional override fields (can also be top-level)
|
|
63
|
+
engine: str | None = None
|
|
64
|
+
engine_size: str | None = None
|
|
65
|
+
auto_suspend_mins: int | None = None
|
|
66
|
+
data_freshness_mins: int | None = None
|
|
67
|
+
query_timeout_mins: int | None = None
|
|
68
|
+
download_url_type: Literal["internal", "external"] | None = None
|
|
69
|
+
wait_for_stream_sync: bool | None = None
|
|
70
|
+
ensure_change_tracking: bool | None = None
|
|
71
|
+
|
|
72
|
+
# Debug fields (can be in profile or top-level)
|
|
73
|
+
debug: bool | None = None
|
|
74
|
+
show_full_traces: bool | None = None
|
|
75
|
+
|
|
76
|
+
# Other top-level fields that can be in profiles
|
|
77
|
+
use_graph_index: bool | None = None
|
|
78
|
+
use_direct_access: bool | None = None
|
|
79
|
+
skip_invalid_data: bool | None = None
|
|
80
|
+
use_package_manager: bool | None = None
|
|
81
|
+
enable_otel_handler: bool | None = None
|
|
82
|
+
reuse_model: bool | None = None
|
|
83
|
+
|
|
84
|
+
# Compiler fields (can be in profile or top-level)
|
|
85
|
+
use_monotype_operators: bool | None = None
|
|
86
|
+
show_corerel_errors: bool | None = None
|
|
87
|
+
dry_run: bool | None = None
|
|
88
|
+
inspect_df: bool | None = None
|
|
89
|
+
use_value_types: bool | None = None
|
|
90
|
+
debug_hidden_keys: bool | None = None
|
|
91
|
+
wide_outputs: bool | None = None
|
|
92
|
+
strict: bool | None = None
|
|
93
|
+
use_inlined_intermediates: bool | None = None
|
|
94
|
+
inline_value_maps: bool | None = None
|
|
95
|
+
inline_entity_maps: bool | None = None
|
|
96
|
+
|
|
97
|
+
# Model fields (can be in profile or top-level)
|
|
98
|
+
keep: bool | None = None
|
|
99
|
+
isolated: bool | None = None
|
|
100
|
+
nowait_durable: bool | None = None
|
|
101
|
+
|
|
102
|
+
def convert(self) -> dict[str, Any]:
|
|
103
|
+
"""Convert to new connection format."""
|
|
104
|
+
# Import here to avoid circular dependency
|
|
105
|
+
from .raiconfig_converter import FIELD_MAPPINGS, CONNECTION_FIELD_MAPPINGS
|
|
106
|
+
|
|
107
|
+
excluded_fields = set(FIELD_MAPPINGS.keys()) | {"platform"}
|
|
108
|
+
|
|
109
|
+
connection_dict = self.model_dump(
|
|
110
|
+
exclude_none=True,
|
|
111
|
+
by_alias=True,
|
|
112
|
+
exclude=excluded_fields
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
connection_dict["type"] = "snowflake"
|
|
116
|
+
connection_dict["authenticator"] = normalize_authenticator(self.authenticator)
|
|
117
|
+
|
|
118
|
+
# Map old field names to new field names (from model_extra, since old names aren't explicit fields)
|
|
119
|
+
extra = self.model_extra or {}
|
|
120
|
+
for old_name, new_name in CONNECTION_FIELD_MAPPINGS.items():
|
|
121
|
+
old_value = extra.get(old_name)
|
|
122
|
+
new_value = getattr(self, new_name, None)
|
|
123
|
+
# Use new name if set, otherwise fall back to old name
|
|
124
|
+
final_value = new_value if new_value is not None else old_value
|
|
125
|
+
if final_value is not None:
|
|
126
|
+
connection_dict[new_name] = final_value
|
|
127
|
+
|
|
128
|
+
return connection_dict
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
class RAIConfigLocalProfile(BaseModel):
|
|
132
|
+
"""Local profile in deprecated raiconfig.toml."""
|
|
133
|
+
model_config = ConfigDict(extra="allow")
|
|
134
|
+
|
|
135
|
+
platform: Literal["local"]
|
|
136
|
+
|
|
137
|
+
# Connection fields
|
|
138
|
+
host: str = "localhost"
|
|
139
|
+
port: int = 8010
|
|
140
|
+
engine: str | None = None
|
|
141
|
+
|
|
142
|
+
def convert(self) -> dict[str, Any]:
|
|
143
|
+
"""Convert to new connection format."""
|
|
144
|
+
return {
|
|
145
|
+
"type": "local",
|
|
146
|
+
"host": self.host,
|
|
147
|
+
"port": self.port,
|
|
148
|
+
"engine": self.engine,
|
|
149
|
+
}
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
class RAIConfigFile(BaseModel):
|
|
153
|
+
"""Root model for deprecated raiconfig.toml."""
|
|
154
|
+
model_config = ConfigDict(extra="allow")
|
|
155
|
+
|
|
156
|
+
# Active profile selection
|
|
157
|
+
active_profile: str | None = None
|
|
158
|
+
|
|
159
|
+
# Top-level config fields (can be overridden by profile)
|
|
160
|
+
debug: bool = True
|
|
161
|
+
show_full_traces: bool = False
|
|
162
|
+
use_graph_index: bool = True
|
|
163
|
+
use_direct_access: bool = False
|
|
164
|
+
skip_invalid_data: bool | None = None
|
|
165
|
+
use_package_manager: bool = True
|
|
166
|
+
enable_otel_handler: bool = False
|
|
167
|
+
reuse_model: bool = True
|
|
168
|
+
|
|
169
|
+
# Engine config (top-level defaults)
|
|
170
|
+
auto_suspend_mins: int | None = None
|
|
171
|
+
engine_size: str | None = None
|
|
172
|
+
|
|
173
|
+
# Data config (top-level defaults)
|
|
174
|
+
wait_for_stream_sync: bool = True
|
|
175
|
+
ensure_change_tracking: bool = False
|
|
176
|
+
data_freshness_mins: int | None = None
|
|
177
|
+
query_timeout_mins: int | None = None
|
|
178
|
+
download_url_type: Literal["internal", "external"] | None = None
|
|
179
|
+
|
|
180
|
+
# Compiler config
|
|
181
|
+
use_monotype_operators: bool | None = None
|
|
182
|
+
show_corerel_errors: bool | None = None
|
|
183
|
+
dry_run: bool | None = None
|
|
184
|
+
inspect_df: bool | None = None
|
|
185
|
+
use_value_types: bool | None = None
|
|
186
|
+
debug_hidden_keys: bool | None = None
|
|
187
|
+
wide_outputs: bool | None = None
|
|
188
|
+
strict: bool | None = None
|
|
189
|
+
use_inlined_intermediates: bool | None = None
|
|
190
|
+
inline_value_maps: bool | None = None
|
|
191
|
+
inline_entity_maps: bool | None = None
|
|
192
|
+
|
|
193
|
+
# Model config
|
|
194
|
+
keep: bool | None = None
|
|
195
|
+
isolated: bool | None = None
|
|
196
|
+
nowait_durable: bool | None = None
|
|
197
|
+
|
|
198
|
+
# Reasoner config (nested [reasoner.rule] sections)
|
|
199
|
+
reasoner: dict[str, Any] | None = None
|
|
200
|
+
|
|
201
|
+
# Profiles (accept any structure, validate Snowflake profiles on-demand)
|
|
202
|
+
profile: dict[str, dict[str, Any]] = Field(default_factory=dict)
|
|
@@ -17,3 +17,34 @@ def find_snowflake_config_file() -> str | None:
|
|
|
17
17
|
if os.path.exists(path):
|
|
18
18
|
return path
|
|
19
19
|
return None
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def find_raiconfig_toml_file() -> str | None:
|
|
23
|
+
"""
|
|
24
|
+
Find raiconfig.toml file.
|
|
25
|
+
|
|
26
|
+
Search order:
|
|
27
|
+
1. Current directory upwards (project-level)
|
|
28
|
+
2. ~/.rai/raiconfig.toml (user-level)
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
Path to raiconfig.toml if found, None otherwise.
|
|
32
|
+
"""
|
|
33
|
+
# Search upwards from current directory
|
|
34
|
+
current_dir = os.path.abspath(os.getcwd())
|
|
35
|
+
while True:
|
|
36
|
+
file_path = os.path.join(current_dir, "raiconfig.toml")
|
|
37
|
+
if os.path.exists(file_path):
|
|
38
|
+
return file_path
|
|
39
|
+
|
|
40
|
+
parent_dir = os.path.dirname(current_dir)
|
|
41
|
+
if parent_dir == current_dir: # Reached root
|
|
42
|
+
break
|
|
43
|
+
current_dir = parent_dir
|
|
44
|
+
|
|
45
|
+
# Check user home directory
|
|
46
|
+
user_config_path = os.path.expanduser("~/.rai/raiconfig.toml")
|
|
47
|
+
if os.path.exists(user_config_path):
|
|
48
|
+
return user_config_path
|
|
49
|
+
|
|
50
|
+
return None
|
relationalai/config/shims.py
CHANGED