relationalai 1.0.0a3__py3-none-any.whl → 1.0.0a5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (118) hide show
  1. relationalai/config/config.py +47 -21
  2. relationalai/config/connections/__init__.py +5 -2
  3. relationalai/config/connections/duckdb.py +2 -2
  4. relationalai/config/connections/local.py +31 -0
  5. relationalai/config/connections/snowflake.py +0 -1
  6. relationalai/config/external/raiconfig_converter.py +235 -0
  7. relationalai/config/external/raiconfig_models.py +202 -0
  8. relationalai/config/external/utils.py +31 -0
  9. relationalai/config/shims.py +1 -0
  10. relationalai/semantics/__init__.py +10 -8
  11. relationalai/semantics/backends/sql/sql_compiler.py +1 -4
  12. relationalai/semantics/experimental/__init__.py +0 -0
  13. relationalai/semantics/experimental/builder.py +295 -0
  14. relationalai/semantics/experimental/builtins.py +154 -0
  15. relationalai/semantics/frontend/base.py +67 -42
  16. relationalai/semantics/frontend/core.py +34 -6
  17. relationalai/semantics/frontend/front_compiler.py +209 -37
  18. relationalai/semantics/frontend/pprint.py +6 -2
  19. relationalai/semantics/metamodel/__init__.py +7 -0
  20. relationalai/semantics/metamodel/metamodel.py +2 -0
  21. relationalai/semantics/metamodel/metamodel_analyzer.py +58 -16
  22. relationalai/semantics/metamodel/pprint.py +6 -1
  23. relationalai/semantics/metamodel/rewriter.py +11 -7
  24. relationalai/semantics/metamodel/typer.py +116 -41
  25. relationalai/semantics/reasoners/__init__.py +11 -0
  26. relationalai/semantics/reasoners/graph/__init__.py +35 -0
  27. relationalai/semantics/reasoners/graph/core.py +9028 -0
  28. relationalai/semantics/std/__init__.py +30 -10
  29. relationalai/semantics/std/aggregates.py +641 -12
  30. relationalai/semantics/std/common.py +146 -13
  31. relationalai/semantics/std/constraints.py +71 -1
  32. relationalai/semantics/std/datetime.py +904 -21
  33. relationalai/semantics/std/decimals.py +143 -2
  34. relationalai/semantics/std/floats.py +57 -4
  35. relationalai/semantics/std/integers.py +98 -4
  36. relationalai/semantics/std/math.py +857 -35
  37. relationalai/semantics/std/numbers.py +216 -20
  38. relationalai/semantics/std/re.py +213 -5
  39. relationalai/semantics/std/strings.py +437 -44
  40. relationalai/shims/executor.py +60 -52
  41. relationalai/shims/fixtures.py +85 -0
  42. relationalai/shims/helpers.py +26 -2
  43. relationalai/shims/hoister.py +28 -9
  44. relationalai/shims/mm2v0.py +204 -173
  45. relationalai/tools/cli/cli.py +192 -10
  46. relationalai/tools/cli/components/progress_reader.py +1 -1
  47. relationalai/tools/cli/docs.py +394 -0
  48. relationalai/tools/debugger.py +11 -4
  49. relationalai/tools/qb_debugger.py +435 -0
  50. relationalai/tools/typer_debugger.py +1 -2
  51. relationalai/util/dataclasses.py +3 -5
  52. relationalai/util/docutils.py +1 -2
  53. relationalai/util/error.py +2 -5
  54. relationalai/util/python.py +23 -0
  55. relationalai/util/runtime.py +1 -2
  56. relationalai/util/schema.py +2 -4
  57. relationalai/util/structures.py +4 -2
  58. relationalai/util/tracing.py +8 -2
  59. {relationalai-1.0.0a3.dist-info → relationalai-1.0.0a5.dist-info}/METADATA +8 -5
  60. {relationalai-1.0.0a3.dist-info → relationalai-1.0.0a5.dist-info}/RECORD +118 -95
  61. {relationalai-1.0.0a3.dist-info → relationalai-1.0.0a5.dist-info}/WHEEL +1 -1
  62. v0/relationalai/__init__.py +1 -1
  63. v0/relationalai/clients/client.py +52 -18
  64. v0/relationalai/clients/exec_txn_poller.py +122 -0
  65. v0/relationalai/clients/local.py +23 -8
  66. v0/relationalai/clients/resources/azure/azure.py +36 -11
  67. v0/relationalai/clients/resources/snowflake/__init__.py +4 -4
  68. v0/relationalai/clients/resources/snowflake/cli_resources.py +12 -1
  69. v0/relationalai/clients/resources/snowflake/direct_access_resources.py +124 -100
  70. v0/relationalai/clients/resources/snowflake/engine_service.py +381 -0
  71. v0/relationalai/clients/resources/snowflake/engine_state_handlers.py +35 -29
  72. v0/relationalai/clients/resources/snowflake/error_handlers.py +43 -2
  73. v0/relationalai/clients/resources/snowflake/snowflake.py +277 -179
  74. v0/relationalai/clients/resources/snowflake/use_index_poller.py +8 -0
  75. v0/relationalai/clients/types.py +5 -0
  76. v0/relationalai/errors.py +19 -1
  77. v0/relationalai/semantics/lqp/algorithms.py +173 -0
  78. v0/relationalai/semantics/lqp/builtins.py +199 -2
  79. v0/relationalai/semantics/lqp/executor.py +68 -37
  80. v0/relationalai/semantics/lqp/ir.py +28 -2
  81. v0/relationalai/semantics/lqp/model2lqp.py +215 -45
  82. v0/relationalai/semantics/lqp/passes.py +13 -658
  83. v0/relationalai/semantics/lqp/rewrite/__init__.py +12 -0
  84. v0/relationalai/semantics/lqp/rewrite/algorithm.py +385 -0
  85. v0/relationalai/semantics/lqp/rewrite/constants_to_vars.py +70 -0
  86. v0/relationalai/semantics/lqp/rewrite/deduplicate_vars.py +104 -0
  87. v0/relationalai/semantics/lqp/rewrite/eliminate_data.py +108 -0
  88. v0/relationalai/semantics/lqp/rewrite/extract_keys.py +25 -3
  89. v0/relationalai/semantics/lqp/rewrite/period_math.py +77 -0
  90. v0/relationalai/semantics/lqp/rewrite/quantify_vars.py +65 -31
  91. v0/relationalai/semantics/lqp/rewrite/unify_definitions.py +317 -0
  92. v0/relationalai/semantics/lqp/utils.py +11 -1
  93. v0/relationalai/semantics/lqp/validators.py +14 -1
  94. v0/relationalai/semantics/metamodel/builtins.py +2 -1
  95. v0/relationalai/semantics/metamodel/compiler.py +2 -1
  96. v0/relationalai/semantics/metamodel/dependency.py +12 -3
  97. v0/relationalai/semantics/metamodel/executor.py +11 -1
  98. v0/relationalai/semantics/metamodel/factory.py +2 -2
  99. v0/relationalai/semantics/metamodel/helpers.py +7 -0
  100. v0/relationalai/semantics/metamodel/ir.py +3 -2
  101. v0/relationalai/semantics/metamodel/rewrite/dnf_union_splitter.py +30 -20
  102. v0/relationalai/semantics/metamodel/rewrite/flatten.py +50 -13
  103. v0/relationalai/semantics/metamodel/rewrite/format_outputs.py +9 -3
  104. v0/relationalai/semantics/metamodel/typer/checker.py +6 -4
  105. v0/relationalai/semantics/metamodel/typer/typer.py +4 -3
  106. v0/relationalai/semantics/metamodel/visitor.py +4 -3
  107. v0/relationalai/semantics/reasoners/optimization/solvers_dev.py +1 -1
  108. v0/relationalai/semantics/reasoners/optimization/solvers_pb.py +336 -86
  109. v0/relationalai/semantics/rel/compiler.py +2 -1
  110. v0/relationalai/semantics/rel/executor.py +3 -2
  111. v0/relationalai/semantics/tests/lqp/__init__.py +0 -0
  112. v0/relationalai/semantics/tests/lqp/algorithms.py +345 -0
  113. v0/relationalai/tools/cli.py +339 -186
  114. v0/relationalai/tools/cli_controls.py +216 -67
  115. v0/relationalai/tools/cli_helpers.py +410 -6
  116. v0/relationalai/util/format.py +5 -2
  117. {relationalai-1.0.0a3.dist-info → relationalai-1.0.0a5.dist-info}/entry_points.txt +0 -0
  118. {relationalai-1.0.0a3.dist-info → relationalai-1.0.0a5.dist-info}/top_level.txt +0 -0
@@ -6,32 +6,28 @@ Main configuration class for PyRel (YAML-based).
6
6
  from __future__ import annotations
7
7
 
8
8
  from abc import ABC
9
- from typing import Any, TypeVar, overload, Literal, TYPE_CHECKING
9
+ from typing import Any, TypeVar, overload, Literal, TYPE_CHECKING, cast
10
10
  from pydantic import Field, model_validator
11
11
  from pydantic_settings import SettingsConfigDict
12
-
13
- try:
14
- from confocal import BaseConfig
15
- except ImportError:
16
- # Confocal not yet published to PyPI - use pydantic_settings as fallback
17
- # Config system not actively used yet, this is just to prevent import errors
18
- from pydantic_settings import BaseSettings as BaseConfig # type: ignore[misc, assignment]
12
+ from confocal import BaseConfig
19
13
 
20
14
  if TYPE_CHECKING:
21
15
  import snowflake.snowpark
22
- import duckdb
16
+ import duckdb # type: ignore[import-not-found]
17
+ import requests
23
18
 
24
- from .connections import ConnectionConfig, BaseConnection, SnowflakeConnection, DuckDBConnection
19
+ from .connections import ConnectionConfig, BaseConnection, SnowflakeConnection, DuckDBConnection, LocalConnection
25
20
  from .config_fields import EngineConfig, DataConfig, CompilerConfig, ModelConfig, ReasonerConfig, DebugConfig
26
21
  from .external.dbt_converter import convert_dbt_to_rai
27
22
  from .external.snowflake_converter import convert_snowflake_to_rai
28
- from .external.utils import find_dbt_profiles_file, find_snowflake_config_file
23
+ from .external.raiconfig_converter import convert_raiconfig_to_rai
24
+ from .external.utils import find_dbt_profiles_file, find_snowflake_config_file, find_raiconfig_toml_file
29
25
 
30
26
  # TypeVar for generic connection retrieval
31
27
  T = TypeVar('T', bound=BaseConnection)
32
28
 
33
29
 
34
- class _Config(BaseConfig, ABC): # type: ignore for now until we publish confocal
30
+ class _Config(BaseConfig, ABC):
35
31
  """Base configuration class with common fields and methods."""
36
32
 
37
33
  active_profile: str | None = Field(
@@ -171,6 +167,8 @@ class _Config(BaseConfig, ABC): # type: ignore for now until we publish confocal
171
167
  connection_class = "SnowflakeConnection"
172
168
  elif isinstance(connection, DuckDBConnection):
173
169
  connection_class = "DuckDBConnection"
170
+ elif isinstance(connection, LocalConnection):
171
+ connection_class = "LocalConnection"
174
172
  else:
175
173
  connection_class = actual_type
176
174
 
@@ -179,18 +177,15 @@ class _Config(BaseConfig, ABC): # type: ignore for now until we publish confocal
179
177
  f"Got: {connection_class} (actual: {actual_type})"
180
178
  )
181
179
 
182
- return connection
183
-
184
- @overload
185
- def get_session(self, connection_type: type[SnowflakeConnection]) -> snowflake.snowpark.Session: ...
180
+ return cast(T, connection)
186
181
 
187
182
  @overload
188
- def get_session(self, connection_type: type[DuckDBConnection]) -> duckdb.DuckDBPyConnection: ...
183
+ def get_session(self) -> snowflake.snowpark.Session | duckdb.DuckDBPyConnection | requests.Session: ...
189
184
 
190
185
  @overload
191
- def get_session(self, connection_type: None = None) -> snowflake.snowpark.Session | duckdb.DuckDBPyConnection: ...
186
+ def get_session(self, connection_type: type[SnowflakeConnection] | type[DuckDBConnection] | type[LocalConnection]) -> snowflake.snowpark.Session | duckdb.DuckDBPyConnection | requests.Session: ...
192
187
 
193
- def get_session(self, connection_type: type[SnowflakeConnection] | type[DuckDBConnection] | None = None) -> snowflake.snowpark.Session | duckdb.DuckDBPyConnection:
188
+ def get_session(self, connection_type: type[SnowflakeConnection] | type[DuckDBConnection] | type[LocalConnection] | None = None) -> snowflake.snowpark.Session | duckdb.DuckDBPyConnection | requests.Session:
194
189
  if connection_type is None:
195
190
  connection = self.get_default_connection()
196
191
  return connection.get_session()
@@ -250,6 +245,35 @@ class ConfigFromDBT(_Config):
250
245
  return convert_dbt_to_rai(data)
251
246
 
252
247
 
248
+ class ConfigFromRAIConfigToml(_Config):
249
+ """Config loaded from deprecated raiconfig.toml (for backwards compatibility)"""
250
+
251
+ source: Literal["raiconfig_toml"] = Field(default="raiconfig_toml", exclude=True)
252
+
253
+ model_config = SettingsConfigDict(
254
+ toml_file=find_raiconfig_toml_file(),
255
+ extra="ignore",
256
+ nested_model_default_partial_update=True,
257
+ )
258
+
259
+ @model_validator(mode='before')
260
+ @classmethod
261
+ def convert_raiconfig_structure(cls, data: Any):
262
+ """Convert deprecated raiconfig.toml format to new format."""
263
+ if not isinstance(data, dict):
264
+ return data
265
+
266
+ # Add deprecation warning
267
+ import warnings
268
+ warnings.warn(
269
+ "raiconfig.toml is deprecated. Please migrate to raiconfig.yaml.",
270
+ DeprecationWarning,
271
+ stacklevel=2
272
+ )
273
+
274
+ return convert_raiconfig_to_rai(data)
275
+
276
+
253
277
  # =============================================================================
254
278
  # Config Factory - Tries each source in order
255
279
  # =============================================================================
@@ -258,11 +282,13 @@ def Config(**data) -> _Config:
258
282
  """
259
283
  Create Config by trying multiple sources in priority order:
260
284
  1. RAIConfig (raiconfig.yaml) - or direct data if provided
261
- 2. ConfigFromSnowflake (config.toml) - only if no data provided
262
- 3. ConfigFromDBT (profiles.yml) - only if no data provided
285
+ 2. ConfigFromRAIConfigToml (raiconfig.toml - DEPRECATED) - only if no data provided
286
+ 3. ConfigFromSnowflake (config.toml) - only if no data provided
287
+ 4. ConfigFromDBT (profiles.yml) - only if no data provided
263
288
  """
264
289
  sources = [
265
290
  ("RAIConfig (raiconfig.yaml)", RAIConfig, True), # passes **data
291
+ ("ConfigFromRAIConfigToml (raiconfig.toml - DEPRECATED)", ConfigFromRAIConfigToml, False),
266
292
  ("ConfigFromSnowflake (config.toml)", ConfigFromSnowflake, False),
267
293
  ("ConfigFromDBT (profiles.yml)", ConfigFromDBT, False),
268
294
  ]
@@ -5,8 +5,9 @@ This module exports:
5
5
  - BaseConnection: Base class for all connections
6
6
  - Snowflake authenticators: UsernamePasswordAuth, UsernamePasswordMFAAuth, etc.
7
7
  - DuckDBConnection: DuckDB connection
8
+ - LocalConnection: Local RAI server connection
8
9
  - SnowflakeConnection: Discriminated union of Snowflake authenticators
9
- - ConnectionConfig: Top-level discriminated union (Snowflake | DuckDB)
10
+ - ConnectionConfig: Top-level discriminated union (Snowflake | DuckDB | Local)
10
11
  """
11
12
 
12
13
  from __future__ import annotations
@@ -26,9 +27,10 @@ from .snowflake import (
26
27
  SnowflakeAuthenticator,
27
28
  )
28
29
  from .duckdb import DuckDBConnection
30
+ from .local import LocalConnection
29
31
 
30
32
  ConnectionConfig = Annotated[
31
- Union[SnowflakeAuthenticator, DuckDBConnection],
33
+ Union[SnowflakeAuthenticator, DuckDBConnection, LocalConnection],
32
34
  Field(discriminator="type")
33
35
  ]
34
36
 
@@ -42,5 +44,6 @@ __all__ = [
42
44
  "ProgrammaticAccessTokenAuth",
43
45
  "SnowflakeConnection",
44
46
  "DuckDBConnection",
47
+ "LocalConnection",
45
48
  "ConnectionConfig",
46
49
  ]
@@ -4,7 +4,7 @@ from typing import Literal, Any, TYPE_CHECKING
4
4
  from pydantic import Field
5
5
 
6
6
  if TYPE_CHECKING:
7
- import duckdb
7
+ import duckdb # type: ignore[import-not-found]
8
8
 
9
9
  from .base import BaseConnection
10
10
 
@@ -17,7 +17,7 @@ class DuckDBConnection(BaseConnection):
17
17
  config: dict[str, Any] | None = Field(None, description="Additional DuckDB configuration options")
18
18
 
19
19
  def get_session(self) -> duckdb.DuckDBPyConnection:
20
- import duckdb
20
+ import duckdb # type: ignore[import-not-found]
21
21
 
22
22
  if self._cached_session is None:
23
23
  self._cached_session = duckdb.connect(
@@ -0,0 +1,31 @@
1
+ """Local server connection configuration."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Literal, TYPE_CHECKING
6
+ from pydantic import Field
7
+
8
+ if TYPE_CHECKING:
9
+ import requests
10
+
11
+ from .base import BaseConnection
12
+
13
+
14
+ class LocalConnection(BaseConnection):
15
+ """Connection to a local RAI server."""
16
+
17
+ type: Literal["local"] = "local"
18
+
19
+ host: str = Field(default="localhost", description="Local server host")
20
+ port: int = Field(default=8010, description="Local server port")
21
+ engine: str | None = Field(default=None, description="Engine name")
22
+
23
+ def get_session(self) -> requests.Session:
24
+ """Return a requests.Session instance for the local server."""
25
+ # Should we migrate to using LocalClient to PyRel V1?
26
+ from v0.relationalai.clients.local import LocalClient # type: ignore[import-not-found]
27
+
28
+ if self._cached_session is None:
29
+ self._cached_session = LocalClient(host=self.host, port=self.port).http_session
30
+
31
+ return self._cached_session
@@ -71,7 +71,6 @@ class UsernamePasswordAuth(SnowflakeConnectionBase):
71
71
  "password": cast(SecretStr, self.password).get_secret_value(),
72
72
  "account": self.account,
73
73
  "warehouse": self.warehouse,
74
- "authenticator": self.authenticator,
75
74
  }
76
75
 
77
76
 
@@ -0,0 +1,235 @@
1
+ """
2
+ Converter for deprecated raiconfig.toml to RAI Config format.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ from typing import Any
8
+
9
+ from .raiconfig_models import RAIConfigFile, RAIConfigSnowflakeProfile, RAIConfigLocalProfile
10
+
11
+
12
+ # Central field mapping configuration
13
+ # Format: "old_field_name": ("new_section", "new_field_name")
14
+ # where section=None means top-level field
15
+ FIELD_MAPPINGS = {
16
+ # Top-level fields (stay at top level)
17
+ "use_graph_index": (None, "use_graph_index"),
18
+ "use_direct_access": (None, "use_direct_access"),
19
+ "enable_otel_handler": (None, "enable_otel_handler"),
20
+ "use_package_manager": (None, "use_package_manager"),
21
+ "reuse_model": (None, "reuse_model"),
22
+ "skip_invalid_data": (None, "skip_invalid_data"),
23
+
24
+ # Engine fields
25
+ "engine": ("engine", "name"),
26
+ "engine_size": ("engine", "size"),
27
+ "auto_suspend_mins": ("engine", "auto_suspend_mins"),
28
+
29
+ # Data fields
30
+ "wait_for_stream_sync": ("data", "wait_for_stream_sync"),
31
+ "ensure_change_tracking": ("data", "ensure_change_tracking"),
32
+ "data_freshness_mins": ("data", "data_freshness_mins"),
33
+ "query_timeout_mins": ("data", "query_timeout_mins"),
34
+ "download_url_type": ("data", "download_url_type"),
35
+
36
+ # Debug fields
37
+ "debug": ("debug", "enabled"), # Note: maps to "enabled"!
38
+ "show_full_traces": ("debug", "show_full_traces"),
39
+
40
+ # Compiler fields
41
+ "use_monotype_operators": ("compiler", "use_monotype_operators"),
42
+ "show_corerel_errors": ("compiler", "show_corerel_errors"),
43
+ "dry_run": ("compiler", "dry_run"),
44
+ "inspect_df": ("compiler", "inspect_df"),
45
+ "use_value_types": ("compiler", "use_value_types"),
46
+ "debug_hidden_keys": ("compiler", "debug_hidden_keys"),
47
+ "wide_outputs": ("compiler", "wide_outputs"),
48
+ "strict": ("compiler", "strict"),
49
+ "use_inlined_intermediates": ("compiler", "use_inlined_intermediates"),
50
+ "inline_value_maps": ("compiler", "inline_value_maps"),
51
+ "inline_entity_maps": ("compiler", "inline_entity_maps"),
52
+
53
+ # Model fields
54
+ "keep": ("model", "keep"),
55
+ "isolated": ("model", "isolated"),
56
+ "nowait_durable": ("model", "nowait_durable"),
57
+ }
58
+
59
+ # Connection fields (stay in connection dict)
60
+ CONNECTION_FIELDS = {
61
+ "account", "warehouse", "user", "password", "role", "database", "schema",
62
+ "authenticator", "rai_app_name", "passcode",
63
+ "private_key_path", "private_key_passphrase", "token",
64
+ }
65
+
66
+ # Old -> New field name mappings for connection fields
67
+ # Old field names from deprecated raiconfig.toml are mapped to new names
68
+ # These are accessed via model_extra in RAIConfigSnowflakeProfile.convert()
69
+ CONNECTION_FIELD_MAPPINGS = {
70
+ "private_key_file": "private_key_path",
71
+ "private_key_file_pwd": "private_key_passphrase",
72
+ "token_file_path": "token",
73
+ }
74
+
75
+
76
+ def map_fields_to_nested_structure(merged: dict[str, Any]) -> dict[str, Any]:
77
+ """
78
+ Map flat config fields to nested structure using FIELD_MAPPINGS.
79
+
80
+ This replaces all the manual if statements!
81
+
82
+ Args:
83
+ merged: Flat dictionary with all config fields
84
+
85
+ Returns:
86
+ Nested dictionary structured according to new config format
87
+ """
88
+ result = {}
89
+ nested_sections = {} # Collect fields for each section
90
+
91
+ for old_field, (section, new_field) in FIELD_MAPPINGS.items():
92
+ if old_field not in merged or merged[old_field] is None:
93
+ continue # Skip missing or None values
94
+
95
+ value = merged[old_field]
96
+
97
+ if section is None:
98
+ # Top-level field
99
+ result[new_field] = value
100
+ else:
101
+ # Nested section field
102
+ if section not in nested_sections:
103
+ nested_sections[section] = {}
104
+ nested_sections[section][new_field] = value
105
+
106
+ # Add nested sections to result
107
+ for section, fields in nested_sections.items():
108
+ if fields: # Only add non-empty sections
109
+ result[section] = fields
110
+
111
+ return result
112
+
113
+
114
+ def merge_profile_with_toplevel(
115
+ config_file: RAIConfigFile,
116
+ profile: RAIConfigSnowflakeProfile | RAIConfigLocalProfile
117
+ ) -> dict[str, Any]:
118
+ """
119
+ Merge profile over top-level config (profile takes precedence).
120
+
121
+ This matches the v0 behavior where profile-specific settings override
122
+ global defaults.
123
+ """
124
+ # Start with top-level config
125
+ merged = config_file.model_dump(
126
+ exclude={"profile", "active_profile"},
127
+ exclude_none=True
128
+ )
129
+
130
+ # Overlay profile fields (they take precedence)
131
+ profile_dict = profile.model_dump(exclude_none=True)
132
+
133
+ # Merge profile into config (profile overrides top-level)
134
+ merged.update(profile_dict)
135
+
136
+ return merged
137
+
138
+
139
+ def convert_raiconfig_to_rai(
140
+ rai_config: dict[str, Any],
141
+ profile_name: str | None = None
142
+ ) -> dict[str, Any]:
143
+ """
144
+ Convert deprecated raiconfig.toml to new RAI Config format.
145
+
146
+ Flow:
147
+ 1. Validate with RAIConfigFile model
148
+ 2. Select active profile (parameter > active_profile > single profile)
149
+ 3. Merge profile over top-level config (profile takes precedence)
150
+ 4. Extract connection from profile
151
+ 5. Map flat fields to nested structure
152
+
153
+ Args:
154
+ rai_config: Parsed TOML config dict
155
+ profile_name: Profile name to use (overrides active_profile)
156
+
157
+ Returns:
158
+ Dict in new config format ready for RAIConfig model
159
+ """
160
+ config_file = RAIConfigFile(**rai_config)
161
+
162
+ # Profile selection logic
163
+ if not profile_name:
164
+ profile_name = config_file.active_profile
165
+
166
+ # Supported platforms
167
+ supported_platforms = {"snowflake", "local"}
168
+
169
+ # Filter to supported profiles (Snowflake and Local)
170
+ supported_profiles = {
171
+ name: data for name, data in config_file.profile.items()
172
+ if data.get("platform", "snowflake") in supported_platforms
173
+ }
174
+
175
+ if not profile_name and len(supported_profiles) == 1:
176
+ # Auto-select single supported profile
177
+ profile_name = next(iter(supported_profiles.keys()))
178
+
179
+ if not profile_name:
180
+ if not supported_profiles:
181
+ raise ValueError(
182
+ "No supported profiles found in raiconfig.toml. "
183
+ "Supported platforms: snowflake, local. "
184
+ "Please add a supported profile or migrate to raiconfig.yaml."
185
+ )
186
+ raise ValueError(
187
+ f"No profile specified. Either set 'active_profile' in config, "
188
+ f"provide profile_name parameter, or ensure only one supported profile exists. "
189
+ f"Available profiles: {list(supported_profiles.keys())}"
190
+ )
191
+
192
+ if profile_name not in config_file.profile:
193
+ available = list(config_file.profile.keys())
194
+ raise ValueError(
195
+ f"Profile '{profile_name}' not found. Available profiles: {available}"
196
+ )
197
+
198
+ # Get profile data
199
+ profile_data = config_file.profile[profile_name]
200
+ platform = profile_data.get("platform", "snowflake")
201
+
202
+ # Validate and convert based on platform
203
+ if platform == "local":
204
+ profile = RAIConfigLocalProfile(**profile_data)
205
+ connection = profile.convert()
206
+ connection_name = "local"
207
+ elif platform == "snowflake":
208
+ profile = RAIConfigSnowflakeProfile(**profile_data)
209
+ connection = profile.convert()
210
+ connection_name = "snowflake"
211
+ else:
212
+ raise ValueError(
213
+ f"Profile '{profile_name}' uses platform '{platform}'. "
214
+ f"Supported platforms: snowflake, local. "
215
+ f"Please use a supported profile or migrate to raiconfig.yaml."
216
+ )
217
+
218
+ # Merge profile with top-level config
219
+ merged = merge_profile_with_toplevel(config_file, profile)
220
+
221
+ # Build base result
222
+ result = {
223
+ "connections": {connection_name: connection},
224
+ "default_connection": connection_name,
225
+ }
226
+
227
+ # Map all config fields automatically using FIELD_MAPPINGS
228
+ mapped_fields = map_fields_to_nested_structure(merged)
229
+ result.update(mapped_fields)
230
+
231
+ # Handle reasoner separately (nested dict structure)
232
+ if "reasoner" in merged and merged["reasoner"]:
233
+ result["reasoner"] = merged["reasoner"]
234
+
235
+ return result
@@ -0,0 +1,202 @@
1
+ """
2
+ Pydantic models for deprecated raiconfig.toml structure.
3
+
4
+ These models validate the structure of deprecated raiconfig.toml files before
5
+ converting them to RAI Config format.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from typing import Any, Literal
11
+ from pydantic import BaseModel, ConfigDict, Field
12
+
13
+
14
+ def normalize_authenticator(authenticator: str | None) -> str:
15
+ """Normalize old authenticator names to new format."""
16
+ if not authenticator:
17
+ return "username_password"
18
+
19
+ auth = authenticator.lower()
20
+
21
+ mapping = {
22
+ "snowflake": "username_password",
23
+ "username_password": "username_password",
24
+ "username_password_mfa": "username_password_mfa",
25
+ "externalbrowser": "externalbrowser",
26
+ "jwt": "jwt",
27
+ "snowflake_jwt": "jwt",
28
+ "oauth": "oauth",
29
+ "programmatic_access_token": "programmatic_access_token",
30
+ }
31
+
32
+ return mapping.get(auth, "username_password")
33
+
34
+
35
+ class RAIConfigSnowflakeProfile(BaseModel):
36
+ """Snowflake profile in deprecated raiconfig.toml."""
37
+ model_config = ConfigDict(extra="allow")
38
+
39
+ platform: Literal["snowflake"]
40
+
41
+ # Required connection fields
42
+ account: str
43
+ warehouse: str
44
+
45
+ # Auth fields (required by some authenticators)
46
+ user: str | None = None
47
+ password: str | None = None
48
+
49
+ # Optional connection fields
50
+ role: str | None = None
51
+ database: str | None = None
52
+ schema_: str | None = Field(default=None, alias="schema")
53
+ authenticator: str | None = None
54
+ rai_app_name: str = "RELATIONALAI"
55
+
56
+ # Auth-specific fields
57
+ passcode: str | None = None # MFA
58
+ private_key_path: str | None = None # JWT
59
+ private_key_passphrase: str | None = None
60
+ token: str | None = None # OAuth/PAT
61
+
62
+ # Optional override fields (can also be top-level)
63
+ engine: str | None = None
64
+ engine_size: str | None = None
65
+ auto_suspend_mins: int | None = None
66
+ data_freshness_mins: int | None = None
67
+ query_timeout_mins: int | None = None
68
+ download_url_type: Literal["internal", "external"] | None = None
69
+ wait_for_stream_sync: bool | None = None
70
+ ensure_change_tracking: bool | None = None
71
+
72
+ # Debug fields (can be in profile or top-level)
73
+ debug: bool | None = None
74
+ show_full_traces: bool | None = None
75
+
76
+ # Other top-level fields that can be in profiles
77
+ use_graph_index: bool | None = None
78
+ use_direct_access: bool | None = None
79
+ skip_invalid_data: bool | None = None
80
+ use_package_manager: bool | None = None
81
+ enable_otel_handler: bool | None = None
82
+ reuse_model: bool | None = None
83
+
84
+ # Compiler fields (can be in profile or top-level)
85
+ use_monotype_operators: bool | None = None
86
+ show_corerel_errors: bool | None = None
87
+ dry_run: bool | None = None
88
+ inspect_df: bool | None = None
89
+ use_value_types: bool | None = None
90
+ debug_hidden_keys: bool | None = None
91
+ wide_outputs: bool | None = None
92
+ strict: bool | None = None
93
+ use_inlined_intermediates: bool | None = None
94
+ inline_value_maps: bool | None = None
95
+ inline_entity_maps: bool | None = None
96
+
97
+ # Model fields (can be in profile or top-level)
98
+ keep: bool | None = None
99
+ isolated: bool | None = None
100
+ nowait_durable: bool | None = None
101
+
102
+ def convert(self) -> dict[str, Any]:
103
+ """Convert to new connection format."""
104
+ # Import here to avoid circular dependency
105
+ from .raiconfig_converter import FIELD_MAPPINGS, CONNECTION_FIELD_MAPPINGS
106
+
107
+ excluded_fields = set(FIELD_MAPPINGS.keys()) | {"platform"}
108
+
109
+ connection_dict = self.model_dump(
110
+ exclude_none=True,
111
+ by_alias=True,
112
+ exclude=excluded_fields
113
+ )
114
+
115
+ connection_dict["type"] = "snowflake"
116
+ connection_dict["authenticator"] = normalize_authenticator(self.authenticator)
117
+
118
+ # Map old field names to new field names (from model_extra, since old names aren't explicit fields)
119
+ extra = self.model_extra or {}
120
+ for old_name, new_name in CONNECTION_FIELD_MAPPINGS.items():
121
+ old_value = extra.get(old_name)
122
+ new_value = getattr(self, new_name, None)
123
+ # Use new name if set, otherwise fall back to old name
124
+ final_value = new_value if new_value is not None else old_value
125
+ if final_value is not None:
126
+ connection_dict[new_name] = final_value
127
+
128
+ return connection_dict
129
+
130
+
131
+ class RAIConfigLocalProfile(BaseModel):
132
+ """Local profile in deprecated raiconfig.toml."""
133
+ model_config = ConfigDict(extra="allow")
134
+
135
+ platform: Literal["local"]
136
+
137
+ # Connection fields
138
+ host: str = "localhost"
139
+ port: int = 8010
140
+ engine: str | None = None
141
+
142
+ def convert(self) -> dict[str, Any]:
143
+ """Convert to new connection format."""
144
+ return {
145
+ "type": "local",
146
+ "host": self.host,
147
+ "port": self.port,
148
+ "engine": self.engine,
149
+ }
150
+
151
+
152
+ class RAIConfigFile(BaseModel):
153
+ """Root model for deprecated raiconfig.toml."""
154
+ model_config = ConfigDict(extra="allow")
155
+
156
+ # Active profile selection
157
+ active_profile: str | None = None
158
+
159
+ # Top-level config fields (can be overridden by profile)
160
+ debug: bool = True
161
+ show_full_traces: bool = False
162
+ use_graph_index: bool = True
163
+ use_direct_access: bool = False
164
+ skip_invalid_data: bool | None = None
165
+ use_package_manager: bool = True
166
+ enable_otel_handler: bool = False
167
+ reuse_model: bool = True
168
+
169
+ # Engine config (top-level defaults)
170
+ auto_suspend_mins: int | None = None
171
+ engine_size: str | None = None
172
+
173
+ # Data config (top-level defaults)
174
+ wait_for_stream_sync: bool = True
175
+ ensure_change_tracking: bool = False
176
+ data_freshness_mins: int | None = None
177
+ query_timeout_mins: int | None = None
178
+ download_url_type: Literal["internal", "external"] | None = None
179
+
180
+ # Compiler config
181
+ use_monotype_operators: bool | None = None
182
+ show_corerel_errors: bool | None = None
183
+ dry_run: bool | None = None
184
+ inspect_df: bool | None = None
185
+ use_value_types: bool | None = None
186
+ debug_hidden_keys: bool | None = None
187
+ wide_outputs: bool | None = None
188
+ strict: bool | None = None
189
+ use_inlined_intermediates: bool | None = None
190
+ inline_value_maps: bool | None = None
191
+ inline_entity_maps: bool | None = None
192
+
193
+ # Model config
194
+ keep: bool | None = None
195
+ isolated: bool | None = None
196
+ nowait_durable: bool | None = None
197
+
198
+ # Reasoner config (nested [reasoner.rule] sections)
199
+ reasoner: dict[str, Any] | None = None
200
+
201
+ # Profiles (accept any structure, validate Snowflake profiles on-demand)
202
+ profile: dict[str, dict[str, Any]] = Field(default_factory=dict)
@@ -17,3 +17,34 @@ def find_snowflake_config_file() -> str | None:
17
17
  if os.path.exists(path):
18
18
  return path
19
19
  return None
20
+
21
+
22
+ def find_raiconfig_toml_file() -> str | None:
23
+ """
24
+ Find raiconfig.toml file.
25
+
26
+ Search order:
27
+ 1. Current directory upwards (project-level)
28
+ 2. ~/.rai/raiconfig.toml (user-level)
29
+
30
+ Returns:
31
+ Path to raiconfig.toml if found, None otherwise.
32
+ """
33
+ # Search upwards from current directory
34
+ current_dir = os.path.abspath(os.getcwd())
35
+ while True:
36
+ file_path = os.path.join(current_dir, "raiconfig.toml")
37
+ if os.path.exists(file_path):
38
+ return file_path
39
+
40
+ parent_dir = os.path.dirname(current_dir)
41
+ if parent_dir == current_dir: # Reached root
42
+ break
43
+ current_dir = parent_dir
44
+
45
+ # Check user home directory
46
+ user_config_path = os.path.expanduser("~/.rai/raiconfig.toml")
47
+ if os.path.exists(user_config_path):
48
+ return user_config_path
49
+
50
+ return None
@@ -1 +1,2 @@
1
1
  DRY_RUN=False
2
+ ENFORCE_TYPE_CORRECT=False