relationalai 0.12.13__py3-none-any.whl → 0.13.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- relationalai/__init__.py +1 -209
- relationalai/config/__init__.py +56 -0
- relationalai/config/config.py +289 -0
- relationalai/config/config_fields.py +86 -0
- relationalai/config/connections/__init__.py +46 -0
- relationalai/config/connections/base.py +23 -0
- relationalai/config/connections/duckdb.py +29 -0
- relationalai/config/connections/snowflake.py +243 -0
- relationalai/config/external/__init__.py +17 -0
- relationalai/config/external/dbt_converter.py +101 -0
- relationalai/config/external/dbt_models.py +93 -0
- relationalai/config/external/snowflake_converter.py +41 -0
- relationalai/config/external/snowflake_models.py +85 -0
- relationalai/config/external/utils.py +19 -0
- relationalai/semantics/__init__.py +146 -22
- relationalai/semantics/backends/lqp/annotations.py +11 -0
- relationalai/semantics/backends/sql/sql_compiler.py +327 -0
- relationalai/semantics/frontend/base.py +1707 -0
- relationalai/semantics/frontend/core.py +179 -0
- relationalai/semantics/frontend/front_compiler.py +1313 -0
- relationalai/semantics/frontend/pprint.py +408 -0
- relationalai/semantics/metamodel/__init__.py +6 -40
- relationalai/semantics/metamodel/builtins.py +205 -769
- relationalai/semantics/metamodel/metamodel.py +437 -0
- relationalai/semantics/metamodel/metamodel_analyzer.py +519 -0
- relationalai/semantics/metamodel/pprint.py +412 -0
- relationalai/semantics/metamodel/rewriter.py +266 -0
- relationalai/semantics/metamodel/typer.py +1378 -0
- relationalai/semantics/std/__init__.py +60 -40
- relationalai/semantics/std/aggregates.py +149 -0
- relationalai/semantics/std/common.py +44 -0
- relationalai/semantics/std/constraints.py +37 -43
- relationalai/semantics/std/datetime.py +246 -135
- relationalai/semantics/std/decimals.py +45 -52
- relationalai/semantics/std/floats.py +13 -5
- relationalai/semantics/std/integers.py +26 -11
- relationalai/semantics/std/math.py +183 -112
- relationalai/semantics/std/numbers.py +86 -0
- relationalai/semantics/std/re.py +80 -62
- relationalai/semantics/std/strings.py +117 -60
- relationalai/shims/executor.py +147 -0
- relationalai/shims/helpers.py +126 -0
- relationalai/shims/hoister.py +221 -0
- relationalai/shims/mm2v0.py +1290 -0
- relationalai/tools/cli/__init__.py +6 -0
- relationalai/tools/cli/cli.py +90 -0
- relationalai/tools/cli/components/__init__.py +5 -0
- relationalai/tools/cli/components/progress_reader.py +1524 -0
- relationalai/tools/cli/components/utils.py +58 -0
- relationalai/tools/cli/config_template.py +45 -0
- relationalai/tools/cli/dev.py +19 -0
- relationalai/tools/debugger.py +289 -183
- relationalai/tools/typer_debugger.py +93 -0
- relationalai/util/dataclasses.py +43 -0
- relationalai/util/docutils.py +40 -0
- relationalai/util/error.py +199 -0
- relationalai/util/format.py +48 -106
- relationalai/util/naming.py +145 -0
- relationalai/util/python.py +35 -0
- relationalai/util/runtime.py +156 -0
- relationalai/util/schema.py +197 -0
- relationalai/util/source.py +185 -0
- relationalai/util/structures.py +163 -0
- relationalai/util/tracing.py +261 -0
- relationalai-0.13.0.dev0.dist-info/METADATA +46 -0
- relationalai-0.13.0.dev0.dist-info/RECORD +488 -0
- relationalai-0.13.0.dev0.dist-info/WHEEL +5 -0
- relationalai-0.13.0.dev0.dist-info/entry_points.txt +3 -0
- relationalai-0.13.0.dev0.dist-info/top_level.txt +2 -0
- v0/relationalai/__init__.py +216 -0
- v0/relationalai/clients/azure.py +477 -0
- v0/relationalai/clients/client.py +912 -0
- v0/relationalai/clients/config.py +673 -0
- v0/relationalai/clients/direct_access_client.py +118 -0
- v0/relationalai/clients/hash_util.py +31 -0
- v0/relationalai/clients/local.py +571 -0
- v0/relationalai/clients/profile_polling.py +73 -0
- v0/relationalai/clients/result_helpers.py +420 -0
- v0/relationalai/clients/snowflake.py +3869 -0
- v0/relationalai/clients/types.py +113 -0
- v0/relationalai/clients/use_index_poller.py +980 -0
- v0/relationalai/clients/util.py +356 -0
- v0/relationalai/debugging.py +389 -0
- v0/relationalai/dsl.py +1749 -0
- v0/relationalai/early_access/builder/__init__.py +30 -0
- v0/relationalai/early_access/builder/builder/__init__.py +35 -0
- v0/relationalai/early_access/builder/snowflake/__init__.py +12 -0
- v0/relationalai/early_access/builder/std/__init__.py +25 -0
- v0/relationalai/early_access/builder/std/decimals/__init__.py +12 -0
- v0/relationalai/early_access/builder/std/integers/__init__.py +12 -0
- v0/relationalai/early_access/builder/std/math/__init__.py +12 -0
- v0/relationalai/early_access/builder/std/strings/__init__.py +14 -0
- v0/relationalai/early_access/devtools/__init__.py +12 -0
- v0/relationalai/early_access/devtools/benchmark_lqp/__init__.py +12 -0
- v0/relationalai/early_access/devtools/extract_lqp/__init__.py +12 -0
- v0/relationalai/early_access/dsl/adapters/orm/adapter_qb.py +427 -0
- v0/relationalai/early_access/dsl/adapters/orm/parser.py +636 -0
- v0/relationalai/early_access/dsl/adapters/owl/adapter.py +176 -0
- v0/relationalai/early_access/dsl/adapters/owl/parser.py +160 -0
- v0/relationalai/early_access/dsl/bindings/common.py +402 -0
- v0/relationalai/early_access/dsl/bindings/csv.py +170 -0
- v0/relationalai/early_access/dsl/bindings/legacy/binding_models.py +143 -0
- v0/relationalai/early_access/dsl/bindings/snowflake.py +64 -0
- v0/relationalai/early_access/dsl/codegen/binder.py +411 -0
- v0/relationalai/early_access/dsl/codegen/common.py +79 -0
- v0/relationalai/early_access/dsl/codegen/helpers.py +23 -0
- v0/relationalai/early_access/dsl/codegen/relations.py +700 -0
- v0/relationalai/early_access/dsl/codegen/weaver.py +417 -0
- v0/relationalai/early_access/dsl/core/builders/__init__.py +47 -0
- v0/relationalai/early_access/dsl/core/builders/logic.py +19 -0
- v0/relationalai/early_access/dsl/core/builders/scalar_constraint.py +11 -0
- v0/relationalai/early_access/dsl/core/constraints/predicate/atomic.py +455 -0
- v0/relationalai/early_access/dsl/core/constraints/predicate/universal.py +73 -0
- v0/relationalai/early_access/dsl/core/constraints/scalar.py +310 -0
- v0/relationalai/early_access/dsl/core/context.py +13 -0
- v0/relationalai/early_access/dsl/core/cset.py +132 -0
- v0/relationalai/early_access/dsl/core/exprs/__init__.py +116 -0
- v0/relationalai/early_access/dsl/core/exprs/relational.py +18 -0
- v0/relationalai/early_access/dsl/core/exprs/scalar.py +412 -0
- v0/relationalai/early_access/dsl/core/instances.py +44 -0
- v0/relationalai/early_access/dsl/core/logic/__init__.py +193 -0
- v0/relationalai/early_access/dsl/core/logic/aggregation.py +98 -0
- v0/relationalai/early_access/dsl/core/logic/exists.py +223 -0
- v0/relationalai/early_access/dsl/core/logic/helper.py +163 -0
- v0/relationalai/early_access/dsl/core/namespaces.py +32 -0
- v0/relationalai/early_access/dsl/core/relations.py +276 -0
- v0/relationalai/early_access/dsl/core/rules.py +112 -0
- v0/relationalai/early_access/dsl/core/std/__init__.py +45 -0
- v0/relationalai/early_access/dsl/core/temporal/recall.py +6 -0
- v0/relationalai/early_access/dsl/core/types/__init__.py +270 -0
- v0/relationalai/early_access/dsl/core/types/concepts.py +128 -0
- v0/relationalai/early_access/dsl/core/types/constrained/__init__.py +267 -0
- v0/relationalai/early_access/dsl/core/types/constrained/nominal.py +143 -0
- v0/relationalai/early_access/dsl/core/types/constrained/subtype.py +124 -0
- v0/relationalai/early_access/dsl/core/types/standard.py +92 -0
- v0/relationalai/early_access/dsl/core/types/unconstrained.py +50 -0
- v0/relationalai/early_access/dsl/core/types/variables.py +203 -0
- v0/relationalai/early_access/dsl/ir/compiler.py +318 -0
- v0/relationalai/early_access/dsl/ir/executor.py +260 -0
- v0/relationalai/early_access/dsl/ontologies/constraints.py +88 -0
- v0/relationalai/early_access/dsl/ontologies/export.py +30 -0
- v0/relationalai/early_access/dsl/ontologies/models.py +453 -0
- v0/relationalai/early_access/dsl/ontologies/python_printer.py +303 -0
- v0/relationalai/early_access/dsl/ontologies/readings.py +60 -0
- v0/relationalai/early_access/dsl/ontologies/relationships.py +322 -0
- v0/relationalai/early_access/dsl/ontologies/roles.py +87 -0
- v0/relationalai/early_access/dsl/ontologies/subtyping.py +55 -0
- v0/relationalai/early_access/dsl/orm/constraints.py +438 -0
- v0/relationalai/early_access/dsl/orm/measures/dimensions.py +200 -0
- v0/relationalai/early_access/dsl/orm/measures/initializer.py +16 -0
- v0/relationalai/early_access/dsl/orm/measures/measure_rules.py +275 -0
- v0/relationalai/early_access/dsl/orm/measures/measures.py +299 -0
- v0/relationalai/early_access/dsl/orm/measures/role_exprs.py +268 -0
- v0/relationalai/early_access/dsl/orm/models.py +256 -0
- v0/relationalai/early_access/dsl/orm/object_oriented_printer.py +344 -0
- v0/relationalai/early_access/dsl/orm/printer.py +469 -0
- v0/relationalai/early_access/dsl/orm/reasoners.py +480 -0
- v0/relationalai/early_access/dsl/orm/relations.py +19 -0
- v0/relationalai/early_access/dsl/orm/relationships.py +251 -0
- v0/relationalai/early_access/dsl/orm/types.py +42 -0
- v0/relationalai/early_access/dsl/orm/utils.py +79 -0
- v0/relationalai/early_access/dsl/orm/verb.py +204 -0
- v0/relationalai/early_access/dsl/physical_metadata/tables.py +133 -0
- v0/relationalai/early_access/dsl/relations.py +170 -0
- v0/relationalai/early_access/dsl/rulesets.py +69 -0
- v0/relationalai/early_access/dsl/schemas/__init__.py +450 -0
- v0/relationalai/early_access/dsl/schemas/builder.py +48 -0
- v0/relationalai/early_access/dsl/schemas/comp_names.py +51 -0
- v0/relationalai/early_access/dsl/schemas/components.py +203 -0
- v0/relationalai/early_access/dsl/schemas/contexts.py +156 -0
- v0/relationalai/early_access/dsl/schemas/exprs.py +89 -0
- v0/relationalai/early_access/dsl/schemas/fragments.py +464 -0
- v0/relationalai/early_access/dsl/serialization.py +79 -0
- v0/relationalai/early_access/dsl/serialize/exporter.py +163 -0
- v0/relationalai/early_access/dsl/snow/api.py +104 -0
- v0/relationalai/early_access/dsl/snow/common.py +76 -0
- v0/relationalai/early_access/dsl/state_mgmt/__init__.py +129 -0
- v0/relationalai/early_access/dsl/state_mgmt/state_charts.py +125 -0
- v0/relationalai/early_access/dsl/state_mgmt/transitions.py +130 -0
- v0/relationalai/early_access/dsl/types/__init__.py +40 -0
- v0/relationalai/early_access/dsl/types/concepts.py +12 -0
- v0/relationalai/early_access/dsl/types/entities.py +135 -0
- v0/relationalai/early_access/dsl/types/values.py +17 -0
- v0/relationalai/early_access/dsl/utils.py +102 -0
- v0/relationalai/early_access/graphs/__init__.py +13 -0
- v0/relationalai/early_access/lqp/__init__.py +12 -0
- v0/relationalai/early_access/lqp/compiler/__init__.py +12 -0
- v0/relationalai/early_access/lqp/constructors/__init__.py +18 -0
- v0/relationalai/early_access/lqp/executor/__init__.py +12 -0
- v0/relationalai/early_access/lqp/ir/__init__.py +12 -0
- v0/relationalai/early_access/lqp/passes/__init__.py +12 -0
- v0/relationalai/early_access/lqp/pragmas/__init__.py +12 -0
- v0/relationalai/early_access/lqp/primitives/__init__.py +12 -0
- v0/relationalai/early_access/lqp/types/__init__.py +12 -0
- v0/relationalai/early_access/lqp/utils/__init__.py +12 -0
- v0/relationalai/early_access/lqp/validators/__init__.py +12 -0
- v0/relationalai/early_access/metamodel/__init__.py +58 -0
- v0/relationalai/early_access/metamodel/builtins/__init__.py +12 -0
- v0/relationalai/early_access/metamodel/compiler/__init__.py +12 -0
- v0/relationalai/early_access/metamodel/dependency/__init__.py +12 -0
- v0/relationalai/early_access/metamodel/factory/__init__.py +17 -0
- v0/relationalai/early_access/metamodel/helpers/__init__.py +12 -0
- v0/relationalai/early_access/metamodel/ir/__init__.py +14 -0
- v0/relationalai/early_access/metamodel/rewrite/__init__.py +7 -0
- v0/relationalai/early_access/metamodel/typer/__init__.py +3 -0
- v0/relationalai/early_access/metamodel/typer/typer/__init__.py +12 -0
- v0/relationalai/early_access/metamodel/types/__init__.py +15 -0
- v0/relationalai/early_access/metamodel/util/__init__.py +15 -0
- v0/relationalai/early_access/metamodel/visitor/__init__.py +12 -0
- v0/relationalai/early_access/rel/__init__.py +12 -0
- v0/relationalai/early_access/rel/executor/__init__.py +12 -0
- v0/relationalai/early_access/rel/rel_utils/__init__.py +12 -0
- v0/relationalai/early_access/rel/rewrite/__init__.py +7 -0
- v0/relationalai/early_access/solvers/__init__.py +19 -0
- v0/relationalai/early_access/sql/__init__.py +11 -0
- v0/relationalai/early_access/sql/executor/__init__.py +3 -0
- v0/relationalai/early_access/sql/rewrite/__init__.py +3 -0
- v0/relationalai/early_access/tests/logging/__init__.py +12 -0
- v0/relationalai/early_access/tests/test_snapshot_base/__init__.py +12 -0
- v0/relationalai/early_access/tests/utils/__init__.py +12 -0
- v0/relationalai/environments/__init__.py +35 -0
- v0/relationalai/environments/base.py +381 -0
- v0/relationalai/environments/colab.py +14 -0
- v0/relationalai/environments/generic.py +71 -0
- v0/relationalai/environments/ipython.py +68 -0
- v0/relationalai/environments/jupyter.py +9 -0
- v0/relationalai/environments/snowbook.py +169 -0
- v0/relationalai/errors.py +2455 -0
- v0/relationalai/experimental/SF.py +38 -0
- v0/relationalai/experimental/inspect.py +47 -0
- v0/relationalai/experimental/pathfinder/__init__.py +158 -0
- v0/relationalai/experimental/pathfinder/api.py +160 -0
- v0/relationalai/experimental/pathfinder/automaton.py +584 -0
- v0/relationalai/experimental/pathfinder/bridge.py +226 -0
- v0/relationalai/experimental/pathfinder/compiler.py +416 -0
- v0/relationalai/experimental/pathfinder/datalog.py +214 -0
- v0/relationalai/experimental/pathfinder/diagnostics.py +56 -0
- v0/relationalai/experimental/pathfinder/filter.py +236 -0
- v0/relationalai/experimental/pathfinder/glushkov.py +439 -0
- v0/relationalai/experimental/pathfinder/options.py +265 -0
- v0/relationalai/experimental/pathfinder/rpq.py +344 -0
- v0/relationalai/experimental/pathfinder/transition.py +200 -0
- v0/relationalai/experimental/pathfinder/utils.py +26 -0
- v0/relationalai/experimental/paths/api.py +143 -0
- v0/relationalai/experimental/paths/benchmarks/grid_graph.py +37 -0
- v0/relationalai/experimental/paths/examples/basic_example.py +40 -0
- v0/relationalai/experimental/paths/examples/minimal_engine_warmup.py +3 -0
- v0/relationalai/experimental/paths/examples/movie_example.py +77 -0
- v0/relationalai/experimental/paths/examples/paths_benchmark.py +115 -0
- v0/relationalai/experimental/paths/examples/paths_example.py +116 -0
- v0/relationalai/experimental/paths/examples/pattern_to_automaton.py +28 -0
- v0/relationalai/experimental/paths/find_paths_via_automaton.py +85 -0
- v0/relationalai/experimental/paths/graph.py +185 -0
- v0/relationalai/experimental/paths/path_algorithms/find_paths.py +280 -0
- v0/relationalai/experimental/paths/path_algorithms/one_sided_ball_repetition.py +26 -0
- v0/relationalai/experimental/paths/path_algorithms/one_sided_ball_upto.py +111 -0
- v0/relationalai/experimental/paths/path_algorithms/single.py +59 -0
- v0/relationalai/experimental/paths/path_algorithms/two_sided_balls_repetition.py +39 -0
- v0/relationalai/experimental/paths/path_algorithms/two_sided_balls_upto.py +103 -0
- v0/relationalai/experimental/paths/path_algorithms/usp-old.py +130 -0
- v0/relationalai/experimental/paths/path_algorithms/usp-tuple.py +183 -0
- v0/relationalai/experimental/paths/path_algorithms/usp.py +150 -0
- v0/relationalai/experimental/paths/product_graph.py +93 -0
- v0/relationalai/experimental/paths/rpq/automaton.py +584 -0
- v0/relationalai/experimental/paths/rpq/diagnostics.py +56 -0
- v0/relationalai/experimental/paths/rpq/rpq.py +378 -0
- v0/relationalai/experimental/paths/tests/tests_limit_sp_max_length.py +90 -0
- v0/relationalai/experimental/paths/tests/tests_limit_sp_multiple.py +119 -0
- v0/relationalai/experimental/paths/tests/tests_limit_sp_single.py +104 -0
- v0/relationalai/experimental/paths/tests/tests_limit_walks_multiple.py +113 -0
- v0/relationalai/experimental/paths/tests/tests_limit_walks_single.py +149 -0
- v0/relationalai/experimental/paths/tests/tests_one_sided_ball_repetition_multiple.py +70 -0
- v0/relationalai/experimental/paths/tests/tests_one_sided_ball_repetition_single.py +64 -0
- v0/relationalai/experimental/paths/tests/tests_one_sided_ball_upto_multiple.py +115 -0
- v0/relationalai/experimental/paths/tests/tests_one_sided_ball_upto_single.py +75 -0
- v0/relationalai/experimental/paths/tests/tests_single_paths.py +152 -0
- v0/relationalai/experimental/paths/tests/tests_single_walks.py +208 -0
- v0/relationalai/experimental/paths/tests/tests_single_walks_undirected.py +297 -0
- v0/relationalai/experimental/paths/tests/tests_two_sided_balls_repetition_multiple.py +107 -0
- v0/relationalai/experimental/paths/tests/tests_two_sided_balls_repetition_single.py +76 -0
- v0/relationalai/experimental/paths/tests/tests_two_sided_balls_upto_multiple.py +76 -0
- v0/relationalai/experimental/paths/tests/tests_two_sided_balls_upto_single.py +110 -0
- v0/relationalai/experimental/paths/tests/tests_usp_nsp_multiple.py +229 -0
- v0/relationalai/experimental/paths/tests/tests_usp_nsp_single.py +108 -0
- v0/relationalai/experimental/paths/tree_agg.py +168 -0
- v0/relationalai/experimental/paths/utilities/iterators.py +27 -0
- v0/relationalai/experimental/paths/utilities/prefix_sum.py +91 -0
- v0/relationalai/experimental/solvers.py +1087 -0
- v0/relationalai/loaders/__init__.py +0 -0
- v0/relationalai/loaders/csv.py +195 -0
- v0/relationalai/loaders/loader.py +177 -0
- v0/relationalai/loaders/types.py +23 -0
- v0/relationalai/rel_emitter.py +373 -0
- v0/relationalai/rel_utils.py +185 -0
- v0/relationalai/semantics/__init__.py +29 -0
- v0/relationalai/semantics/devtools/benchmark_lqp.py +536 -0
- v0/relationalai/semantics/devtools/compilation_manager.py +294 -0
- v0/relationalai/semantics/devtools/extract_lqp.py +110 -0
- v0/relationalai/semantics/internal/internal.py +3785 -0
- v0/relationalai/semantics/internal/snowflake.py +324 -0
- v0/relationalai/semantics/lqp/builtins.py +16 -0
- v0/relationalai/semantics/lqp/compiler.py +22 -0
- v0/relationalai/semantics/lqp/constructors.py +68 -0
- v0/relationalai/semantics/lqp/executor.py +469 -0
- v0/relationalai/semantics/lqp/intrinsics.py +24 -0
- v0/relationalai/semantics/lqp/model2lqp.py +839 -0
- v0/relationalai/semantics/lqp/passes.py +680 -0
- v0/relationalai/semantics/lqp/primitives.py +252 -0
- v0/relationalai/semantics/lqp/result_helpers.py +202 -0
- v0/relationalai/semantics/lqp/rewrite/annotate_constraints.py +57 -0
- v0/relationalai/semantics/lqp/rewrite/cdc.py +216 -0
- v0/relationalai/semantics/lqp/rewrite/extract_common.py +338 -0
- v0/relationalai/semantics/lqp/rewrite/extract_keys.py +449 -0
- v0/relationalai/semantics/lqp/rewrite/function_annotations.py +114 -0
- v0/relationalai/semantics/lqp/rewrite/functional_dependencies.py +314 -0
- v0/relationalai/semantics/lqp/rewrite/quantify_vars.py +296 -0
- v0/relationalai/semantics/lqp/rewrite/splinter.py +76 -0
- v0/relationalai/semantics/lqp/types.py +101 -0
- v0/relationalai/semantics/lqp/utils.py +160 -0
- v0/relationalai/semantics/lqp/validators.py +57 -0
- v0/relationalai/semantics/metamodel/__init__.py +40 -0
- v0/relationalai/semantics/metamodel/builtins.py +774 -0
- v0/relationalai/semantics/metamodel/compiler.py +133 -0
- v0/relationalai/semantics/metamodel/dependency.py +862 -0
- v0/relationalai/semantics/metamodel/executor.py +61 -0
- v0/relationalai/semantics/metamodel/factory.py +287 -0
- v0/relationalai/semantics/metamodel/helpers.py +361 -0
- v0/relationalai/semantics/metamodel/rewrite/discharge_constraints.py +39 -0
- v0/relationalai/semantics/metamodel/rewrite/dnf_union_splitter.py +210 -0
- v0/relationalai/semantics/metamodel/rewrite/extract_nested_logicals.py +78 -0
- v0/relationalai/semantics/metamodel/rewrite/flatten.py +549 -0
- v0/relationalai/semantics/metamodel/rewrite/format_outputs.py +165 -0
- v0/relationalai/semantics/metamodel/typer/checker.py +353 -0
- v0/relationalai/semantics/metamodel/typer/typer.py +1395 -0
- v0/relationalai/semantics/reasoners/__init__.py +10 -0
- v0/relationalai/semantics/reasoners/graph/__init__.py +37 -0
- v0/relationalai/semantics/reasoners/graph/core.py +9020 -0
- v0/relationalai/semantics/reasoners/optimization/__init__.py +68 -0
- v0/relationalai/semantics/reasoners/optimization/common.py +88 -0
- v0/relationalai/semantics/reasoners/optimization/solvers_dev.py +568 -0
- v0/relationalai/semantics/reasoners/optimization/solvers_pb.py +1163 -0
- v0/relationalai/semantics/rel/builtins.py +40 -0
- v0/relationalai/semantics/rel/compiler.py +989 -0
- v0/relationalai/semantics/rel/executor.py +359 -0
- v0/relationalai/semantics/rel/rel.py +482 -0
- v0/relationalai/semantics/rel/rel_utils.py +276 -0
- v0/relationalai/semantics/snowflake/__init__.py +3 -0
- v0/relationalai/semantics/sql/compiler.py +2503 -0
- v0/relationalai/semantics/sql/executor/duck_db.py +52 -0
- v0/relationalai/semantics/sql/executor/result_helpers.py +64 -0
- v0/relationalai/semantics/sql/executor/snowflake.py +145 -0
- v0/relationalai/semantics/sql/rewrite/denormalize.py +222 -0
- v0/relationalai/semantics/sql/rewrite/double_negation.py +49 -0
- v0/relationalai/semantics/sql/rewrite/recursive_union.py +127 -0
- v0/relationalai/semantics/sql/rewrite/sort_output_query.py +246 -0
- v0/relationalai/semantics/sql/sql.py +504 -0
- v0/relationalai/semantics/std/__init__.py +54 -0
- v0/relationalai/semantics/std/constraints.py +43 -0
- v0/relationalai/semantics/std/datetime.py +363 -0
- v0/relationalai/semantics/std/decimals.py +62 -0
- v0/relationalai/semantics/std/floats.py +7 -0
- v0/relationalai/semantics/std/integers.py +22 -0
- v0/relationalai/semantics/std/math.py +141 -0
- v0/relationalai/semantics/std/pragmas.py +11 -0
- v0/relationalai/semantics/std/re.py +83 -0
- v0/relationalai/semantics/std/std.py +14 -0
- v0/relationalai/semantics/std/strings.py +63 -0
- v0/relationalai/semantics/tests/__init__.py +0 -0
- v0/relationalai/semantics/tests/test_snapshot_abstract.py +143 -0
- v0/relationalai/semantics/tests/test_snapshot_base.py +9 -0
- v0/relationalai/semantics/tests/utils.py +46 -0
- v0/relationalai/std/__init__.py +70 -0
- v0/relationalai/tools/__init__.py +0 -0
- v0/relationalai/tools/cli.py +1940 -0
- v0/relationalai/tools/cli_controls.py +1826 -0
- v0/relationalai/tools/cli_helpers.py +390 -0
- v0/relationalai/tools/debugger.py +183 -0
- v0/relationalai/tools/debugger_client.py +109 -0
- v0/relationalai/tools/debugger_server.py +302 -0
- v0/relationalai/tools/dev.py +685 -0
- v0/relationalai/tools/qb_debugger.py +425 -0
- v0/relationalai/util/clean_up_databases.py +95 -0
- v0/relationalai/util/format.py +123 -0
- v0/relationalai/util/list_databases.py +9 -0
- v0/relationalai/util/otel_configuration.py +25 -0
- v0/relationalai/util/otel_handler.py +484 -0
- v0/relationalai/util/snowflake_handler.py +88 -0
- v0/relationalai/util/span_format_test.py +43 -0
- v0/relationalai/util/span_tracker.py +207 -0
- v0/relationalai/util/spans_file_handler.py +72 -0
- v0/relationalai/util/tracing_handler.py +34 -0
- frontend/debugger/dist/.gitignore +0 -2
- frontend/debugger/dist/assets/favicon-Dy0ZgA6N.png +0 -0
- frontend/debugger/dist/assets/index-Cssla-O7.js +0 -208
- frontend/debugger/dist/assets/index-DlHsYx1V.css +0 -9
- frontend/debugger/dist/index.html +0 -17
- relationalai/clients/azure.py +0 -477
- relationalai/clients/client.py +0 -912
- relationalai/clients/config.py +0 -673
- relationalai/clients/direct_access_client.py +0 -118
- relationalai/clients/export_procedure.py.jinja +0 -249
- relationalai/clients/hash_util.py +0 -31
- relationalai/clients/local.py +0 -571
- relationalai/clients/profile_polling.py +0 -73
- relationalai/clients/result_helpers.py +0 -420
- relationalai/clients/snowflake.py +0 -3869
- relationalai/clients/types.py +0 -113
- relationalai/clients/use_index_poller.py +0 -980
- relationalai/clients/util.py +0 -356
- relationalai/debugging.py +0 -389
- relationalai/dsl.py +0 -1749
- relationalai/early_access/builder/__init__.py +0 -30
- relationalai/early_access/builder/builder/__init__.py +0 -35
- relationalai/early_access/builder/snowflake/__init__.py +0 -12
- relationalai/early_access/builder/std/__init__.py +0 -25
- relationalai/early_access/builder/std/decimals/__init__.py +0 -12
- relationalai/early_access/builder/std/integers/__init__.py +0 -12
- relationalai/early_access/builder/std/math/__init__.py +0 -12
- relationalai/early_access/builder/std/strings/__init__.py +0 -14
- relationalai/early_access/devtools/__init__.py +0 -12
- relationalai/early_access/devtools/benchmark_lqp/__init__.py +0 -12
- relationalai/early_access/devtools/extract_lqp/__init__.py +0 -12
- relationalai/early_access/dsl/adapters/orm/adapter_qb.py +0 -427
- relationalai/early_access/dsl/adapters/orm/parser.py +0 -636
- relationalai/early_access/dsl/adapters/owl/adapter.py +0 -176
- relationalai/early_access/dsl/adapters/owl/parser.py +0 -160
- relationalai/early_access/dsl/bindings/common.py +0 -402
- relationalai/early_access/dsl/bindings/csv.py +0 -170
- relationalai/early_access/dsl/bindings/legacy/binding_models.py +0 -143
- relationalai/early_access/dsl/bindings/snowflake.py +0 -64
- relationalai/early_access/dsl/codegen/binder.py +0 -411
- relationalai/early_access/dsl/codegen/common.py +0 -79
- relationalai/early_access/dsl/codegen/helpers.py +0 -23
- relationalai/early_access/dsl/codegen/relations.py +0 -700
- relationalai/early_access/dsl/codegen/weaver.py +0 -417
- relationalai/early_access/dsl/core/builders/__init__.py +0 -47
- relationalai/early_access/dsl/core/builders/logic.py +0 -19
- relationalai/early_access/dsl/core/builders/scalar_constraint.py +0 -11
- relationalai/early_access/dsl/core/constraints/predicate/atomic.py +0 -455
- relationalai/early_access/dsl/core/constraints/predicate/universal.py +0 -73
- relationalai/early_access/dsl/core/constraints/scalar.py +0 -310
- relationalai/early_access/dsl/core/context.py +0 -13
- relationalai/early_access/dsl/core/cset.py +0 -132
- relationalai/early_access/dsl/core/exprs/__init__.py +0 -116
- relationalai/early_access/dsl/core/exprs/relational.py +0 -18
- relationalai/early_access/dsl/core/exprs/scalar.py +0 -412
- relationalai/early_access/dsl/core/instances.py +0 -44
- relationalai/early_access/dsl/core/logic/__init__.py +0 -193
- relationalai/early_access/dsl/core/logic/aggregation.py +0 -98
- relationalai/early_access/dsl/core/logic/exists.py +0 -223
- relationalai/early_access/dsl/core/logic/helper.py +0 -163
- relationalai/early_access/dsl/core/namespaces.py +0 -32
- relationalai/early_access/dsl/core/relations.py +0 -276
- relationalai/early_access/dsl/core/rules.py +0 -112
- relationalai/early_access/dsl/core/std/__init__.py +0 -45
- relationalai/early_access/dsl/core/temporal/recall.py +0 -6
- relationalai/early_access/dsl/core/types/__init__.py +0 -270
- relationalai/early_access/dsl/core/types/concepts.py +0 -128
- relationalai/early_access/dsl/core/types/constrained/__init__.py +0 -267
- relationalai/early_access/dsl/core/types/constrained/nominal.py +0 -143
- relationalai/early_access/dsl/core/types/constrained/subtype.py +0 -124
- relationalai/early_access/dsl/core/types/standard.py +0 -92
- relationalai/early_access/dsl/core/types/unconstrained.py +0 -50
- relationalai/early_access/dsl/core/types/variables.py +0 -203
- relationalai/early_access/dsl/ir/compiler.py +0 -318
- relationalai/early_access/dsl/ir/executor.py +0 -260
- relationalai/early_access/dsl/ontologies/constraints.py +0 -88
- relationalai/early_access/dsl/ontologies/export.py +0 -30
- relationalai/early_access/dsl/ontologies/models.py +0 -453
- relationalai/early_access/dsl/ontologies/python_printer.py +0 -303
- relationalai/early_access/dsl/ontologies/readings.py +0 -60
- relationalai/early_access/dsl/ontologies/relationships.py +0 -322
- relationalai/early_access/dsl/ontologies/roles.py +0 -87
- relationalai/early_access/dsl/ontologies/subtyping.py +0 -55
- relationalai/early_access/dsl/orm/constraints.py +0 -438
- relationalai/early_access/dsl/orm/measures/dimensions.py +0 -200
- relationalai/early_access/dsl/orm/measures/initializer.py +0 -16
- relationalai/early_access/dsl/orm/measures/measure_rules.py +0 -275
- relationalai/early_access/dsl/orm/measures/measures.py +0 -299
- relationalai/early_access/dsl/orm/measures/role_exprs.py +0 -268
- relationalai/early_access/dsl/orm/models.py +0 -256
- relationalai/early_access/dsl/orm/object_oriented_printer.py +0 -344
- relationalai/early_access/dsl/orm/printer.py +0 -469
- relationalai/early_access/dsl/orm/reasoners.py +0 -480
- relationalai/early_access/dsl/orm/relations.py +0 -19
- relationalai/early_access/dsl/orm/relationships.py +0 -251
- relationalai/early_access/dsl/orm/types.py +0 -42
- relationalai/early_access/dsl/orm/utils.py +0 -79
- relationalai/early_access/dsl/orm/verb.py +0 -204
- relationalai/early_access/dsl/physical_metadata/tables.py +0 -133
- relationalai/early_access/dsl/relations.py +0 -170
- relationalai/early_access/dsl/rulesets.py +0 -69
- relationalai/early_access/dsl/schemas/__init__.py +0 -450
- relationalai/early_access/dsl/schemas/builder.py +0 -48
- relationalai/early_access/dsl/schemas/comp_names.py +0 -51
- relationalai/early_access/dsl/schemas/components.py +0 -203
- relationalai/early_access/dsl/schemas/contexts.py +0 -156
- relationalai/early_access/dsl/schemas/exprs.py +0 -89
- relationalai/early_access/dsl/schemas/fragments.py +0 -464
- relationalai/early_access/dsl/serialization.py +0 -79
- relationalai/early_access/dsl/serialize/exporter.py +0 -163
- relationalai/early_access/dsl/snow/api.py +0 -104
- relationalai/early_access/dsl/snow/common.py +0 -76
- relationalai/early_access/dsl/state_mgmt/__init__.py +0 -129
- relationalai/early_access/dsl/state_mgmt/state_charts.py +0 -125
- relationalai/early_access/dsl/state_mgmt/transitions.py +0 -130
- relationalai/early_access/dsl/types/__init__.py +0 -40
- relationalai/early_access/dsl/types/concepts.py +0 -12
- relationalai/early_access/dsl/types/entities.py +0 -135
- relationalai/early_access/dsl/types/values.py +0 -17
- relationalai/early_access/dsl/utils.py +0 -102
- relationalai/early_access/graphs/__init__.py +0 -13
- relationalai/early_access/lqp/__init__.py +0 -12
- relationalai/early_access/lqp/compiler/__init__.py +0 -12
- relationalai/early_access/lqp/constructors/__init__.py +0 -18
- relationalai/early_access/lqp/executor/__init__.py +0 -12
- relationalai/early_access/lqp/ir/__init__.py +0 -12
- relationalai/early_access/lqp/passes/__init__.py +0 -12
- relationalai/early_access/lqp/pragmas/__init__.py +0 -12
- relationalai/early_access/lqp/primitives/__init__.py +0 -12
- relationalai/early_access/lqp/types/__init__.py +0 -12
- relationalai/early_access/lqp/utils/__init__.py +0 -12
- relationalai/early_access/lqp/validators/__init__.py +0 -12
- relationalai/early_access/metamodel/__init__.py +0 -58
- relationalai/early_access/metamodel/builtins/__init__.py +0 -12
- relationalai/early_access/metamodel/compiler/__init__.py +0 -12
- relationalai/early_access/metamodel/dependency/__init__.py +0 -12
- relationalai/early_access/metamodel/factory/__init__.py +0 -17
- relationalai/early_access/metamodel/helpers/__init__.py +0 -12
- relationalai/early_access/metamodel/ir/__init__.py +0 -14
- relationalai/early_access/metamodel/rewrite/__init__.py +0 -7
- relationalai/early_access/metamodel/typer/__init__.py +0 -3
- relationalai/early_access/metamodel/typer/typer/__init__.py +0 -12
- relationalai/early_access/metamodel/types/__init__.py +0 -15
- relationalai/early_access/metamodel/util/__init__.py +0 -15
- relationalai/early_access/metamodel/visitor/__init__.py +0 -12
- relationalai/early_access/rel/__init__.py +0 -12
- relationalai/early_access/rel/executor/__init__.py +0 -12
- relationalai/early_access/rel/rel_utils/__init__.py +0 -12
- relationalai/early_access/rel/rewrite/__init__.py +0 -7
- relationalai/early_access/solvers/__init__.py +0 -19
- relationalai/early_access/sql/__init__.py +0 -11
- relationalai/early_access/sql/executor/__init__.py +0 -3
- relationalai/early_access/sql/rewrite/__init__.py +0 -3
- relationalai/early_access/tests/logging/__init__.py +0 -12
- relationalai/early_access/tests/test_snapshot_base/__init__.py +0 -12
- relationalai/early_access/tests/utils/__init__.py +0 -12
- relationalai/environments/__init__.py +0 -35
- relationalai/environments/base.py +0 -381
- relationalai/environments/colab.py +0 -14
- relationalai/environments/generic.py +0 -71
- relationalai/environments/ipython.py +0 -68
- relationalai/environments/jupyter.py +0 -9
- relationalai/environments/snowbook.py +0 -169
- relationalai/errors.py +0 -2455
- relationalai/experimental/SF.py +0 -38
- relationalai/experimental/inspect.py +0 -47
- relationalai/experimental/pathfinder/__init__.py +0 -158
- relationalai/experimental/pathfinder/api.py +0 -160
- relationalai/experimental/pathfinder/automaton.py +0 -584
- relationalai/experimental/pathfinder/bridge.py +0 -226
- relationalai/experimental/pathfinder/compiler.py +0 -416
- relationalai/experimental/pathfinder/datalog.py +0 -214
- relationalai/experimental/pathfinder/diagnostics.py +0 -56
- relationalai/experimental/pathfinder/filter.py +0 -236
- relationalai/experimental/pathfinder/glushkov.py +0 -439
- relationalai/experimental/pathfinder/options.py +0 -265
- relationalai/experimental/pathfinder/pathfinder-v0.7.0.rel +0 -1951
- relationalai/experimental/pathfinder/rpq.py +0 -344
- relationalai/experimental/pathfinder/transition.py +0 -200
- relationalai/experimental/pathfinder/utils.py +0 -26
- relationalai/experimental/paths/README.md +0 -107
- relationalai/experimental/paths/api.py +0 -143
- relationalai/experimental/paths/benchmarks/grid_graph.py +0 -37
- relationalai/experimental/paths/code_organization.md +0 -2
- relationalai/experimental/paths/examples/Movies.ipynb +0 -16328
- relationalai/experimental/paths/examples/basic_example.py +0 -40
- relationalai/experimental/paths/examples/minimal_engine_warmup.py +0 -3
- relationalai/experimental/paths/examples/movie_example.py +0 -77
- relationalai/experimental/paths/examples/movies_data/actedin.csv +0 -193
- relationalai/experimental/paths/examples/movies_data/directed.csv +0 -45
- relationalai/experimental/paths/examples/movies_data/follows.csv +0 -7
- relationalai/experimental/paths/examples/movies_data/movies.csv +0 -39
- relationalai/experimental/paths/examples/movies_data/person.csv +0 -134
- relationalai/experimental/paths/examples/movies_data/produced.csv +0 -16
- relationalai/experimental/paths/examples/movies_data/ratings.csv +0 -10
- relationalai/experimental/paths/examples/movies_data/wrote.csv +0 -11
- relationalai/experimental/paths/examples/paths_benchmark.py +0 -115
- relationalai/experimental/paths/examples/paths_example.py +0 -116
- relationalai/experimental/paths/examples/pattern_to_automaton.py +0 -28
- relationalai/experimental/paths/find_paths_via_automaton.py +0 -85
- relationalai/experimental/paths/graph.py +0 -185
- relationalai/experimental/paths/path_algorithms/find_paths.py +0 -280
- relationalai/experimental/paths/path_algorithms/one_sided_ball_repetition.py +0 -26
- relationalai/experimental/paths/path_algorithms/one_sided_ball_upto.py +0 -111
- relationalai/experimental/paths/path_algorithms/single.py +0 -59
- relationalai/experimental/paths/path_algorithms/two_sided_balls_repetition.py +0 -39
- relationalai/experimental/paths/path_algorithms/two_sided_balls_upto.py +0 -103
- relationalai/experimental/paths/path_algorithms/usp-old.py +0 -130
- relationalai/experimental/paths/path_algorithms/usp-tuple.py +0 -183
- relationalai/experimental/paths/path_algorithms/usp.py +0 -150
- relationalai/experimental/paths/product_graph.py +0 -93
- relationalai/experimental/paths/rpq/automaton.py +0 -584
- relationalai/experimental/paths/rpq/diagnostics.py +0 -56
- relationalai/experimental/paths/rpq/rpq.py +0 -378
- relationalai/experimental/paths/tests/tests_limit_sp_max_length.py +0 -90
- relationalai/experimental/paths/tests/tests_limit_sp_multiple.py +0 -119
- relationalai/experimental/paths/tests/tests_limit_sp_single.py +0 -104
- relationalai/experimental/paths/tests/tests_limit_walks_multiple.py +0 -113
- relationalai/experimental/paths/tests/tests_limit_walks_single.py +0 -149
- relationalai/experimental/paths/tests/tests_one_sided_ball_repetition_multiple.py +0 -70
- relationalai/experimental/paths/tests/tests_one_sided_ball_repetition_single.py +0 -64
- relationalai/experimental/paths/tests/tests_one_sided_ball_upto_multiple.py +0 -115
- relationalai/experimental/paths/tests/tests_one_sided_ball_upto_single.py +0 -75
- relationalai/experimental/paths/tests/tests_single_paths.py +0 -152
- relationalai/experimental/paths/tests/tests_single_walks.py +0 -208
- relationalai/experimental/paths/tests/tests_single_walks_undirected.py +0 -297
- relationalai/experimental/paths/tests/tests_two_sided_balls_repetition_multiple.py +0 -107
- relationalai/experimental/paths/tests/tests_two_sided_balls_repetition_single.py +0 -76
- relationalai/experimental/paths/tests/tests_two_sided_balls_upto_multiple.py +0 -76
- relationalai/experimental/paths/tests/tests_two_sided_balls_upto_single.py +0 -110
- relationalai/experimental/paths/tests/tests_usp_nsp_multiple.py +0 -229
- relationalai/experimental/paths/tests/tests_usp_nsp_single.py +0 -108
- relationalai/experimental/paths/tree_agg.py +0 -168
- relationalai/experimental/paths/utilities/iterators.py +0 -27
- relationalai/experimental/paths/utilities/prefix_sum.py +0 -91
- relationalai/experimental/solvers.py +0 -1087
- relationalai/loaders/csv.py +0 -195
- relationalai/loaders/loader.py +0 -177
- relationalai/loaders/types.py +0 -23
- relationalai/rel_emitter.py +0 -373
- relationalai/rel_utils.py +0 -185
- relationalai/semantics/designs/query_builder/identify_by.md +0 -106
- relationalai/semantics/devtools/benchmark_lqp.py +0 -536
- relationalai/semantics/devtools/compilation_manager.py +0 -294
- relationalai/semantics/devtools/extract_lqp.py +0 -110
- relationalai/semantics/internal/internal.py +0 -3785
- relationalai/semantics/internal/snowflake.py +0 -324
- relationalai/semantics/lqp/README.md +0 -34
- relationalai/semantics/lqp/builtins.py +0 -16
- relationalai/semantics/lqp/compiler.py +0 -22
- relationalai/semantics/lqp/constructors.py +0 -68
- relationalai/semantics/lqp/executor.py +0 -469
- relationalai/semantics/lqp/intrinsics.py +0 -24
- relationalai/semantics/lqp/model2lqp.py +0 -839
- relationalai/semantics/lqp/passes.py +0 -680
- relationalai/semantics/lqp/primitives.py +0 -252
- relationalai/semantics/lqp/result_helpers.py +0 -202
- relationalai/semantics/lqp/rewrite/annotate_constraints.py +0 -57
- relationalai/semantics/lqp/rewrite/cdc.py +0 -216
- relationalai/semantics/lqp/rewrite/extract_common.py +0 -338
- relationalai/semantics/lqp/rewrite/extract_keys.py +0 -449
- relationalai/semantics/lqp/rewrite/function_annotations.py +0 -114
- relationalai/semantics/lqp/rewrite/functional_dependencies.py +0 -314
- relationalai/semantics/lqp/rewrite/quantify_vars.py +0 -296
- relationalai/semantics/lqp/rewrite/splinter.py +0 -76
- relationalai/semantics/lqp/types.py +0 -101
- relationalai/semantics/lqp/utils.py +0 -160
- relationalai/semantics/lqp/validators.py +0 -57
- relationalai/semantics/metamodel/compiler.py +0 -133
- relationalai/semantics/metamodel/dependency.py +0 -862
- relationalai/semantics/metamodel/executor.py +0 -61
- relationalai/semantics/metamodel/factory.py +0 -287
- relationalai/semantics/metamodel/helpers.py +0 -361
- relationalai/semantics/metamodel/rewrite/discharge_constraints.py +0 -39
- relationalai/semantics/metamodel/rewrite/dnf_union_splitter.py +0 -210
- relationalai/semantics/metamodel/rewrite/extract_nested_logicals.py +0 -78
- relationalai/semantics/metamodel/rewrite/flatten.py +0 -549
- relationalai/semantics/metamodel/rewrite/format_outputs.py +0 -165
- relationalai/semantics/metamodel/typer/checker.py +0 -353
- relationalai/semantics/metamodel/typer/typer.py +0 -1395
- relationalai/semantics/reasoners/__init__.py +0 -10
- relationalai/semantics/reasoners/graph/README.md +0 -620
- relationalai/semantics/reasoners/graph/__init__.py +0 -37
- relationalai/semantics/reasoners/graph/core.py +0 -9020
- relationalai/semantics/reasoners/graph/design/beyond_demand_transform.md +0 -797
- relationalai/semantics/reasoners/graph/tests/README.md +0 -21
- relationalai/semantics/reasoners/optimization/__init__.py +0 -68
- relationalai/semantics/reasoners/optimization/common.py +0 -88
- relationalai/semantics/reasoners/optimization/solvers_dev.py +0 -568
- relationalai/semantics/reasoners/optimization/solvers_pb.py +0 -1163
- relationalai/semantics/rel/builtins.py +0 -40
- relationalai/semantics/rel/compiler.py +0 -989
- relationalai/semantics/rel/executor.py +0 -359
- relationalai/semantics/rel/rel.py +0 -482
- relationalai/semantics/rel/rel_utils.py +0 -276
- relationalai/semantics/snowflake/__init__.py +0 -3
- relationalai/semantics/sql/compiler.py +0 -2503
- relationalai/semantics/sql/executor/duck_db.py +0 -52
- relationalai/semantics/sql/executor/result_helpers.py +0 -64
- relationalai/semantics/sql/executor/snowflake.py +0 -145
- relationalai/semantics/sql/rewrite/denormalize.py +0 -222
- relationalai/semantics/sql/rewrite/double_negation.py +0 -49
- relationalai/semantics/sql/rewrite/recursive_union.py +0 -127
- relationalai/semantics/sql/rewrite/sort_output_query.py +0 -246
- relationalai/semantics/sql/sql.py +0 -504
- relationalai/semantics/std/pragmas.py +0 -11
- relationalai/semantics/std/std.py +0 -14
- relationalai/semantics/tests/test_snapshot_abstract.py +0 -143
- relationalai/semantics/tests/test_snapshot_base.py +0 -9
- relationalai/semantics/tests/utils.py +0 -46
- relationalai/std/__init__.py +0 -70
- relationalai/tools/cli.py +0 -1940
- relationalai/tools/cli_controls.py +0 -1826
- relationalai/tools/cli_helpers.py +0 -390
- relationalai/tools/debugger_client.py +0 -109
- relationalai/tools/debugger_server.py +0 -302
- relationalai/tools/dev.py +0 -685
- relationalai/tools/notes +0 -7
- relationalai/tools/qb_debugger.py +0 -425
- relationalai/util/clean_up_databases.py +0 -95
- relationalai/util/list_databases.py +0 -9
- relationalai/util/otel_configuration.py +0 -25
- relationalai/util/otel_handler.py +0 -484
- relationalai/util/snowflake_handler.py +0 -88
- relationalai/util/span_format_test.py +0 -43
- relationalai/util/span_tracker.py +0 -207
- relationalai/util/spans_file_handler.py +0 -72
- relationalai/util/tracing_handler.py +0 -34
- relationalai-0.12.13.dist-info/METADATA +0 -74
- relationalai-0.12.13.dist-info/RECORD +0 -449
- relationalai-0.12.13.dist-info/WHEEL +0 -4
- relationalai-0.12.13.dist-info/entry_points.txt +0 -3
- relationalai-0.12.13.dist-info/licenses/LICENSE +0 -202
- relationalai_test_util/__init__.py +0 -4
- relationalai_test_util/fixtures.py +0 -228
- relationalai_test_util/snapshot.py +0 -252
- relationalai_test_util/traceback.py +0 -118
- /relationalai/{analysis → semantics/frontend}/__init__.py +0 -0
- /relationalai/{auth/__init__.py → semantics/metamodel/metamodel_compiler.py} +0 -0
- /relationalai/{early_access → shims}/__init__.py +0 -0
- {relationalai/early_access/dsl/adapters → v0/relationalai/analysis}/__init__.py +0 -0
- {relationalai → v0/relationalai}/analysis/mechanistic.py +0 -0
- {relationalai → v0/relationalai}/analysis/whynot.py +0 -0
- {relationalai/early_access/dsl/adapters/orm → v0/relationalai/auth}/__init__.py +0 -0
- {relationalai → v0/relationalai}/auth/jwt_generator.py +0 -0
- {relationalai → v0/relationalai}/auth/oauth_callback_server.py +0 -0
- {relationalai → v0/relationalai}/auth/token_handler.py +0 -0
- {relationalai → v0/relationalai}/auth/util.py +0 -0
- {relationalai → v0/relationalai}/clients/__init__.py +0 -0
- {relationalai → v0/relationalai}/clients/cache_store.py +0 -0
- {relationalai → v0/relationalai}/compiler.py +0 -0
- {relationalai → v0/relationalai}/dependencies.py +0 -0
- {relationalai → v0/relationalai}/docutils.py +0 -0
- {relationalai/early_access/dsl/adapters/owl → v0/relationalai/early_access}/__init__.py +0 -0
- {relationalai → v0/relationalai}/early_access/dsl/__init__.py +0 -0
- {relationalai/early_access/dsl/bindings → v0/relationalai/early_access/dsl/adapters}/__init__.py +0 -0
- {relationalai/early_access/dsl/bindings/legacy → v0/relationalai/early_access/dsl/adapters/orm}/__init__.py +0 -0
- {relationalai → v0/relationalai}/early_access/dsl/adapters/orm/model.py +0 -0
- {relationalai/early_access/dsl/codegen → v0/relationalai/early_access/dsl/adapters/owl}/__init__.py +0 -0
- {relationalai → v0/relationalai}/early_access/dsl/adapters/owl/model.py +0 -0
- {relationalai/early_access/dsl/core/temporal → v0/relationalai/early_access/dsl/bindings}/__init__.py +0 -0
- {relationalai/early_access/dsl/ir → v0/relationalai/early_access/dsl/bindings/legacy}/__init__.py +0 -0
- {relationalai/early_access/dsl/ontologies → v0/relationalai/early_access/dsl/codegen}/__init__.py +0 -0
- {relationalai → v0/relationalai}/early_access/dsl/constants.py +0 -0
- {relationalai → v0/relationalai}/early_access/dsl/core/__init__.py +0 -0
- {relationalai → v0/relationalai}/early_access/dsl/core/constraints/__init__.py +0 -0
- {relationalai → v0/relationalai}/early_access/dsl/core/constraints/predicate/__init__.py +0 -0
- {relationalai → v0/relationalai}/early_access/dsl/core/stack.py +0 -0
- {relationalai/early_access/dsl/orm → v0/relationalai/early_access/dsl/core/temporal}/__init__.py +0 -0
- {relationalai → v0/relationalai}/early_access/dsl/core/utils.py +0 -0
- {relationalai/early_access/dsl/orm/measures → v0/relationalai/early_access/dsl/ir}/__init__.py +0 -0
- {relationalai/early_access/dsl/physical_metadata → v0/relationalai/early_access/dsl/ontologies}/__init__.py +0 -0
- {relationalai → v0/relationalai}/early_access/dsl/ontologies/raw_source.py +0 -0
- {relationalai/early_access/dsl/serialize → v0/relationalai/early_access/dsl/orm}/__init__.py +0 -0
- {relationalai/early_access/dsl/snow → v0/relationalai/early_access/dsl/orm/measures}/__init__.py +0 -0
- {relationalai → v0/relationalai}/early_access/dsl/orm/reasoner_errors.py +0 -0
- {relationalai/loaders → v0/relationalai/early_access/dsl/physical_metadata}/__init__.py +0 -0
- {relationalai/semantics/tests → v0/relationalai/early_access/dsl/serialize}/__init__.py +0 -0
- {relationalai → v0/relationalai}/early_access/dsl/serialize/binding_model.py +0 -0
- {relationalai → v0/relationalai}/early_access/dsl/serialize/model.py +0 -0
- {relationalai/tools → v0/relationalai/early_access/dsl/snow}/__init__.py +0 -0
- {relationalai → v0/relationalai}/early_access/tests/__init__.py +0 -0
- {relationalai → v0/relationalai}/environments/ci.py +0 -0
- {relationalai → v0/relationalai}/environments/hex.py +0 -0
- {relationalai → v0/relationalai}/environments/terminal.py +0 -0
- {relationalai → v0/relationalai}/experimental/__init__.py +0 -0
- {relationalai → v0/relationalai}/experimental/graphs.py +0 -0
- {relationalai → v0/relationalai}/experimental/paths/__init__.py +0 -0
- {relationalai → v0/relationalai}/experimental/paths/benchmarks/__init__.py +0 -0
- {relationalai → v0/relationalai}/experimental/paths/path_algorithms/__init__.py +0 -0
- {relationalai → v0/relationalai}/experimental/paths/rpq/__init__.py +0 -0
- {relationalai → v0/relationalai}/experimental/paths/rpq/filter.py +0 -0
- {relationalai → v0/relationalai}/experimental/paths/rpq/glushkov.py +0 -0
- {relationalai → v0/relationalai}/experimental/paths/rpq/transition.py +0 -0
- {relationalai → v0/relationalai}/experimental/paths/utilities/__init__.py +0 -0
- {relationalai → v0/relationalai}/experimental/paths/utilities/utilities.py +0 -0
- {relationalai → v0/relationalai}/metagen.py +0 -0
- {relationalai → v0/relationalai}/metamodel.py +0 -0
- {relationalai → v0/relationalai}/rel.py +0 -0
- {relationalai → v0/relationalai}/semantics/devtools/__init__.py +0 -0
- {relationalai → v0/relationalai}/semantics/internal/__init__.py +0 -0
- {relationalai → v0/relationalai}/semantics/internal/annotations.py +0 -0
- {relationalai → v0/relationalai}/semantics/lqp/__init__.py +0 -0
- {relationalai → v0/relationalai}/semantics/lqp/ir.py +0 -0
- {relationalai → v0/relationalai}/semantics/lqp/pragmas.py +0 -0
- {relationalai → v0/relationalai}/semantics/lqp/rewrite/__init__.py +0 -0
- {relationalai → v0/relationalai}/semantics/metamodel/dataflow.py +0 -0
- {relationalai → v0/relationalai}/semantics/metamodel/ir.py +0 -0
- {relationalai → v0/relationalai}/semantics/metamodel/rewrite/__init__.py +0 -0
- {relationalai → v0/relationalai}/semantics/metamodel/typer/__init__.py +0 -0
- {relationalai → v0/relationalai}/semantics/metamodel/types.py +0 -0
- {relationalai → v0/relationalai}/semantics/metamodel/util.py +0 -0
- {relationalai → v0/relationalai}/semantics/metamodel/visitor.py +0 -0
- {relationalai → v0/relationalai}/semantics/reasoners/experimental/__init__.py +0 -0
- {relationalai → v0/relationalai}/semantics/rel/__init__.py +0 -0
- {relationalai → v0/relationalai}/semantics/sql/__init__.py +0 -0
- {relationalai → v0/relationalai}/semantics/sql/executor/__init__.py +0 -0
- {relationalai → v0/relationalai}/semantics/sql/rewrite/__init__.py +0 -0
- {relationalai → v0/relationalai}/semantics/tests/logging.py +0 -0
- {relationalai → v0/relationalai}/std/aggregates.py +0 -0
- {relationalai → v0/relationalai}/std/dates.py +0 -0
- {relationalai → v0/relationalai}/std/graphs.py +0 -0
- {relationalai → v0/relationalai}/std/inspect.py +0 -0
- {relationalai → v0/relationalai}/std/math.py +0 -0
- {relationalai → v0/relationalai}/std/re.py +0 -0
- {relationalai → v0/relationalai}/std/strings.py +0 -0
- {relationalai → v0/relationalai}/tools/cleanup_snapshots.py +0 -0
- {relationalai → v0/relationalai}/tools/constants.py +0 -0
- {relationalai → v0/relationalai}/tools/query_utils.py +0 -0
- {relationalai → v0/relationalai}/tools/snapshot_viewer.py +0 -0
- {relationalai → v0/relationalai}/util/__init__.py +0 -0
- {relationalai → v0/relationalai}/util/constants.py +0 -0
- {relationalai → v0/relationalai}/util/graph.py +0 -0
- {relationalai → v0/relationalai}/util/timeout.py +0 -0
|
@@ -0,0 +1,3869 @@
|
|
|
1
|
+
# pyright: reportUnusedExpression=false
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
import base64
|
|
4
|
+
import decimal
|
|
5
|
+
import importlib.resources
|
|
6
|
+
import io
|
|
7
|
+
from numbers import Number
|
|
8
|
+
import re
|
|
9
|
+
import json
|
|
10
|
+
import time
|
|
11
|
+
import textwrap
|
|
12
|
+
import ast
|
|
13
|
+
import uuid
|
|
14
|
+
import warnings
|
|
15
|
+
import atexit
|
|
16
|
+
import hashlib
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
from v0.relationalai.auth.token_handler import TokenHandler
|
|
20
|
+
from v0.relationalai.clients.use_index_poller import DirectUseIndexPoller, UseIndexPoller
|
|
21
|
+
import snowflake.snowpark
|
|
22
|
+
|
|
23
|
+
from v0.relationalai.rel_utils import sanitize_identifier, to_fqn_relation_name
|
|
24
|
+
from v0.relationalai.tools.constants import FIELD_PLACEHOLDER, RAI_APP_NAME, SNOWFLAKE_AUTHS, USE_GRAPH_INDEX, USE_DIRECT_ACCESS, DEFAULT_QUERY_TIMEOUT_MINS, WAIT_FOR_STREAM_SYNC, Generation
|
|
25
|
+
from .. import std
|
|
26
|
+
from collections import defaultdict
|
|
27
|
+
import requests
|
|
28
|
+
import snowflake.connector
|
|
29
|
+
import pyarrow as pa
|
|
30
|
+
|
|
31
|
+
from snowflake.snowpark import Session
|
|
32
|
+
from snowflake.snowpark.context import get_active_session
|
|
33
|
+
from . import result_helpers
|
|
34
|
+
from .. import debugging
|
|
35
|
+
from typing import Any, Dict, Iterable, Optional, Tuple, List, Literal, Union, cast
|
|
36
|
+
|
|
37
|
+
from pandas import DataFrame
|
|
38
|
+
|
|
39
|
+
from ..tools.cli_controls import Spinner
|
|
40
|
+
from ..clients.types import AvailableModel, EngineState, Import, ImportSource, ImportSourceTable, ImportsStatus, SourceInfo, TransactionAsyncResponse
|
|
41
|
+
from ..clients.config import Config, ConfigStore, ENDPOINT_FILE
|
|
42
|
+
from ..clients.client import Client, ExportParams, ProviderBase, ResourcesBase
|
|
43
|
+
from ..clients.direct_access_client import DirectAccessClient
|
|
44
|
+
from ..clients.util import IdentityParser, escape_for_f_string, get_pyrel_version, get_with_retries, poll_with_specified_overhead, safe_json_loads, sanitize_module_name, scrub_exception, wrap_with_request_id, ms_to_timestamp, normalize_datetime
|
|
45
|
+
from ..environments import runtime_env, HexEnvironment, SnowbookEnvironment
|
|
46
|
+
from .. import dsl, rel, metamodel as m
|
|
47
|
+
from ..errors import DuoSecurityFailed, EngineProvisioningFailed, EngineNameValidationException, EngineNotFoundException, EnginePending, EngineSizeMismatchWarning, EngineResumeFailed, Errors, InvalidAliasError, InvalidEngineSizeError, InvalidSourceTypeWarning, RAIAbortedTransactionError, RAIException, HexSessionException, SnowflakeAppMissingException, SnowflakeChangeTrackingNotEnabledException, SnowflakeDatabaseException, SnowflakeImportMissingException, SnowflakeInvalidSource, SnowflakeMissingConfigValuesException, SnowflakeProxyAPIDeprecationWarning, SnowflakeProxySourceError, SnowflakeRaiAppNotStarted, ModelNotFoundException, UnknownSourceWarning, ResponseStatusException, RowsDroppedFromTargetTableWarning, QueryTimeoutExceededException
|
|
48
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
49
|
+
from datetime import datetime, date, timedelta
|
|
50
|
+
from snowflake.snowpark.types import StringType, StructField, StructType
|
|
51
|
+
|
|
52
|
+
# warehouse-based snowflake notebooks currently don't have hazmat
|
|
53
|
+
crypto_disabled = False
|
|
54
|
+
try:
|
|
55
|
+
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
|
56
|
+
from cryptography.hazmat.backends import default_backend
|
|
57
|
+
from cryptography.hazmat.primitives import padding
|
|
58
|
+
except (ModuleNotFoundError, ImportError):
|
|
59
|
+
crypto_disabled = True
|
|
60
|
+
|
|
61
|
+
#--------------------------------------------------
|
|
62
|
+
# Constants
|
|
63
|
+
#--------------------------------------------------
|
|
64
|
+
|
|
65
|
+
VALID_POOL_STATUS = ["ACTIVE", "IDLE", "SUSPENDED"]
|
|
66
|
+
# transaction list and get return different fields (duration vs timings)
|
|
67
|
+
LIST_TXN_SQL_FIELDS = ["id", "database_name", "engine_name", "state", "abort_reason", "read_only","created_by", "created_on", "finished_at", "duration"]
|
|
68
|
+
GET_TXN_SQL_FIELDS = ["id", "database", "engine", "state", "abort_reason", "read_only","created_by", "created_on", "finished_at", "timings"]
|
|
69
|
+
IMPORT_STREAM_FIELDS = ["ID", "CREATED_AT", "CREATED_BY", "STATUS", "REFERENCE_NAME", "REFERENCE_ALIAS", "FQ_OBJECT_NAME", "RAI_DATABASE",
|
|
70
|
+
"RAI_RELATION", "DATA_SYNC_STATUS", "PENDING_BATCHES_COUNT", "NEXT_BATCH_STATUS", "NEXT_BATCH_UNLOADED_TIMESTAMP",
|
|
71
|
+
"NEXT_BATCH_DETAILS", "LAST_BATCH_DETAILS", "LAST_BATCH_UNLOADED_TIMESTAMP", "CDC_STATUS"]
|
|
72
|
+
VALID_ENGINE_STATES = ["READY", "PENDING"]
|
|
73
|
+
|
|
74
|
+
# Cloud-specific engine sizes
|
|
75
|
+
INTERNAL_ENGINE_SIZES = ["XS", "S", "M", "L"]
|
|
76
|
+
ENGINE_SIZES_AWS = ["HIGHMEM_X64_S", "HIGHMEM_X64_M", "HIGHMEM_X64_L"]
|
|
77
|
+
ENGINE_SIZES_AZURE = ["HIGHMEM_X64_S", "HIGHMEM_X64_M", "HIGHMEM_X64_SL"]
|
|
78
|
+
|
|
79
|
+
FIELD_MAP = {
|
|
80
|
+
"database_name": "database",
|
|
81
|
+
"engine_name": "engine",
|
|
82
|
+
}
|
|
83
|
+
VALID_IMPORT_STATES = ["PENDING", "PROCESSING", "QUARANTINED", "LOADED"]
|
|
84
|
+
ENGINE_ERRORS = ["engine is suspended", "create/resume", "engine not found", "no engines found", "engine was deleted"]
|
|
85
|
+
ENGINE_NOT_READY_MSGS = ["engine is in pending", "engine is provisioning"]
|
|
86
|
+
DATABASE_ERRORS = ["database not found"]
|
|
87
|
+
PYREL_ROOT_DB = 'pyrel_root_db'
|
|
88
|
+
|
|
89
|
+
TERMINAL_TXN_STATES = ["COMPLETED", "ABORTED"]
|
|
90
|
+
|
|
91
|
+
DUO_TEXT = "duo security"
|
|
92
|
+
|
|
93
|
+
TXN_ABORT_REASON_TIMEOUT = "transaction timeout"
|
|
94
|
+
|
|
95
|
+
#--------------------------------------------------
|
|
96
|
+
# Helpers
|
|
97
|
+
#--------------------------------------------------
|
|
98
|
+
|
|
99
|
+
def process_jinja_template(template: str, indent_spaces = 0, **substitutions) -> str:
|
|
100
|
+
"""Process a Jinja-like template.
|
|
101
|
+
|
|
102
|
+
Supports:
|
|
103
|
+
- Variable substitution {{ var }}
|
|
104
|
+
- Conditional blocks {% if condition %} ... {% endif %}
|
|
105
|
+
- For loops {% for item in items %} ... {% endfor %}
|
|
106
|
+
- Comments {# ... #}
|
|
107
|
+
- Whitespace control with {%- and -%}
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
template: The template string
|
|
111
|
+
indent_spaces: Number of spaces to indent the result
|
|
112
|
+
**substitutions: Variable substitutions
|
|
113
|
+
"""
|
|
114
|
+
|
|
115
|
+
def evaluate_condition(condition: str, context: dict) -> bool:
|
|
116
|
+
"""Safely evaluate a condition string using the context."""
|
|
117
|
+
# Replace variables with their values
|
|
118
|
+
for k, v in context.items():
|
|
119
|
+
if isinstance(v, str):
|
|
120
|
+
condition = condition.replace(k, f"'{v}'")
|
|
121
|
+
else:
|
|
122
|
+
condition = condition.replace(k, str(v))
|
|
123
|
+
try:
|
|
124
|
+
return bool(eval(condition, {"__builtins__": {}}, {}))
|
|
125
|
+
except Exception:
|
|
126
|
+
return False
|
|
127
|
+
|
|
128
|
+
def process_expression(expr: str, context: dict) -> str:
|
|
129
|
+
"""Process a {{ expression }} block."""
|
|
130
|
+
expr = expr.strip()
|
|
131
|
+
if expr in context:
|
|
132
|
+
return str(context[expr])
|
|
133
|
+
return ""
|
|
134
|
+
|
|
135
|
+
def process_block(lines: List[str], context: dict, indent: int = 0) -> List[str]:
|
|
136
|
+
"""Process a block of template lines recursively."""
|
|
137
|
+
result = []
|
|
138
|
+
i = 0
|
|
139
|
+
while i < len(lines):
|
|
140
|
+
line = lines[i]
|
|
141
|
+
|
|
142
|
+
# Handle comments
|
|
143
|
+
line = re.sub(r'{#.*?#}', '', line)
|
|
144
|
+
|
|
145
|
+
# Handle if blocks
|
|
146
|
+
if_match = re.search(r'{%\s*if\s+(.+?)\s*%}', line)
|
|
147
|
+
if if_match:
|
|
148
|
+
condition = if_match.group(1)
|
|
149
|
+
if_block = []
|
|
150
|
+
else_block = []
|
|
151
|
+
i += 1
|
|
152
|
+
nesting = 1
|
|
153
|
+
in_else_block = False
|
|
154
|
+
while i < len(lines) and nesting > 0:
|
|
155
|
+
if re.search(r'{%\s*if\s+', lines[i]):
|
|
156
|
+
nesting += 1
|
|
157
|
+
elif re.search(r'{%\s*endif\s*%}', lines[i]):
|
|
158
|
+
nesting -= 1
|
|
159
|
+
elif nesting == 1 and re.search(r'{%\s*else\s*%}', lines[i]):
|
|
160
|
+
in_else_block = True
|
|
161
|
+
i += 1
|
|
162
|
+
continue
|
|
163
|
+
|
|
164
|
+
if nesting > 0:
|
|
165
|
+
if in_else_block:
|
|
166
|
+
else_block.append(lines[i])
|
|
167
|
+
else:
|
|
168
|
+
if_block.append(lines[i])
|
|
169
|
+
i += 1
|
|
170
|
+
if evaluate_condition(condition, context):
|
|
171
|
+
result.extend(process_block(if_block, context, indent))
|
|
172
|
+
else:
|
|
173
|
+
result.extend(process_block(else_block, context, indent))
|
|
174
|
+
continue
|
|
175
|
+
|
|
176
|
+
# Handle for loops
|
|
177
|
+
for_match = re.search(r'{%\s*for\s+(\w+)\s+in\s+(\w+)\s*%}', line)
|
|
178
|
+
if for_match:
|
|
179
|
+
var_name, iterable_name = for_match.groups()
|
|
180
|
+
for_block = []
|
|
181
|
+
i += 1
|
|
182
|
+
nesting = 1
|
|
183
|
+
while i < len(lines) and nesting > 0:
|
|
184
|
+
if re.search(r'{%\s*for\s+', lines[i]):
|
|
185
|
+
nesting += 1
|
|
186
|
+
elif re.search(r'{%\s*endfor\s*%}', lines[i]):
|
|
187
|
+
nesting -= 1
|
|
188
|
+
if nesting > 0:
|
|
189
|
+
for_block.append(lines[i])
|
|
190
|
+
i += 1
|
|
191
|
+
if iterable_name in context and isinstance(context[iterable_name], (list, tuple)):
|
|
192
|
+
for item in context[iterable_name]:
|
|
193
|
+
loop_context = dict(context)
|
|
194
|
+
loop_context[var_name] = item
|
|
195
|
+
result.extend(process_block(for_block, loop_context, indent))
|
|
196
|
+
continue
|
|
197
|
+
|
|
198
|
+
# Handle variable substitution
|
|
199
|
+
line = re.sub(r'{{\s*(\w+)\s*}}', lambda m: process_expression(m.group(1), context), line)
|
|
200
|
+
|
|
201
|
+
# Handle whitespace control
|
|
202
|
+
line = re.sub(r'{%-', '{%', line)
|
|
203
|
+
line = re.sub(r'-%}', '%}', line)
|
|
204
|
+
|
|
205
|
+
# Add line with proper indentation, preserving blank lines
|
|
206
|
+
if line.strip():
|
|
207
|
+
result.append(" " * (indent_spaces + indent) + line)
|
|
208
|
+
else:
|
|
209
|
+
result.append("")
|
|
210
|
+
|
|
211
|
+
i += 1
|
|
212
|
+
|
|
213
|
+
return result
|
|
214
|
+
|
|
215
|
+
# Split template into lines and process
|
|
216
|
+
lines = template.split('\n')
|
|
217
|
+
processed_lines = process_block(lines, substitutions)
|
|
218
|
+
|
|
219
|
+
return '\n'.join(processed_lines)
|
|
220
|
+
|
|
221
|
+
def type_to_sql(type) -> str:
|
|
222
|
+
if type is str:
|
|
223
|
+
return "VARCHAR"
|
|
224
|
+
if type is int:
|
|
225
|
+
return "NUMBER"
|
|
226
|
+
if type is Number:
|
|
227
|
+
return "DECIMAL(38, 15)"
|
|
228
|
+
if type is float:
|
|
229
|
+
return "FLOAT"
|
|
230
|
+
if type is decimal.Decimal:
|
|
231
|
+
return "DECIMAL(38, 15)"
|
|
232
|
+
if type is bool:
|
|
233
|
+
return "BOOLEAN"
|
|
234
|
+
if type is dict:
|
|
235
|
+
return "VARIANT"
|
|
236
|
+
if type is list:
|
|
237
|
+
return "ARRAY"
|
|
238
|
+
if type is bytes:
|
|
239
|
+
return "BINARY"
|
|
240
|
+
if type is datetime:
|
|
241
|
+
return "TIMESTAMP"
|
|
242
|
+
if type is date:
|
|
243
|
+
return "DATE"
|
|
244
|
+
if isinstance(type, dsl.Type):
|
|
245
|
+
return "VARCHAR"
|
|
246
|
+
raise ValueError(f"Unknown type {type}")
|
|
247
|
+
|
|
248
|
+
def type_to_snowpark(type) -> str:
|
|
249
|
+
if type is str:
|
|
250
|
+
return "StringType()"
|
|
251
|
+
if type is int:
|
|
252
|
+
return "IntegerType()"
|
|
253
|
+
if type is float:
|
|
254
|
+
return "FloatType()"
|
|
255
|
+
if type is Number:
|
|
256
|
+
return "DecimalType(38, 15)"
|
|
257
|
+
if type is decimal.Decimal:
|
|
258
|
+
return "DecimalType(38, 15)"
|
|
259
|
+
if type is bool:
|
|
260
|
+
return "BooleanType()"
|
|
261
|
+
if type is dict:
|
|
262
|
+
return "MapType()"
|
|
263
|
+
if type is list:
|
|
264
|
+
return "ArrayType()"
|
|
265
|
+
if type is bytes:
|
|
266
|
+
return "BinaryType()"
|
|
267
|
+
if type is datetime:
|
|
268
|
+
return "TimestampType()"
|
|
269
|
+
if type is date:
|
|
270
|
+
return "DateType()"
|
|
271
|
+
if isinstance(type, dsl.Type):
|
|
272
|
+
return "StringType()"
|
|
273
|
+
raise ValueError(f"Unknown type {type}")
|
|
274
|
+
|
|
275
|
+
def _sanitize_user_name(user: str) -> str:
|
|
276
|
+
# Extract the part before the '@'
|
|
277
|
+
sanitized_user = user.split('@')[0]
|
|
278
|
+
# Replace any character that is not a letter, number, or underscore with '_'
|
|
279
|
+
sanitized_user = re.sub(r'[^a-zA-Z0-9_]', '_', sanitized_user)
|
|
280
|
+
return sanitized_user
|
|
281
|
+
|
|
282
|
+
def _is_engine_issue(response_message: str) -> bool:
|
|
283
|
+
return any(kw in response_message.lower() for kw in ENGINE_ERRORS + ENGINE_NOT_READY_MSGS)
|
|
284
|
+
|
|
285
|
+
def _is_database_issue(response_message: str) -> bool:
|
|
286
|
+
return any(kw in response_message.lower() for kw in DATABASE_ERRORS)
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
#--------------------------------------------------
|
|
290
|
+
# Resources
|
|
291
|
+
#--------------------------------------------------
|
|
292
|
+
|
|
293
|
+
APP_NAME = "___RAI_APP___"
|
|
294
|
+
|
|
295
|
+
class Resources(ResourcesBase):
|
|
296
|
+
def __init__(
|
|
297
|
+
self,
|
|
298
|
+
profile: str | None = None,
|
|
299
|
+
config: Config | None = None,
|
|
300
|
+
connection: Session | None = None,
|
|
301
|
+
dry_run: bool = False,
|
|
302
|
+
reset_session: bool = False,
|
|
303
|
+
generation: Generation | None = None,
|
|
304
|
+
language: str = "rel",
|
|
305
|
+
):
|
|
306
|
+
super().__init__(profile, config=config)
|
|
307
|
+
self._token_handler: TokenHandler | None = None
|
|
308
|
+
self._session = connection
|
|
309
|
+
self.generation = generation
|
|
310
|
+
if self._session is None and not dry_run:
|
|
311
|
+
try:
|
|
312
|
+
# we may still be constructing the config, so this can fail now,
|
|
313
|
+
# if so we'll create later
|
|
314
|
+
self._session = self.get_sf_session(reset_session)
|
|
315
|
+
except Exception:
|
|
316
|
+
pass
|
|
317
|
+
self._pending_transactions: list[str] = []
|
|
318
|
+
self._ns_cache = {}
|
|
319
|
+
# self.sources contains fully qualified Snowflake table/view names
|
|
320
|
+
self.sources: set[str] = set()
|
|
321
|
+
self._sproc_models = None
|
|
322
|
+
self.database = ""
|
|
323
|
+
self.language = language
|
|
324
|
+
atexit.register(self.cancel_pending_transactions)
|
|
325
|
+
|
|
326
|
+
@property
|
|
327
|
+
def token_handler(self) -> TokenHandler:
|
|
328
|
+
if not self._token_handler:
|
|
329
|
+
self._token_handler = TokenHandler.from_config(self.config)
|
|
330
|
+
return self._token_handler
|
|
331
|
+
|
|
332
|
+
def is_erp_running(self, app_name: str) -> bool:
|
|
333
|
+
"""Check if the ERP is running. The app.service_status() returns single row/column containing an array of JSON service status objects."""
|
|
334
|
+
query = f"CALL {app_name}.app.service_status();"
|
|
335
|
+
try:
|
|
336
|
+
result = self._exec(query)
|
|
337
|
+
# The result is a list of dictionaries, each with a "STATUS" key
|
|
338
|
+
# The column name containing the result is "SERVICE_STATUS"
|
|
339
|
+
services_status = json.loads(result[0]["SERVICE_STATUS"])
|
|
340
|
+
# Find the dictionary with "name" of "main" and check if its "status" is "READY"
|
|
341
|
+
for service in services_status:
|
|
342
|
+
if service.get("name") == "main" and service.get("status") == "READY":
|
|
343
|
+
return True
|
|
344
|
+
return False
|
|
345
|
+
except Exception:
|
|
346
|
+
return False
|
|
347
|
+
|
|
348
|
+
def get_sf_session(self, reset_session: bool = False):
|
|
349
|
+
if self._session:
|
|
350
|
+
return self._session
|
|
351
|
+
|
|
352
|
+
if isinstance(runtime_env, HexEnvironment):
|
|
353
|
+
raise HexSessionException()
|
|
354
|
+
if isinstance(runtime_env, SnowbookEnvironment):
|
|
355
|
+
return get_active_session()
|
|
356
|
+
else:
|
|
357
|
+
# if there's already been a session created, try using that
|
|
358
|
+
# if reset_session is true always try to get the new session
|
|
359
|
+
if not reset_session:
|
|
360
|
+
try:
|
|
361
|
+
return get_active_session()
|
|
362
|
+
except Exception:
|
|
363
|
+
pass
|
|
364
|
+
|
|
365
|
+
# otherwise, create a new session
|
|
366
|
+
missing_keys = []
|
|
367
|
+
connection_parameters = {}
|
|
368
|
+
|
|
369
|
+
authenticator = self.config.get('authenticator', None)
|
|
370
|
+
passcode = self.config.get("passcode", "")
|
|
371
|
+
private_key_file = self.config.get("private_key_file", "")
|
|
372
|
+
|
|
373
|
+
# If the authenticator is not set, we need to set it based on the provided parameters
|
|
374
|
+
if authenticator is None:
|
|
375
|
+
if private_key_file != "":
|
|
376
|
+
authenticator = "snowflake_jwt"
|
|
377
|
+
elif passcode != "":
|
|
378
|
+
authenticator = "username_password_mfa"
|
|
379
|
+
else:
|
|
380
|
+
authenticator = "snowflake"
|
|
381
|
+
# set the default authenticator in the config so we can skip it when we check for missing keys
|
|
382
|
+
self.config.set("authenticator", authenticator)
|
|
383
|
+
|
|
384
|
+
if authenticator in SNOWFLAKE_AUTHS:
|
|
385
|
+
required_keys = {
|
|
386
|
+
key for key, value in SNOWFLAKE_AUTHS[authenticator].items() if value.get("required", True)
|
|
387
|
+
}
|
|
388
|
+
for key in required_keys:
|
|
389
|
+
if self.config.get(key, None) is None:
|
|
390
|
+
default = SNOWFLAKE_AUTHS[authenticator][key].get("value", None)
|
|
391
|
+
if default is None or default == FIELD_PLACEHOLDER:
|
|
392
|
+
# No default value and no value in the config, add to missing keys
|
|
393
|
+
missing_keys.append(key)
|
|
394
|
+
else:
|
|
395
|
+
# Set the default value in the config from the auth defaults
|
|
396
|
+
self.config.set(key, default)
|
|
397
|
+
if missing_keys:
|
|
398
|
+
profile = getattr(self.config, 'profile', None)
|
|
399
|
+
config_file_path = getattr(self.config, 'file_path', None)
|
|
400
|
+
raise SnowflakeMissingConfigValuesException(missing_keys, profile, config_file_path)
|
|
401
|
+
for key in SNOWFLAKE_AUTHS[authenticator]:
|
|
402
|
+
connection_parameters[key] = self.config.get(key, None)
|
|
403
|
+
else:
|
|
404
|
+
raise ValueError(f'Authenticator "{authenticator}" not supported')
|
|
405
|
+
|
|
406
|
+
return self._build_snowflake_session(connection_parameters)
|
|
407
|
+
|
|
408
|
+
def _build_snowflake_session(self, connection_parameters: Dict[str, Any]) -> Session:
|
|
409
|
+
try:
|
|
410
|
+
tmp = {
|
|
411
|
+
"client_session_keep_alive": True,
|
|
412
|
+
"client_session_keep_alive_heartbeat_frequency": 60 * 5,
|
|
413
|
+
}
|
|
414
|
+
tmp.update(connection_parameters)
|
|
415
|
+
connection_parameters = tmp
|
|
416
|
+
# authenticator programmatic access token needs to be upper cased to work...
|
|
417
|
+
connection_parameters["authenticator"] = connection_parameters["authenticator"].upper()
|
|
418
|
+
if "authenticator" in connection_parameters and connection_parameters["authenticator"] == "OAUTH_AUTHORIZATION_CODE":
|
|
419
|
+
# we are replicating OAUTH_AUTHORIZATION_CODE by first retrieving the token
|
|
420
|
+
# and then authenticating with the token via the OAUTH authenticator
|
|
421
|
+
connection_parameters["token"] = self.token_handler.get_session_login_token()
|
|
422
|
+
connection_parameters["authenticator"] = "OAUTH"
|
|
423
|
+
return Session.builder.configs(connection_parameters).create()
|
|
424
|
+
except snowflake.connector.errors.Error as e:
|
|
425
|
+
raise SnowflakeDatabaseException(e)
|
|
426
|
+
except Exception as e:
|
|
427
|
+
raise e
|
|
428
|
+
|
|
429
|
+
def _exec_sql(self, code: str, params: List[Any] | None, raw=False):
|
|
430
|
+
assert self._session is not None
|
|
431
|
+
sess_results = self._session.sql(
|
|
432
|
+
code.replace(APP_NAME, self.get_app_name()),
|
|
433
|
+
params
|
|
434
|
+
)
|
|
435
|
+
if raw:
|
|
436
|
+
return sess_results
|
|
437
|
+
return sess_results.collect()
|
|
438
|
+
|
|
439
|
+
def _exec(
|
|
440
|
+
self,
|
|
441
|
+
code: str,
|
|
442
|
+
params: List[Any] | Any | None = None,
|
|
443
|
+
raw: bool = False,
|
|
444
|
+
help: bool = True,
|
|
445
|
+
skip_engine_db_error_retry: bool = False
|
|
446
|
+
) -> Any:
|
|
447
|
+
# print(f"\n--- sql---\n{code}\n--- end sql---\n")
|
|
448
|
+
if not self._session:
|
|
449
|
+
self._session = self.get_sf_session()
|
|
450
|
+
|
|
451
|
+
try:
|
|
452
|
+
if params is not None and not isinstance(params, list):
|
|
453
|
+
params = cast(List[Any], [params])
|
|
454
|
+
return self._exec_sql(code, params, raw=raw)
|
|
455
|
+
except Exception as e:
|
|
456
|
+
if not help:
|
|
457
|
+
raise e
|
|
458
|
+
orig_message = str(e).lower()
|
|
459
|
+
rai_app = self.config.get("rai_app_name", "")
|
|
460
|
+
current_role = self.config.get("role")
|
|
461
|
+
engine = self.get_default_engine_name()
|
|
462
|
+
engine_size = self.config.get_default_engine_size()
|
|
463
|
+
assert isinstance(rai_app, str), f"rai_app_name must be a string, not {type(rai_app)}"
|
|
464
|
+
assert isinstance(engine, str), f"engine must be a string, not {type(engine)}"
|
|
465
|
+
print("\n")
|
|
466
|
+
if DUO_TEXT in orig_message:
|
|
467
|
+
raise DuoSecurityFailed(e)
|
|
468
|
+
if re.search(f"database '{rai_app}' does not exist or not authorized.".lower(), orig_message):
|
|
469
|
+
exception = SnowflakeAppMissingException(rai_app, current_role)
|
|
470
|
+
raise exception from None
|
|
471
|
+
# skip initializing the index if the query is a user transaction. exec_raw/exec_lqp will handle that case with the correct request headers.
|
|
472
|
+
if (_is_engine_issue(orig_message) or _is_database_issue(orig_message)) and not skip_engine_db_error_retry:
|
|
473
|
+
try:
|
|
474
|
+
self._poll_use_index(
|
|
475
|
+
app_name=self.get_app_name(),
|
|
476
|
+
sources=self.sources,
|
|
477
|
+
model=self.database,
|
|
478
|
+
engine_name=engine,
|
|
479
|
+
engine_size=engine_size
|
|
480
|
+
)
|
|
481
|
+
return self._exec(code, params, raw=raw, help=help)
|
|
482
|
+
except EngineNameValidationException as e:
|
|
483
|
+
raise EngineNameValidationException(engine) from e
|
|
484
|
+
except Exception as e:
|
|
485
|
+
raise EngineProvisioningFailed(engine, e) from e
|
|
486
|
+
elif re.search(r"javascript execution error", orig_message):
|
|
487
|
+
match = re.search(r"\"message\":\"(.*)\"", orig_message)
|
|
488
|
+
if match:
|
|
489
|
+
message = match.group(1)
|
|
490
|
+
if "engine is in pending" in message or "engine is provisioning" in message:
|
|
491
|
+
raise EnginePending(engine)
|
|
492
|
+
else:
|
|
493
|
+
raise RAIException(message) from None
|
|
494
|
+
|
|
495
|
+
if re.search(r"the relationalai service has not been started.", orig_message):
|
|
496
|
+
app_name = self.config.get("rai_app_name", "")
|
|
497
|
+
assert isinstance(app_name, str), f"rai_app_name must be a string, not {type(app_name)}"
|
|
498
|
+
raise SnowflakeRaiAppNotStarted(app_name)
|
|
499
|
+
|
|
500
|
+
if re.search(r"state:\s*aborted", orig_message):
|
|
501
|
+
txn_id_match = re.search(r"id:\s*([0-9a-f\-]+)", orig_message)
|
|
502
|
+
if txn_id_match:
|
|
503
|
+
txn_id = txn_id_match.group(1)
|
|
504
|
+
problems = self.get_transaction_problems(txn_id)
|
|
505
|
+
if problems:
|
|
506
|
+
for problem in problems:
|
|
507
|
+
if isinstance(problem, dict):
|
|
508
|
+
type_field = problem.get('TYPE')
|
|
509
|
+
message_field = problem.get('MESSAGE')
|
|
510
|
+
report_field = problem.get('REPORT')
|
|
511
|
+
else:
|
|
512
|
+
type_field = problem.TYPE
|
|
513
|
+
message_field = problem.MESSAGE
|
|
514
|
+
report_field = problem.REPORT
|
|
515
|
+
|
|
516
|
+
raise RAIAbortedTransactionError(type_field, message_field, report_field)
|
|
517
|
+
raise RAIException(str(e))
|
|
518
|
+
raise RAIException(str(e))
|
|
519
|
+
|
|
520
|
+
|
|
521
|
+
def reset(self):
|
|
522
|
+
self._session = None
|
|
523
|
+
|
|
524
|
+
#--------------------------------------------------
|
|
525
|
+
# Check direct access is enabled
|
|
526
|
+
#--------------------------------------------------
|
|
527
|
+
|
|
528
|
+
def is_direct_access_enabled(self) -> bool:
|
|
529
|
+
try:
|
|
530
|
+
feature_enabled = self._exec(
|
|
531
|
+
f"call {APP_NAME}.APP.DIRECT_INGRESS_ENABLED();"
|
|
532
|
+
)
|
|
533
|
+
if not feature_enabled:
|
|
534
|
+
return False
|
|
535
|
+
|
|
536
|
+
# Even if the feature is enabled, customers still need to reactivate ERP to ensure the endpoint is available.
|
|
537
|
+
endpoint = self._exec(
|
|
538
|
+
f"call {APP_NAME}.APP.SERVICE_ENDPOINT(true);"
|
|
539
|
+
)
|
|
540
|
+
if not endpoint or endpoint[0][0] is None:
|
|
541
|
+
return False
|
|
542
|
+
|
|
543
|
+
return feature_enabled[0][0]
|
|
544
|
+
except Exception as e:
|
|
545
|
+
raise Exception(f"Unable to determine if direct access is enabled. Details error: {e}") from e
|
|
546
|
+
|
|
547
|
+
#--------------------------------------------------
|
|
548
|
+
# Snowflake Account Flags
|
|
549
|
+
#--------------------------------------------------
|
|
550
|
+
|
|
551
|
+
def is_account_flag_set(self, flag: str) -> bool:
|
|
552
|
+
results = self._exec(
|
|
553
|
+
f"SHOW PARAMETERS LIKE '%{flag}%' IN ACCOUNT;"
|
|
554
|
+
)
|
|
555
|
+
if not results:
|
|
556
|
+
return False
|
|
557
|
+
return results[0]["value"] == "true"
|
|
558
|
+
|
|
559
|
+
#--------------------------------------------------
|
|
560
|
+
# Databases
|
|
561
|
+
#--------------------------------------------------
|
|
562
|
+
|
|
563
|
+
def get_database(self, database: str):
|
|
564
|
+
try:
|
|
565
|
+
results = self._exec(
|
|
566
|
+
f"call {APP_NAME}.api.get_database('{database}');"
|
|
567
|
+
)
|
|
568
|
+
except Exception as e:
|
|
569
|
+
if "Database does not exist" in str(e):
|
|
570
|
+
return None
|
|
571
|
+
raise e
|
|
572
|
+
|
|
573
|
+
if not results:
|
|
574
|
+
return None
|
|
575
|
+
db = results[0]
|
|
576
|
+
if not db:
|
|
577
|
+
return None
|
|
578
|
+
return {
|
|
579
|
+
"id": db["ID"],
|
|
580
|
+
"name": db["NAME"],
|
|
581
|
+
"created_by": db["CREATED_BY"],
|
|
582
|
+
"created_on": db["CREATED_ON"],
|
|
583
|
+
"deleted_by": db["DELETED_BY"],
|
|
584
|
+
"deleted_on": db["DELETED_ON"],
|
|
585
|
+
"state": db["STATE"],
|
|
586
|
+
}
|
|
587
|
+
|
|
588
|
+
def get_installed_packages(self, database: str) -> Dict | None:
|
|
589
|
+
query = f"call {APP_NAME}.api.get_installed_package_versions('{database}');"
|
|
590
|
+
try:
|
|
591
|
+
results = self._exec(query)
|
|
592
|
+
except Exception as e:
|
|
593
|
+
if "Database does not exist" in str(e):
|
|
594
|
+
return None
|
|
595
|
+
# fallback to None for old sql-lib versions
|
|
596
|
+
if "Unknown user-defined function" in str(e):
|
|
597
|
+
return None
|
|
598
|
+
raise e
|
|
599
|
+
|
|
600
|
+
if not results:
|
|
601
|
+
return None
|
|
602
|
+
|
|
603
|
+
row = results[0]
|
|
604
|
+
if not row:
|
|
605
|
+
return None
|
|
606
|
+
|
|
607
|
+
return safe_json_loads(row["PACKAGE_VERSIONS"])
|
|
608
|
+
|
|
609
|
+
#--------------------------------------------------
|
|
610
|
+
# Engines
|
|
611
|
+
#--------------------------------------------------
|
|
612
|
+
|
|
613
|
+
def get_engine_sizes(self, cloud_provider: str|None=None):
|
|
614
|
+
sizes = []
|
|
615
|
+
if cloud_provider is None:
|
|
616
|
+
cloud_provider = self.get_cloud_provider()
|
|
617
|
+
if cloud_provider == 'azure':
|
|
618
|
+
sizes = ENGINE_SIZES_AZURE
|
|
619
|
+
else:
|
|
620
|
+
sizes = ENGINE_SIZES_AWS
|
|
621
|
+
if self.config.show_all_engine_sizes():
|
|
622
|
+
return INTERNAL_ENGINE_SIZES + sizes
|
|
623
|
+
else:
|
|
624
|
+
return sizes
|
|
625
|
+
|
|
626
|
+
def list_engines(self, state: str | None = None):
|
|
627
|
+
where_clause = f"WHERE STATUS = '{state.upper()}'" if state else ""
|
|
628
|
+
statement = f"SELECT NAME, ID, SIZE, STATUS, CREATED_BY, CREATED_ON, UPDATED_ON FROM {APP_NAME}.api.engines {where_clause} ORDER BY NAME ASC;"
|
|
629
|
+
results = self._exec(statement)
|
|
630
|
+
if not results:
|
|
631
|
+
return []
|
|
632
|
+
return [
|
|
633
|
+
{
|
|
634
|
+
"name": row["NAME"],
|
|
635
|
+
"id": row["ID"],
|
|
636
|
+
"size": row["SIZE"],
|
|
637
|
+
"state": row["STATUS"], # callers are expecting 'state'
|
|
638
|
+
"created_by": row["CREATED_BY"],
|
|
639
|
+
"created_on": row["CREATED_ON"],
|
|
640
|
+
"updated_on": row["UPDATED_ON"],
|
|
641
|
+
}
|
|
642
|
+
for row in results
|
|
643
|
+
]
|
|
644
|
+
|
|
645
|
+
def get_engine(self, name: str):
|
|
646
|
+
results = self._exec(
|
|
647
|
+
f"SELECT NAME, ID, SIZE, STATUS, CREATED_BY, CREATED_ON, UPDATED_ON, VERSION, AUTO_SUSPEND_MINS, SUSPENDS_AT FROM {APP_NAME}.api.engines WHERE NAME='{name}';"
|
|
648
|
+
)
|
|
649
|
+
if not results:
|
|
650
|
+
return None
|
|
651
|
+
engine = results[0]
|
|
652
|
+
if not engine:
|
|
653
|
+
return None
|
|
654
|
+
engine_state: EngineState = {
|
|
655
|
+
"name": engine["NAME"],
|
|
656
|
+
"id": engine["ID"],
|
|
657
|
+
"size": engine["SIZE"],
|
|
658
|
+
"state": engine["STATUS"], # callers are expecting 'state'
|
|
659
|
+
"created_by": engine["CREATED_BY"],
|
|
660
|
+
"created_on": engine["CREATED_ON"],
|
|
661
|
+
"updated_on": engine["UPDATED_ON"],
|
|
662
|
+
"version": engine["VERSION"],
|
|
663
|
+
"auto_suspend": engine["AUTO_SUSPEND_MINS"],
|
|
664
|
+
"suspends_at": engine["SUSPENDS_AT"]
|
|
665
|
+
}
|
|
666
|
+
return engine_state
|
|
667
|
+
|
|
668
|
+
def get_default_engine_name(self) -> str:
|
|
669
|
+
if self.config.get("engine_name", None) is not None:
|
|
670
|
+
profile = self.config.profile
|
|
671
|
+
raise InvalidAliasError(f"""
|
|
672
|
+
'engine_name' is not a valid config option.
|
|
673
|
+
If you meant to use a specific engine, use 'engine' instead.
|
|
674
|
+
Otherwise, remove it from your '{profile}' configuration profile.
|
|
675
|
+
""")
|
|
676
|
+
engine = self.config.get("engine", None)
|
|
677
|
+
if not engine and self.config.get("user", None):
|
|
678
|
+
engine = _sanitize_user_name(str(self.config.get("user")))
|
|
679
|
+
if not engine:
|
|
680
|
+
engine = self.get_user_based_engine_name()
|
|
681
|
+
self.config.set("engine", engine)
|
|
682
|
+
return engine
|
|
683
|
+
|
|
684
|
+
def is_valid_engine_state(self, name:str):
|
|
685
|
+
return name in VALID_ENGINE_STATES
|
|
686
|
+
|
|
687
|
+
def _create_engine(
|
|
688
|
+
self,
|
|
689
|
+
name: str,
|
|
690
|
+
size: str | None = None,
|
|
691
|
+
auto_suspend_mins: int | None= None,
|
|
692
|
+
is_async: bool = False,
|
|
693
|
+
headers: Dict | None = None,
|
|
694
|
+
):
|
|
695
|
+
api = "create_engine_async" if is_async else "create_engine"
|
|
696
|
+
if size is None:
|
|
697
|
+
size = self.config.get_default_engine_size()
|
|
698
|
+
# if auto_suspend_mins is None, get the default value from the config
|
|
699
|
+
if auto_suspend_mins is None:
|
|
700
|
+
auto_suspend_mins = self.config.get_default_auto_suspend_mins()
|
|
701
|
+
try:
|
|
702
|
+
headers = debugging.gen_current_propagation_headers()
|
|
703
|
+
with debugging.span(api, name=name, size=size, auto_suspend_mins=auto_suspend_mins):
|
|
704
|
+
# check in case the config default is missing
|
|
705
|
+
if auto_suspend_mins is None:
|
|
706
|
+
self._exec(f"call {APP_NAME}.api.{api}('{name}', '{size}', null, {headers});")
|
|
707
|
+
else:
|
|
708
|
+
self._exec(f"call {APP_NAME}.api.{api}('{name}', '{size}', PARSE_JSON('{{\"auto_suspend_mins\": {auto_suspend_mins}}}'), {headers});")
|
|
709
|
+
except Exception as e:
|
|
710
|
+
raise EngineProvisioningFailed(name, e) from e
|
|
711
|
+
|
|
712
|
+
def create_engine(self, name:str, size:str|None=None, auto_suspend_mins:int|None=None, headers: Dict | None = None):
|
|
713
|
+
self._create_engine(name, size, auto_suspend_mins, headers=headers)
|
|
714
|
+
|
|
715
|
+
def create_engine_async(self, name:str, size:str|None=None, auto_suspend_mins:int|None=None):
|
|
716
|
+
self._create_engine(name, size, auto_suspend_mins, True)
|
|
717
|
+
|
|
718
|
+
def delete_engine(self, name:str, force:bool = False, headers: Dict | None = None):
|
|
719
|
+
request_headers = debugging.add_current_propagation_headers(headers)
|
|
720
|
+
self._exec(f"call {APP_NAME}.api.delete_engine('{name}', {force},{request_headers});")
|
|
721
|
+
|
|
722
|
+
def suspend_engine(self, name:str):
|
|
723
|
+
self._exec(f"call {APP_NAME}.api.suspend_engine('{name}');")
|
|
724
|
+
|
|
725
|
+
def resume_engine(self, name:str, headers: Dict | None = None) -> Dict:
|
|
726
|
+
request_headers = debugging.add_current_propagation_headers(headers)
|
|
727
|
+
self._exec(f"call {APP_NAME}.api.resume_engine('{name}',{request_headers});")
|
|
728
|
+
# returning empty dict to match the expected return type
|
|
729
|
+
return {}
|
|
730
|
+
|
|
731
|
+
def resume_engine_async(self, name:str, headers: Dict | None = None) -> Dict:
|
|
732
|
+
if headers is None:
|
|
733
|
+
headers = {}
|
|
734
|
+
self._exec(f"call {APP_NAME}.api.resume_engine_async('{name}',{headers});")
|
|
735
|
+
# returning empty dict to match the expected return type
|
|
736
|
+
return {}
|
|
737
|
+
|
|
738
|
+
def alter_engine_pool(self, size:str|None=None, mins:int|None=None, maxs:int|None=None):
|
|
739
|
+
"""Alter engine pool node limits for Snowflake."""
|
|
740
|
+
self._exec(f"call {APP_NAME}.api.alter_engine_pool_node_limits('{size}', {mins}, {maxs});")
|
|
741
|
+
|
|
742
|
+
#--------------------------------------------------
|
|
743
|
+
# Graphs
|
|
744
|
+
#--------------------------------------------------
|
|
745
|
+
|
|
746
|
+
def list_graphs(self) -> List[AvailableModel]:
|
|
747
|
+
with debugging.span("list_models"):
|
|
748
|
+
query = textwrap.dedent(f"""
|
|
749
|
+
SELECT NAME, ID, CREATED_BY, CREATED_ON, STATE, DELETED_BY, DELETED_ON
|
|
750
|
+
FROM {APP_NAME}.api.databases
|
|
751
|
+
WHERE state <> 'DELETED'
|
|
752
|
+
ORDER BY NAME ASC;
|
|
753
|
+
""")
|
|
754
|
+
results = self._exec(query)
|
|
755
|
+
if not results:
|
|
756
|
+
return []
|
|
757
|
+
return [
|
|
758
|
+
{
|
|
759
|
+
"name": row["NAME"],
|
|
760
|
+
"id": row["ID"],
|
|
761
|
+
"created_by": row["CREATED_BY"],
|
|
762
|
+
"created_on": row["CREATED_ON"],
|
|
763
|
+
"state": row["STATE"],
|
|
764
|
+
"deleted_by": row["DELETED_BY"],
|
|
765
|
+
"deleted_on": row["DELETED_ON"],
|
|
766
|
+
}
|
|
767
|
+
for row in results
|
|
768
|
+
]
|
|
769
|
+
|
|
770
|
+
def get_graph(self, name: str):
|
|
771
|
+
res = self.get_database(name)
|
|
772
|
+
if res and res.get("state") != "DELETED":
|
|
773
|
+
return res
|
|
774
|
+
|
|
775
|
+
def create_graph(self, name: str):
|
|
776
|
+
with debugging.span("create_model", name=name):
|
|
777
|
+
self._exec(f"call {APP_NAME}.api.create_database('{name}', false, {debugging.gen_current_propagation_headers()});")
|
|
778
|
+
|
|
779
|
+
def delete_graph(self, name:str, force=False, language:str="rel"):
|
|
780
|
+
prop_hdrs = debugging.gen_current_propagation_headers()
|
|
781
|
+
if self.config.get("use_graph_index", USE_GRAPH_INDEX):
|
|
782
|
+
keep_database = not force and self.config.get("reuse_model", True)
|
|
783
|
+
with debugging.span("release_index", name=name, keep_database=keep_database, language=language):
|
|
784
|
+
#TODO add headers to release_index
|
|
785
|
+
response = self._exec(f"call {APP_NAME}.api.release_index('{name}', OBJECT_CONSTRUCT('keep_database', {keep_database}, 'language', '{language}', 'user_agent', '{get_pyrel_version(self.generation)}'));")
|
|
786
|
+
if response:
|
|
787
|
+
result = next(iter(response))
|
|
788
|
+
obj = json.loads(result["RELEASE_INDEX"])
|
|
789
|
+
error = obj.get('error', None)
|
|
790
|
+
if error and "Model database not found" not in error:
|
|
791
|
+
raise Exception(f"Error releasing index: {error}")
|
|
792
|
+
else:
|
|
793
|
+
raise Exception("There was no response from the release index call.")
|
|
794
|
+
else:
|
|
795
|
+
with debugging.span("delete_model", name=name):
|
|
796
|
+
self._exec(f"call {APP_NAME}.api.delete_database('{name}', false, {prop_hdrs});")
|
|
797
|
+
|
|
798
|
+
def clone_graph(self, target_name:str, source_name:str, nowait_durable=True, force=False):
|
|
799
|
+
if force and self.get_graph(target_name):
|
|
800
|
+
self.delete_graph(target_name)
|
|
801
|
+
with debugging.span("clone_model", target_name=target_name, source_name=source_name):
|
|
802
|
+
# not a mistake: the clone_database argument order is indeed target then source:
|
|
803
|
+
headers = debugging.gen_current_propagation_headers()
|
|
804
|
+
self._exec(f"call {APP_NAME}.api.clone_database('{target_name}', '{source_name}', {nowait_durable}, {headers});")
|
|
805
|
+
|
|
806
|
+
def _poll_use_index(
|
|
807
|
+
self,
|
|
808
|
+
app_name: str,
|
|
809
|
+
sources: Iterable[str],
|
|
810
|
+
model: str,
|
|
811
|
+
engine_name: str,
|
|
812
|
+
engine_size: str | None = None,
|
|
813
|
+
program_span_id: str | None = None,
|
|
814
|
+
headers: Dict | None = None,
|
|
815
|
+
):
|
|
816
|
+
return UseIndexPoller(
|
|
817
|
+
self,
|
|
818
|
+
app_name,
|
|
819
|
+
sources,
|
|
820
|
+
model,
|
|
821
|
+
engine_name,
|
|
822
|
+
engine_size,
|
|
823
|
+
self.language,
|
|
824
|
+
program_span_id,
|
|
825
|
+
headers,
|
|
826
|
+
self.generation
|
|
827
|
+
).poll()
|
|
828
|
+
|
|
829
|
+
def maybe_poll_use_index(
|
|
830
|
+
self,
|
|
831
|
+
app_name: str,
|
|
832
|
+
sources: Iterable[str],
|
|
833
|
+
model: str,
|
|
834
|
+
engine_name: str,
|
|
835
|
+
engine_size: str | None = None,
|
|
836
|
+
program_span_id: str | None = None,
|
|
837
|
+
headers: Dict | None = None,
|
|
838
|
+
):
|
|
839
|
+
"""Only call poll() if there are sources to process and cache is not valid."""
|
|
840
|
+
sources_list = list(sources)
|
|
841
|
+
self.database = model
|
|
842
|
+
if sources_list:
|
|
843
|
+
poller = UseIndexPoller(
|
|
844
|
+
self,
|
|
845
|
+
app_name,
|
|
846
|
+
sources_list,
|
|
847
|
+
model,
|
|
848
|
+
engine_name,
|
|
849
|
+
engine_size,
|
|
850
|
+
self.language,
|
|
851
|
+
program_span_id,
|
|
852
|
+
headers,
|
|
853
|
+
self.generation
|
|
854
|
+
)
|
|
855
|
+
# If cache is valid (data freshness has not expired), skip polling
|
|
856
|
+
if poller.cache.is_valid():
|
|
857
|
+
cached_sources = len(poller.cache.sources)
|
|
858
|
+
total_sources = len(sources_list)
|
|
859
|
+
cached_timestamp = poller.cache._metadata.get("cachedIndices", {}).get(poller.cache.key, {}).get("last_use_index_update_on", "")
|
|
860
|
+
|
|
861
|
+
message = f"Using cached data for {cached_sources}/{total_sources} data streams"
|
|
862
|
+
if cached_timestamp:
|
|
863
|
+
print(f"\n{message} (cached at {cached_timestamp})\n")
|
|
864
|
+
else:
|
|
865
|
+
print(f"\n{message}\n")
|
|
866
|
+
else:
|
|
867
|
+
return poller.poll()
|
|
868
|
+
|
|
869
|
+
#--------------------------------------------------
|
|
870
|
+
# Models
|
|
871
|
+
#--------------------------------------------------
|
|
872
|
+
|
|
873
|
+
def list_models(self, database: str, engine: str):
|
|
874
|
+
pass
|
|
875
|
+
|
|
876
|
+
def create_models(self, database: str, engine: str | None, models:List[Tuple[str, str]]) -> List[Any]:
|
|
877
|
+
rel_code = self.create_models_code(models)
|
|
878
|
+
self.exec_raw(database, engine, rel_code, readonly=False)
|
|
879
|
+
# TODO: handle SPCS errors once they're figured out
|
|
880
|
+
return []
|
|
881
|
+
|
|
882
|
+
def delete_model(self, database:str, engine:str | None, name:str):
|
|
883
|
+
self.exec_raw(database, engine, f"def delete[:rel, :catalog, :model, \"{name}\"]: rel[:catalog, :model, \"{name}\"]", readonly=False)
|
|
884
|
+
|
|
885
|
+
def create_models_code(self, models:List[Tuple[str, str]]) -> str:
|
|
886
|
+
lines = []
|
|
887
|
+
for (name, code) in models:
|
|
888
|
+
name = name.replace("\"", "\\\"")
|
|
889
|
+
assert "\"\"\"\"\"\"\"" not in code, "Code literals must use fewer than 7 quotes."
|
|
890
|
+
|
|
891
|
+
lines.append(textwrap.dedent(f"""
|
|
892
|
+
def delete[:rel, :catalog, :model, "{name}"]: rel[:catalog, :model, "{name}"]
|
|
893
|
+
def insert[:rel, :catalog, :model, "{name}"]: raw\"\"\"\"\"\"\"
|
|
894
|
+
""") + code + "\n\"\"\"\"\"\"\"")
|
|
895
|
+
rel_code = "\n\n".join(lines)
|
|
896
|
+
return rel_code
|
|
897
|
+
|
|
898
|
+
#--------------------------------------------------
|
|
899
|
+
# Exports
|
|
900
|
+
#--------------------------------------------------
|
|
901
|
+
|
|
902
|
+
def list_exports(self, database: str, engine: str):
|
|
903
|
+
return []
|
|
904
|
+
|
|
905
|
+
def format_sproc_name(self, name: str, type:Any) -> str:
|
|
906
|
+
if type is datetime:
|
|
907
|
+
return f"{name}.astimezone(ZoneInfo('UTC')).isoformat(timespec='milliseconds')"
|
|
908
|
+
else:
|
|
909
|
+
return name
|
|
910
|
+
|
|
911
|
+
def get_export_code(self, params: ExportParams, all_installs):
|
|
912
|
+
sql_inputs = ", ".join([f"{name} {type_to_sql(type)}" for (name, _, type) in params.inputs])
|
|
913
|
+
input_names = [name for (name, *_) in params.inputs]
|
|
914
|
+
has_return_hint = params.out_fields and isinstance(params.out_fields[0], tuple)
|
|
915
|
+
if has_return_hint:
|
|
916
|
+
sql_out = ", ".join([f"\"{name}\" {type_to_sql(type)}" for (name, type) in params.out_fields])
|
|
917
|
+
sql_out_names = ", ".join([f"('{name}', '{type_to_sql(type)}')" for (ix, (name, type)) in enumerate(params.out_fields)])
|
|
918
|
+
py_outs = ", ".join([f"StructField(\"{name}\", {type_to_snowpark(type)})" for (name, type) in params.out_fields])
|
|
919
|
+
else:
|
|
920
|
+
sql_out = ""
|
|
921
|
+
sql_out_names = ", ".join([f"'{name}'" for name in params.out_fields])
|
|
922
|
+
py_outs = ", ".join([f"StructField(\"{name}\", {type_to_snowpark(str)})" for name in params.out_fields])
|
|
923
|
+
py_inputs = ", ".join([name for (name, *_) in params.inputs])
|
|
924
|
+
safe_rel = escape_for_f_string(params.code).strip()
|
|
925
|
+
clean_inputs = []
|
|
926
|
+
for (name, var, type) in params.inputs:
|
|
927
|
+
if type is str:
|
|
928
|
+
clean_inputs.append(f"{name} = '\"' + escape({name}) + '\"'")
|
|
929
|
+
# Replace `var` with `name` and keep the following non-word character unchanged
|
|
930
|
+
pattern = re.compile(re.escape(var) + r'(\W)')
|
|
931
|
+
value = self.format_sproc_name(name, type)
|
|
932
|
+
safe_rel = re.sub(pattern, rf"{{{value}}}\1", safe_rel)
|
|
933
|
+
if py_inputs:
|
|
934
|
+
py_inputs = f", {py_inputs}"
|
|
935
|
+
clean_inputs = ("\n").join(clean_inputs)
|
|
936
|
+
assert __package__ is not None, "Package name must be set"
|
|
937
|
+
file = "export_procedure.py.jinja"
|
|
938
|
+
with importlib.resources.open_text(
|
|
939
|
+
__package__, file
|
|
940
|
+
) as f:
|
|
941
|
+
template = f.read()
|
|
942
|
+
def quote(s: str, f = False) -> str:
|
|
943
|
+
return '"' + s + '"' if not f else 'f"' + s + '"'
|
|
944
|
+
|
|
945
|
+
wait_for_stream_sync = self.config.get("wait_for_stream_sync", WAIT_FOR_STREAM_SYNC)
|
|
946
|
+
# 1. Check the sources for staled sources
|
|
947
|
+
# 2. Get the object references for the sources
|
|
948
|
+
# TODO: this could be optimized to do it in the run time of the stored procedure
|
|
949
|
+
# instead of doing it here. It will make it more reliable when sources are
|
|
950
|
+
# modified after the stored procedure is created.
|
|
951
|
+
checked_sources = self._check_source_updates(self.sources)
|
|
952
|
+
source_obj_references = self._get_source_references(checked_sources)
|
|
953
|
+
|
|
954
|
+
# Escape double quotes in the source object references
|
|
955
|
+
escaped_source_obj_references = [source.replace('"', '\\"') for source in source_obj_references]
|
|
956
|
+
escaped_proc_database = params.proc_database.replace('"', '\\"')
|
|
957
|
+
|
|
958
|
+
normalized_func_name = IdentityParser(params.func_name).identity
|
|
959
|
+
assert normalized_func_name is not None, "Function name must be set"
|
|
960
|
+
skip_invalid_data = params.skip_invalid_data
|
|
961
|
+
python_code = process_jinja_template(
|
|
962
|
+
template,
|
|
963
|
+
func_name=quote(normalized_func_name),
|
|
964
|
+
database=quote(params.root_database),
|
|
965
|
+
proc_database=quote(escaped_proc_database),
|
|
966
|
+
engine=quote(params.engine),
|
|
967
|
+
rel_code=quote(safe_rel, f=True),
|
|
968
|
+
APP_NAME=quote(APP_NAME),
|
|
969
|
+
input_names=input_names,
|
|
970
|
+
outputs=sql_out,
|
|
971
|
+
sql_out_names=sql_out_names,
|
|
972
|
+
clean_inputs=clean_inputs,
|
|
973
|
+
py_inputs=py_inputs,
|
|
974
|
+
py_outs=py_outs,
|
|
975
|
+
skip_invalid_data=skip_invalid_data,
|
|
976
|
+
source_references=", ".join(escaped_source_obj_references),
|
|
977
|
+
install_code=all_installs.replace("\\", "\\\\").replace("\n", "\\n"),
|
|
978
|
+
has_return_hint=has_return_hint,
|
|
979
|
+
wait_for_stream_sync=wait_for_stream_sync,
|
|
980
|
+
).strip()
|
|
981
|
+
return_clause = f"TABLE({sql_out})" if sql_out else "STRING"
|
|
982
|
+
destination_input = "" if sql_out else "save_as_table STRING DEFAULT NULL,"
|
|
983
|
+
module_name = sanitize_module_name(normalized_func_name)
|
|
984
|
+
stage = f"@{self.get_app_name()}.app_state.stored_proc_code_stage"
|
|
985
|
+
file_loc = f"{stage}/{module_name}.py"
|
|
986
|
+
python_code = python_code.replace(APP_NAME, self.get_app_name())
|
|
987
|
+
|
|
988
|
+
hash = hashlib.sha256()
|
|
989
|
+
hash.update(python_code.encode('utf-8'))
|
|
990
|
+
code_hash = hash.hexdigest()
|
|
991
|
+
print(code_hash)
|
|
992
|
+
|
|
993
|
+
sql_code = textwrap.dedent(f"""
|
|
994
|
+
CREATE OR REPLACE PROCEDURE {normalized_func_name}({sql_inputs}{sql_inputs and ',' or ''} {destination_input} engine STRING DEFAULT NULL)
|
|
995
|
+
RETURNS {return_clause}
|
|
996
|
+
LANGUAGE PYTHON
|
|
997
|
+
RUNTIME_VERSION = '3.10'
|
|
998
|
+
IMPORTS = ('{file_loc}')
|
|
999
|
+
PACKAGES = ('snowflake-snowpark-python')
|
|
1000
|
+
HANDLER = 'checked_handle'
|
|
1001
|
+
EXECUTE AS CALLER
|
|
1002
|
+
AS
|
|
1003
|
+
$$
|
|
1004
|
+
import {module_name}
|
|
1005
|
+
import inspect, hashlib, os, sys
|
|
1006
|
+
def checked_handle(*args, **kwargs):
|
|
1007
|
+
import_dir = sys._xoptions["snowflake_import_directory"]
|
|
1008
|
+
wheel_path = os.path.join(import_dir, '{module_name}.py')
|
|
1009
|
+
h = hashlib.sha256()
|
|
1010
|
+
with open(wheel_path, 'rb') as f:
|
|
1011
|
+
for chunk in iter(lambda: f.read(1<<20), b''):
|
|
1012
|
+
h.update(chunk)
|
|
1013
|
+
code_hash = h.hexdigest()
|
|
1014
|
+
if code_hash != '{code_hash}':
|
|
1015
|
+
raise RuntimeError("Code hash mismatch. The code has been modified since it was uploaded.")
|
|
1016
|
+
# Call the handle function with the provided arguments
|
|
1017
|
+
return {module_name}.handle(*args, **kwargs)
|
|
1018
|
+
|
|
1019
|
+
$$;
|
|
1020
|
+
""")
|
|
1021
|
+
# print(f"\n--- python---\n{python_code}\n--- end python---\n")
|
|
1022
|
+
# This check helps catch invalid code early and for dry runs:
|
|
1023
|
+
try:
|
|
1024
|
+
ast.parse(python_code)
|
|
1025
|
+
except SyntaxError:
|
|
1026
|
+
raise ValueError(f"Internal error: invalid Python code generated:\n{python_code}")
|
|
1027
|
+
return (sql_code, python_code, file_loc)
|
|
1028
|
+
|
|
1029
|
+
def get_sproc_models(self, params: ExportParams):
|
|
1030
|
+
if self._sproc_models is not None:
|
|
1031
|
+
return self._sproc_models
|
|
1032
|
+
|
|
1033
|
+
with debugging.span("get_sproc_models"):
|
|
1034
|
+
code = """
|
|
1035
|
+
def output(name, model):
|
|
1036
|
+
rel(:catalog, :model, name, model)
|
|
1037
|
+
and not starts_with(name, "rel/")
|
|
1038
|
+
and not starts_with(name, "pkg/rel")
|
|
1039
|
+
and not starts_with(name, "pkg/std")
|
|
1040
|
+
and starts_with(name, "pkg/")
|
|
1041
|
+
"""
|
|
1042
|
+
res = self.exec_raw(params.model_database, params.engine, code, readonly=True, nowait_durable=True)
|
|
1043
|
+
df, errors = result_helpers.format_results(res, None, ["name", "model"])
|
|
1044
|
+
models = []
|
|
1045
|
+
for row in df.itertuples():
|
|
1046
|
+
models.append((row.name, row.model))
|
|
1047
|
+
self._sproc_models = models
|
|
1048
|
+
return models
|
|
1049
|
+
|
|
1050
|
+
def create_export(self, params: ExportParams):
|
|
1051
|
+
with debugging.span("create_export") as span:
|
|
1052
|
+
if params.dry_run:
|
|
1053
|
+
(sql_code, python_code, file_loc) = self.get_export_code(params, params.install_code)
|
|
1054
|
+
span["sql"] = sql_code
|
|
1055
|
+
return
|
|
1056
|
+
|
|
1057
|
+
start = time.perf_counter()
|
|
1058
|
+
use_graph_index = self.config.get("use_graph_index", USE_GRAPH_INDEX)
|
|
1059
|
+
# for the non graph index case we need to create the cloned proc database
|
|
1060
|
+
if not use_graph_index:
|
|
1061
|
+
raise RAIException(
|
|
1062
|
+
"To ensure permissions are properly accounted for, stored procedures require using the graph index. "
|
|
1063
|
+
"Set use_graph_index=True in your config to proceed."
|
|
1064
|
+
)
|
|
1065
|
+
|
|
1066
|
+
models = self.get_sproc_models(params)
|
|
1067
|
+
lib_installs = self.create_models_code(models)
|
|
1068
|
+
all_installs = lib_installs + "\n\n" + params.install_code
|
|
1069
|
+
|
|
1070
|
+
(sql_code, python_code, file_loc) = self.get_export_code(params, all_installs)
|
|
1071
|
+
|
|
1072
|
+
span["sql"] = sql_code
|
|
1073
|
+
assert self._session
|
|
1074
|
+
|
|
1075
|
+
with debugging.span("upload_sproc_code"):
|
|
1076
|
+
code_bytes = python_code.encode('utf-8')
|
|
1077
|
+
code_stream = io.BytesIO(code_bytes)
|
|
1078
|
+
self._session.file.put_stream(code_stream, file_loc, auto_compress=False, overwrite=True)
|
|
1079
|
+
|
|
1080
|
+
with debugging.span("sql_install"):
|
|
1081
|
+
self._exec(sql_code)
|
|
1082
|
+
|
|
1083
|
+
debugging.time("export", time.perf_counter() - start, DataFrame(), code=sql_code.replace(APP_NAME, self.get_app_name()))
|
|
1084
|
+
|
|
1085
|
+
|
|
1086
|
+
def create_export_table(self, database: str, engine: str, table: str, relation: str, columns: Dict[str, str], code: str, refresh: str|None=None):
|
|
1087
|
+
print("Snowflake doesn't support creating export tables yet. Try creating the table manually first.")
|
|
1088
|
+
pass
|
|
1089
|
+
|
|
1090
|
+
def delete_export(self, database: str, engine: str, name: str):
|
|
1091
|
+
pass
|
|
1092
|
+
|
|
1093
|
+
#--------------------------------------------------
|
|
1094
|
+
# Imports
|
|
1095
|
+
#--------------------------------------------------
|
|
1096
|
+
|
|
1097
|
+
def is_valid_import_state(self, state:str):
|
|
1098
|
+
return state in VALID_IMPORT_STATES
|
|
1099
|
+
|
|
1100
|
+
def imports_to_dicts(self, results):
|
|
1101
|
+
parsed_results = [
|
|
1102
|
+
{field.lower(): row[field] for field in IMPORT_STREAM_FIELDS}
|
|
1103
|
+
for row in results
|
|
1104
|
+
]
|
|
1105
|
+
return parsed_results
|
|
1106
|
+
|
|
1107
|
+
def change_stream_status(self, stream_id: str, model:str, suspend: bool):
|
|
1108
|
+
if stream_id and model:
|
|
1109
|
+
if suspend:
|
|
1110
|
+
self._exec(f"CALL {APP_NAME}.api.suspend_data_stream('{stream_id}', '{model}');")
|
|
1111
|
+
else:
|
|
1112
|
+
self._exec(f"CALL {APP_NAME}.api.resume_data_stream('{stream_id}', '{model}');")
|
|
1113
|
+
|
|
1114
|
+
def change_imports_status(self, suspend: bool):
|
|
1115
|
+
if suspend:
|
|
1116
|
+
self._exec(f"CALL {APP_NAME}.app.suspend_cdc();")
|
|
1117
|
+
else:
|
|
1118
|
+
self._exec(f"CALL {APP_NAME}.app.resume_cdc();")
|
|
1119
|
+
|
|
1120
|
+
def get_imports_status(self) -> ImportsStatus|None:
|
|
1121
|
+
# NOTE: We expect there to only ever be one result?
|
|
1122
|
+
results = self._exec(f"CALL {APP_NAME}.app.cdc_status();")
|
|
1123
|
+
if results:
|
|
1124
|
+
result = next(iter(results))
|
|
1125
|
+
engine = result['CDC_ENGINE_NAME']
|
|
1126
|
+
engine_status = result['CDC_ENGINE_STATUS']
|
|
1127
|
+
engine_size = result['CDC_ENGINE_SIZE']
|
|
1128
|
+
task_status = result['CDC_TASK_STATUS']
|
|
1129
|
+
info = result['CDC_TASK_INFO']
|
|
1130
|
+
enabled = result['CDC_ENABLED']
|
|
1131
|
+
return {"engine": engine, "engine_size": engine_size, "engine_status": engine_status, "status": task_status, "enabled": enabled, "info": info }
|
|
1132
|
+
return None
|
|
1133
|
+
|
|
1134
|
+
def set_imports_engine_size(self, size:str):
|
|
1135
|
+
try:
|
|
1136
|
+
self._exec(f"CALL {APP_NAME}.app.alter_cdc_engine_size('{size}');")
|
|
1137
|
+
except Exception as e:
|
|
1138
|
+
raise e
|
|
1139
|
+
|
|
1140
|
+
def list_imports(
|
|
1141
|
+
self,
|
|
1142
|
+
id:str|None = None,
|
|
1143
|
+
name:str|None = None,
|
|
1144
|
+
model:str|None = None,
|
|
1145
|
+
status:str|None = None,
|
|
1146
|
+
creator:str|None = None,
|
|
1147
|
+
) -> list[Import]:
|
|
1148
|
+
where = []
|
|
1149
|
+
if id and isinstance(id, str):
|
|
1150
|
+
where.append(f"LOWER(ID) = '{id.lower()}'")
|
|
1151
|
+
if name and isinstance(name, str):
|
|
1152
|
+
where.append(f"LOWER(FQ_OBJECT_NAME) = '{name.lower()}'")
|
|
1153
|
+
if model and isinstance(model, str):
|
|
1154
|
+
where.append(f"LOWER(RAI_DATABASE) = '{model.lower()}'")
|
|
1155
|
+
if creator and isinstance(creator, str):
|
|
1156
|
+
where.append(f"LOWER(CREATED_BY) = '{creator.lower()}'")
|
|
1157
|
+
if status and isinstance(status, str):
|
|
1158
|
+
where.append(f"LOWER(batch_status) = '{status.lower()}'")
|
|
1159
|
+
where_clause = " AND ".join(where)
|
|
1160
|
+
|
|
1161
|
+
# This is roughly inspired by the native app code because we don't have a way to
|
|
1162
|
+
# get the status of multiple streams at once and doing them individually is way
|
|
1163
|
+
# too slow. We use window functions to get the status of the stream and the batch
|
|
1164
|
+
# details.
|
|
1165
|
+
statement = f"""
|
|
1166
|
+
SELECT
|
|
1167
|
+
ID,
|
|
1168
|
+
RAI_DATABASE,
|
|
1169
|
+
FQ_OBJECT_NAME,
|
|
1170
|
+
CREATED_AT,
|
|
1171
|
+
CREATED_BY,
|
|
1172
|
+
CASE
|
|
1173
|
+
WHEN nextBatch.quarantined > 0 THEN 'quarantined'
|
|
1174
|
+
ELSE nextBatch.status
|
|
1175
|
+
END as batch_status,
|
|
1176
|
+
nextBatch.processing_errors,
|
|
1177
|
+
nextBatch.batches
|
|
1178
|
+
FROM {APP_NAME}.api.data_streams as ds
|
|
1179
|
+
LEFT JOIN (
|
|
1180
|
+
SELECT DISTINCT
|
|
1181
|
+
data_stream_id,
|
|
1182
|
+
-- Get status from the progress record using window functions
|
|
1183
|
+
FIRST_VALUE(status) OVER (
|
|
1184
|
+
PARTITION BY data_stream_id
|
|
1185
|
+
ORDER BY
|
|
1186
|
+
CASE WHEN unloaded IS NOT NULL THEN 1 ELSE 0 END DESC,
|
|
1187
|
+
unloaded ASC
|
|
1188
|
+
) as status,
|
|
1189
|
+
-- Get batch_details from the same record
|
|
1190
|
+
FIRST_VALUE(batch_details) OVER (
|
|
1191
|
+
PARTITION BY data_stream_id
|
|
1192
|
+
ORDER BY
|
|
1193
|
+
CASE WHEN unloaded IS NOT NULL THEN 1 ELSE 0 END DESC,
|
|
1194
|
+
unloaded ASC
|
|
1195
|
+
) as batch_details,
|
|
1196
|
+
-- Aggregate the other fields
|
|
1197
|
+
FIRST_VALUE(processing_details:processingErrors) OVER (
|
|
1198
|
+
PARTITION BY data_stream_id
|
|
1199
|
+
ORDER BY
|
|
1200
|
+
CASE WHEN unloaded IS NOT NULL THEN 1 ELSE 0 END DESC,
|
|
1201
|
+
unloaded ASC
|
|
1202
|
+
) as processing_errors,
|
|
1203
|
+
MIN(unloaded) OVER (PARTITION BY data_stream_id) as unloaded,
|
|
1204
|
+
COUNT(*) OVER (PARTITION BY data_stream_id) as batches,
|
|
1205
|
+
COUNT_IF(status = 'quarantined') OVER (PARTITION BY data_stream_id) as quarantined
|
|
1206
|
+
FROM {APP_NAME}.api.data_stream_batches
|
|
1207
|
+
) nextBatch
|
|
1208
|
+
ON ds.id = nextBatch.data_stream_id
|
|
1209
|
+
{f"where {where_clause}" if where_clause else ""}
|
|
1210
|
+
ORDER BY FQ_OBJECT_NAME ASC;
|
|
1211
|
+
"""
|
|
1212
|
+
results = self._exec(statement)
|
|
1213
|
+
items = []
|
|
1214
|
+
if results:
|
|
1215
|
+
for stream in results:
|
|
1216
|
+
(id, db, name, created_at, created_by, status, processing_errors, batches) = stream
|
|
1217
|
+
if status and isinstance(status, str):
|
|
1218
|
+
status = status.upper()
|
|
1219
|
+
if processing_errors:
|
|
1220
|
+
if status in ["QUARANTINED", "PENDING"]:
|
|
1221
|
+
start = processing_errors.rfind("Error")
|
|
1222
|
+
if start != -1:
|
|
1223
|
+
processing_errors = processing_errors[start:-1]
|
|
1224
|
+
else:
|
|
1225
|
+
processing_errors = None
|
|
1226
|
+
items.append(cast(Import, {
|
|
1227
|
+
"id": id,
|
|
1228
|
+
"model": db,
|
|
1229
|
+
"name": name,
|
|
1230
|
+
"created": created_at,
|
|
1231
|
+
"creator": created_by,
|
|
1232
|
+
"status": status.upper() if status else None,
|
|
1233
|
+
"errors": processing_errors if processing_errors != "[]" else None,
|
|
1234
|
+
"batches": f"{batches}" if batches else "",
|
|
1235
|
+
}))
|
|
1236
|
+
return items
|
|
1237
|
+
|
|
1238
|
+
def poll_imports(self, sources:List[str], model:str):
|
|
1239
|
+
source_set = self._create_source_set(sources)
|
|
1240
|
+
def check_imports():
|
|
1241
|
+
imports = [
|
|
1242
|
+
import_
|
|
1243
|
+
for import_ in self.list_imports(model=model)
|
|
1244
|
+
if import_["name"] in source_set
|
|
1245
|
+
]
|
|
1246
|
+
# loop through printing status for each in the format (index): (name) - (status)
|
|
1247
|
+
statuses = [import_["status"] for import_ in imports]
|
|
1248
|
+
if all(status == "LOADED" for status in statuses):
|
|
1249
|
+
return True
|
|
1250
|
+
if any(status == "QUARANTINED" for status in statuses):
|
|
1251
|
+
failed_imports = [import_["name"] for import_ in imports if import_["status"] == "QUARANTINED"]
|
|
1252
|
+
raise RAIException("Imports failed:" + ", ".join(failed_imports)) from None
|
|
1253
|
+
# this check is necessary in case some of the tables are empty;
|
|
1254
|
+
# such tables may be synced even though their status is None:
|
|
1255
|
+
def synced(import_):
|
|
1256
|
+
if import_["status"] == "LOADED":
|
|
1257
|
+
return True
|
|
1258
|
+
if import_["status"] is None:
|
|
1259
|
+
import_full_status = self.get_import_stream(import_["name"], model)
|
|
1260
|
+
if import_full_status and import_full_status[0]["data_sync_status"] == "SYNCED":
|
|
1261
|
+
return True
|
|
1262
|
+
return False
|
|
1263
|
+
if all(synced(import_) for import_ in imports):
|
|
1264
|
+
return True
|
|
1265
|
+
poll_with_specified_overhead(check_imports, overhead_rate=0.1, max_delay=10)
|
|
1266
|
+
|
|
1267
|
+
def _create_source_set(self, sources: List[str]) -> set:
|
|
1268
|
+
return {
|
|
1269
|
+
source.upper() if not IdentityParser(source).has_double_quoted_identifier else IdentityParser(source).identity
|
|
1270
|
+
for source in sources
|
|
1271
|
+
}
|
|
1272
|
+
|
|
1273
|
+
def get_import_stream(self, name:str|None, model:str|None):
|
|
1274
|
+
results = self._exec(f"CALL {APP_NAME}.api.get_data_stream('{name}', '{model}');")
|
|
1275
|
+
if not results:
|
|
1276
|
+
return None
|
|
1277
|
+
return self.imports_to_dicts(results)
|
|
1278
|
+
|
|
1279
|
+
def create_import_stream(self, source:ImportSource, model:str, rate = 1, options: dict|None = None):
|
|
1280
|
+
assert isinstance(source, ImportSourceTable), "Snowflake integration only supports loading from SF Tables. Try loading your data as a table via the Snowflake interface first."
|
|
1281
|
+
object = source.fqn
|
|
1282
|
+
|
|
1283
|
+
# Parse only to the schema level
|
|
1284
|
+
schemaParser = IdentityParser(f"{source.database}.{source.schema}")
|
|
1285
|
+
|
|
1286
|
+
if object.lower() in [x["name"].lower() for x in self.list_imports(model=model)]:
|
|
1287
|
+
return
|
|
1288
|
+
|
|
1289
|
+
query = f"SHOW OBJECTS LIKE '{source.table}' IN {schemaParser.identity}"
|
|
1290
|
+
|
|
1291
|
+
info = self._exec(query)
|
|
1292
|
+
if not info:
|
|
1293
|
+
raise ValueError(f"Object {source.table} not found in schema {schemaParser.identity}")
|
|
1294
|
+
else:
|
|
1295
|
+
data = info[0]
|
|
1296
|
+
if not data:
|
|
1297
|
+
raise ValueError(f"Object {source.table} not found in {schemaParser.identity}")
|
|
1298
|
+
# (time, name, db_name, schema_name, kind, *rest)
|
|
1299
|
+
kind = data["kind"]
|
|
1300
|
+
|
|
1301
|
+
relation_name = to_fqn_relation_name(object)
|
|
1302
|
+
|
|
1303
|
+
command = f"""call {APP_NAME}.api.create_data_stream(
|
|
1304
|
+
{APP_NAME}.api.object_reference('{kind}', '{object}'),
|
|
1305
|
+
'{model}',
|
|
1306
|
+
'{relation_name}');"""
|
|
1307
|
+
|
|
1308
|
+
def create_stream(tracking_just_changed=False):
|
|
1309
|
+
try:
|
|
1310
|
+
self._exec(command)
|
|
1311
|
+
except Exception as e:
|
|
1312
|
+
if "ensure that CHANGE_TRACKING is enabled on the source object" in str(e):
|
|
1313
|
+
if self.config.get("ensure_change_tracking", False) and not tracking_just_changed:
|
|
1314
|
+
try:
|
|
1315
|
+
self._exec(f"ALTER {kind} {object} SET CHANGE_TRACKING = TRUE;")
|
|
1316
|
+
create_stream(tracking_just_changed=True)
|
|
1317
|
+
except Exception:
|
|
1318
|
+
pass
|
|
1319
|
+
else:
|
|
1320
|
+
print("\n")
|
|
1321
|
+
exception = SnowflakeChangeTrackingNotEnabledException((object, kind))
|
|
1322
|
+
raise exception from None
|
|
1323
|
+
elif "Database does not exist" in str(e):
|
|
1324
|
+
print("\n")
|
|
1325
|
+
raise ModelNotFoundException(model) from None
|
|
1326
|
+
raise e
|
|
1327
|
+
|
|
1328
|
+
create_stream()
|
|
1329
|
+
|
|
1330
|
+
def create_import_snapshot(self, source:ImportSource, model:str, options: dict|None = None):
|
|
1331
|
+
raise Exception("Snowflake integration doesn't support snapshot imports yet")
|
|
1332
|
+
|
|
1333
|
+
def delete_import(self, import_name:str, model:str, force = False):
|
|
1334
|
+
engine = self.get_default_engine_name()
|
|
1335
|
+
rel_name = to_fqn_relation_name(import_name)
|
|
1336
|
+
try:
|
|
1337
|
+
self._exec(f"""call {APP_NAME}.api.delete_data_stream(
|
|
1338
|
+
'{import_name}',
|
|
1339
|
+
'{model}'
|
|
1340
|
+
);""")
|
|
1341
|
+
except RAIException as err:
|
|
1342
|
+
if "streams do not exist" not in str(err) or not force:
|
|
1343
|
+
raise
|
|
1344
|
+
|
|
1345
|
+
# if force is true, we delete the leftover relation to free up the name (in case the user re-creates the stream)
|
|
1346
|
+
if force:
|
|
1347
|
+
self.exec_raw(model, engine, f"""
|
|
1348
|
+
declare ::{rel_name}
|
|
1349
|
+
def delete[:\"{rel_name}\"]: {{ {rel_name} }}
|
|
1350
|
+
""", readonly=False, bypass_index=True)
|
|
1351
|
+
|
|
1352
|
+
#--------------------------------------------------
|
|
1353
|
+
# Exec Async
|
|
1354
|
+
#--------------------------------------------------
|
|
1355
|
+
|
|
1356
|
+
def _check_exec_async_status(self, txn_id: str, headers: Dict | None = None):
|
|
1357
|
+
"""Check whether the given transaction has completed."""
|
|
1358
|
+
if headers is None:
|
|
1359
|
+
headers = {}
|
|
1360
|
+
|
|
1361
|
+
with debugging.span("check_status"):
|
|
1362
|
+
response = self._exec(f"CALL {APP_NAME}.api.get_transaction('{txn_id}',{headers});")
|
|
1363
|
+
assert response, f"No results from get_transaction('{txn_id}')"
|
|
1364
|
+
|
|
1365
|
+
response_row = next(iter(response)).asDict()
|
|
1366
|
+
status: str = response_row['STATE']
|
|
1367
|
+
|
|
1368
|
+
# remove the transaction from the pending list if it's completed or aborted
|
|
1369
|
+
if status in ["COMPLETED", "ABORTED"]:
|
|
1370
|
+
if txn_id in self._pending_transactions:
|
|
1371
|
+
self._pending_transactions.remove(txn_id)
|
|
1372
|
+
|
|
1373
|
+
if status == "ABORTED" and response_row.get("ABORT_REASON", "") == TXN_ABORT_REASON_TIMEOUT:
|
|
1374
|
+
config_file_path = getattr(self.config, 'file_path', None)
|
|
1375
|
+
# todo: use the timeout returned alongside the transaction as soon as it's exposed
|
|
1376
|
+
timeout_mins = int(self.config.get("query_timeout_mins", DEFAULT_QUERY_TIMEOUT_MINS) or DEFAULT_QUERY_TIMEOUT_MINS)
|
|
1377
|
+
raise QueryTimeoutExceededException(
|
|
1378
|
+
timeout_mins=timeout_mins,
|
|
1379
|
+
query_id=txn_id,
|
|
1380
|
+
config_file_path=config_file_path,
|
|
1381
|
+
)
|
|
1382
|
+
|
|
1383
|
+
# @TODO: Find some way to tunnel the ABORT_REASON out. Azure doesn't have this, but it's handy
|
|
1384
|
+
return status == "COMPLETED" or status == "ABORTED"
|
|
1385
|
+
|
|
1386
|
+
def decrypt_stream(self, key: bytes, iv: bytes, src: bytes) -> bytes:
|
|
1387
|
+
"""Decrypt the provided stream with PKCS#5 padding handling."""
|
|
1388
|
+
|
|
1389
|
+
if crypto_disabled:
|
|
1390
|
+
if isinstance(runtime_env, SnowbookEnvironment) and runtime_env.runner == "warehouse":
|
|
1391
|
+
raise Exception("Please open the navigation-bar dropdown labeled *Packages* and select `cryptography` under the *Anaconda Packages* section, and then re-run your query.")
|
|
1392
|
+
else:
|
|
1393
|
+
raise Exception("library `cryptography.hazmat` missing; please install")
|
|
1394
|
+
|
|
1395
|
+
# `type:ignore`s are because of the conditional import, which
|
|
1396
|
+
# we have because warehouse-based snowflake notebooks don't support
|
|
1397
|
+
# the crypto library we're using.
|
|
1398
|
+
cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend()) # type: ignore
|
|
1399
|
+
decryptor = cipher.decryptor()
|
|
1400
|
+
|
|
1401
|
+
# Decrypt the data
|
|
1402
|
+
decrypted_padded_data = decryptor.update(src) + decryptor.finalize()
|
|
1403
|
+
|
|
1404
|
+
# Unpad the decrypted data using PKCS#5
|
|
1405
|
+
unpadder = padding.PKCS7(128).unpadder() # type: ignore # Use 128 directly for AES
|
|
1406
|
+
unpadded_data = unpadder.update(decrypted_padded_data) + unpadder.finalize()
|
|
1407
|
+
|
|
1408
|
+
return unpadded_data
|
|
1409
|
+
|
|
1410
|
+
def _decrypt_artifact(self, data: bytes, encryption_material: str) -> bytes:
|
|
1411
|
+
"""Decrypts the artifact data using provided encryption material."""
|
|
1412
|
+
encryption_material_parts = encryption_material.split("|")
|
|
1413
|
+
assert len(encryption_material_parts) == 3, "Invalid encryption material"
|
|
1414
|
+
|
|
1415
|
+
algorithm, key_base64, iv_base64 = encryption_material_parts
|
|
1416
|
+
assert algorithm == "AES_128_CBC", f"Unsupported encryption algorithm {algorithm}"
|
|
1417
|
+
|
|
1418
|
+
key = base64.standard_b64decode(key_base64)
|
|
1419
|
+
iv = base64.standard_b64decode(iv_base64)
|
|
1420
|
+
|
|
1421
|
+
return self.decrypt_stream(key, iv, data)
|
|
1422
|
+
|
|
1423
|
+
def _list_exec_async_artifacts(self, txn_id: str, headers: Dict | None = None) -> Dict[str, Dict]:
|
|
1424
|
+
"""Grab the list of artifacts produced in the transaction and the URLs to retrieve their contents."""
|
|
1425
|
+
if headers is None:
|
|
1426
|
+
headers = {}
|
|
1427
|
+
with debugging.span("list_results"):
|
|
1428
|
+
response = self._exec(
|
|
1429
|
+
f"CALL {APP_NAME}.api.get_own_transaction_artifacts('{txn_id}',{headers});"
|
|
1430
|
+
)
|
|
1431
|
+
assert response, f"No results from get_own_transaction_artifacts('{txn_id}')"
|
|
1432
|
+
return {row["FILENAME"]: row for row in response}
|
|
1433
|
+
|
|
1434
|
+
def _fetch_exec_async_artifacts(
|
|
1435
|
+
self, artifact_info: Dict[str, Dict[str, Any]]
|
|
1436
|
+
) -> Dict[str, Any]:
|
|
1437
|
+
"""Grab the contents of the given artifacts from SF in parallel using threads."""
|
|
1438
|
+
|
|
1439
|
+
with requests.Session() as session:
|
|
1440
|
+
def _fetch_data(name_info):
|
|
1441
|
+
filename, metadata = name_info
|
|
1442
|
+
|
|
1443
|
+
try:
|
|
1444
|
+
# Extract the presigned URL and encryption material from metadata
|
|
1445
|
+
url_key = self.get_url_key(metadata)
|
|
1446
|
+
presigned_url = metadata[url_key]
|
|
1447
|
+
encryption_material = metadata["ENCRYPTION_MATERIAL"]
|
|
1448
|
+
|
|
1449
|
+
response = get_with_retries(session, presigned_url, config=self.config)
|
|
1450
|
+
response.raise_for_status() # Throw if something goes wrong
|
|
1451
|
+
|
|
1452
|
+
decrypted = self._maybe_decrypt(response.content, encryption_material)
|
|
1453
|
+
return (filename, decrypted)
|
|
1454
|
+
|
|
1455
|
+
except requests.RequestException as e:
|
|
1456
|
+
raise scrub_exception(wrap_with_request_id(e))
|
|
1457
|
+
|
|
1458
|
+
# Create a list of tuples for the map function
|
|
1459
|
+
name_info_pairs = list(artifact_info.items())
|
|
1460
|
+
|
|
1461
|
+
with ThreadPoolExecutor(max_workers=5) as executor:
|
|
1462
|
+
results = executor.map(_fetch_data, name_info_pairs)
|
|
1463
|
+
|
|
1464
|
+
return {name: data for (name, data) in results}
|
|
1465
|
+
|
|
1466
|
+
def _maybe_decrypt(self, content: bytes, encryption_material: str) -> bytes:
|
|
1467
|
+
# Decrypt if encryption material is present
|
|
1468
|
+
if encryption_material:
|
|
1469
|
+
# if there's no padding, the initial file was empty
|
|
1470
|
+
if len(content) == 0:
|
|
1471
|
+
return b""
|
|
1472
|
+
|
|
1473
|
+
return self._decrypt_artifact(content, encryption_material)
|
|
1474
|
+
|
|
1475
|
+
# otherwise, return content directly
|
|
1476
|
+
return content
|
|
1477
|
+
|
|
1478
|
+
def _parse_exec_async_results(self, arrow_files: List[Tuple[str, bytes]]):
|
|
1479
|
+
"""Mimics the logic in _parse_arrow_results of railib/api.py#L303 without requiring a wrapping multipart form."""
|
|
1480
|
+
results = []
|
|
1481
|
+
|
|
1482
|
+
for file_name, file_content in arrow_files:
|
|
1483
|
+
with pa.ipc.open_stream(file_content) as reader:
|
|
1484
|
+
schema = reader.schema
|
|
1485
|
+
batches = [batch for batch in reader]
|
|
1486
|
+
table = pa.Table.from_batches(batches=batches, schema=schema)
|
|
1487
|
+
results.append({"relationId": file_name, "table": table})
|
|
1488
|
+
|
|
1489
|
+
return results
|
|
1490
|
+
|
|
1491
|
+
def _download_results(
|
|
1492
|
+
self, artifact_info: Dict[str, Dict], txn_id: str, state: str
|
|
1493
|
+
) -> TransactionAsyncResponse:
|
|
1494
|
+
with debugging.span("download_results"):
|
|
1495
|
+
# Fetch artifacts
|
|
1496
|
+
artifacts = self._fetch_exec_async_artifacts(artifact_info)
|
|
1497
|
+
|
|
1498
|
+
# Directly use meta_json as it is fetched
|
|
1499
|
+
meta_json_bytes = artifacts["metadata.json"]
|
|
1500
|
+
|
|
1501
|
+
# Decode the bytes and parse the JSON
|
|
1502
|
+
meta_json_str = meta_json_bytes.decode('utf-8')
|
|
1503
|
+
meta_json = json.loads(meta_json_str) # Parse the JSON string
|
|
1504
|
+
|
|
1505
|
+
# Use the metadata to map arrow files to the relations they contain
|
|
1506
|
+
try:
|
|
1507
|
+
arrow_files_to_relations = {
|
|
1508
|
+
artifact["filename"]: artifact["relationId"]
|
|
1509
|
+
for artifact in meta_json
|
|
1510
|
+
}
|
|
1511
|
+
except KeyError:
|
|
1512
|
+
# TODO: Remove this fallback mechanism later once several engine versions are updated
|
|
1513
|
+
arrow_files_to_relations = {
|
|
1514
|
+
f"{ix}.arrow": artifact["relationId"]
|
|
1515
|
+
for ix, artifact in enumerate(meta_json)
|
|
1516
|
+
}
|
|
1517
|
+
|
|
1518
|
+
# Hydrate the arrow files into tables
|
|
1519
|
+
results = self._parse_exec_async_results(
|
|
1520
|
+
[
|
|
1521
|
+
(arrow_files_to_relations[name], content)
|
|
1522
|
+
for name, content in artifacts.items()
|
|
1523
|
+
if name.endswith(".arrow")
|
|
1524
|
+
]
|
|
1525
|
+
)
|
|
1526
|
+
|
|
1527
|
+
# Create and return the response
|
|
1528
|
+
rsp = TransactionAsyncResponse()
|
|
1529
|
+
rsp.transaction = {
|
|
1530
|
+
"id": txn_id,
|
|
1531
|
+
"state": state,
|
|
1532
|
+
"response_format_version": None,
|
|
1533
|
+
}
|
|
1534
|
+
rsp.metadata = meta_json
|
|
1535
|
+
rsp.problems = artifacts.get(
|
|
1536
|
+
"problems.json"
|
|
1537
|
+
) # Safely access possible missing keys
|
|
1538
|
+
rsp.results = results
|
|
1539
|
+
return rsp
|
|
1540
|
+
|
|
1541
|
+
def get_transaction_problems(self, txn_id: str) -> List[Dict[str, Any]]:
|
|
1542
|
+
with debugging.span("get_own_transaction_problems"):
|
|
1543
|
+
response = self._exec(
|
|
1544
|
+
f"select * from table({APP_NAME}.api.get_own_transaction_problems('{txn_id}'));"
|
|
1545
|
+
)
|
|
1546
|
+
if not response:
|
|
1547
|
+
return []
|
|
1548
|
+
return response
|
|
1549
|
+
|
|
1550
|
+
def get_url_key(self, metadata) -> str:
|
|
1551
|
+
# In Azure, there is only one type of URL, which is used for both internal and
|
|
1552
|
+
# external access; always use that one
|
|
1553
|
+
if self.is_azure(metadata['PRESIGNED_URL']):
|
|
1554
|
+
return 'PRESIGNED_URL'
|
|
1555
|
+
|
|
1556
|
+
configured = self.config.get("download_url_type", None)
|
|
1557
|
+
if configured == "internal":
|
|
1558
|
+
return 'PRESIGNED_URL_AP'
|
|
1559
|
+
elif configured == "external":
|
|
1560
|
+
return "PRESIGNED_URL"
|
|
1561
|
+
|
|
1562
|
+
if self.is_container_runtime():
|
|
1563
|
+
return 'PRESIGNED_URL_AP'
|
|
1564
|
+
|
|
1565
|
+
return 'PRESIGNED_URL'
|
|
1566
|
+
|
|
1567
|
+
def is_azure(self, url) -> bool:
|
|
1568
|
+
return "blob.core.windows.net" in url
|
|
1569
|
+
|
|
1570
|
+
def is_container_runtime(self) -> bool:
|
|
1571
|
+
return isinstance(runtime_env, SnowbookEnvironment) and runtime_env.runner == "container"
|
|
1572
|
+
|
|
1573
|
+
def _exec_rai_app(
|
|
1574
|
+
self,
|
|
1575
|
+
database: str,
|
|
1576
|
+
engine: str | None,
|
|
1577
|
+
raw_code: str,
|
|
1578
|
+
inputs: Dict,
|
|
1579
|
+
readonly=True,
|
|
1580
|
+
nowait_durable=False,
|
|
1581
|
+
request_headers: Dict | None = None,
|
|
1582
|
+
bypass_index=False,
|
|
1583
|
+
language: str = "rel",
|
|
1584
|
+
query_timeout_mins: int | None = None,
|
|
1585
|
+
):
|
|
1586
|
+
assert language == "rel" or language == "lqp", "Only 'rel' and 'lqp' languages are supported"
|
|
1587
|
+
if query_timeout_mins is None and (timeout_value := self.config.get("query_timeout_mins", DEFAULT_QUERY_TIMEOUT_MINS)) is not None:
|
|
1588
|
+
query_timeout_mins = int(timeout_value)
|
|
1589
|
+
# Depending on the shape of the input, the behavior of exec_async_v2 changes.
|
|
1590
|
+
# When using the new format (with an object), the function retrieves the
|
|
1591
|
+
# 'rai' database by hashing the model and username. In contrast, the
|
|
1592
|
+
# current version directly uses the passed database value.
|
|
1593
|
+
# Therefore, we must use the original exec_async_v2 when not using the
|
|
1594
|
+
# graph index to ensure the correct database is utilized.
|
|
1595
|
+
use_graph_index = self.config.get("use_graph_index", USE_GRAPH_INDEX)
|
|
1596
|
+
if use_graph_index and not bypass_index:
|
|
1597
|
+
payload = {
|
|
1598
|
+
'database': database,
|
|
1599
|
+
'engine': engine,
|
|
1600
|
+
'inputs': inputs,
|
|
1601
|
+
'readonly': readonly,
|
|
1602
|
+
'nowait_durable': nowait_durable,
|
|
1603
|
+
'language': language,
|
|
1604
|
+
'headers': request_headers
|
|
1605
|
+
}
|
|
1606
|
+
if query_timeout_mins is not None:
|
|
1607
|
+
payload["timeout_mins"] = query_timeout_mins
|
|
1608
|
+
sql_string = f"CALL {APP_NAME}.api.exec_async_v2(?, {payload});"
|
|
1609
|
+
else:
|
|
1610
|
+
if query_timeout_mins is not None:
|
|
1611
|
+
sql_string = f"CALL {APP_NAME}.api.exec_async_v2('{database}','{engine}', ?, {inputs}, {readonly}, {nowait_durable}, '{language}', {query_timeout_mins}, {request_headers});"
|
|
1612
|
+
else:
|
|
1613
|
+
sql_string = f"CALL {APP_NAME}.api.exec_async_v2('{database}','{engine}', ?, {inputs}, {readonly}, {nowait_durable}, '{language}', {request_headers});"
|
|
1614
|
+
# Don't let exec setup GI on failure, exec_raw and exec_lqp will do that and add the correct headers.
|
|
1615
|
+
response = self._exec(
|
|
1616
|
+
sql_string,
|
|
1617
|
+
raw_code,
|
|
1618
|
+
skip_engine_db_error_retry=True,
|
|
1619
|
+
)
|
|
1620
|
+
if not response:
|
|
1621
|
+
raise Exception("Failed to create transaction")
|
|
1622
|
+
return response
|
|
1623
|
+
|
|
1624
|
+
def _exec_async_v2(
|
|
1625
|
+
self,
|
|
1626
|
+
database: str,
|
|
1627
|
+
engine: str | None,
|
|
1628
|
+
raw_code: str,
|
|
1629
|
+
inputs: Dict | None = None,
|
|
1630
|
+
readonly=True,
|
|
1631
|
+
nowait_durable=False,
|
|
1632
|
+
headers: Dict | None = None,
|
|
1633
|
+
bypass_index=False,
|
|
1634
|
+
language: str = "rel",
|
|
1635
|
+
query_timeout_mins: int | None = None,
|
|
1636
|
+
gi_setup_skipped: bool = False,
|
|
1637
|
+
):
|
|
1638
|
+
if inputs is None:
|
|
1639
|
+
inputs = {}
|
|
1640
|
+
request_headers = debugging.add_current_propagation_headers(headers)
|
|
1641
|
+
query_attrs_dict = json.loads(request_headers.get("X-Query-Attributes", "{}"))
|
|
1642
|
+
|
|
1643
|
+
with debugging.span("transaction", **query_attrs_dict) as txn_span:
|
|
1644
|
+
with debugging.span("create_v2", **query_attrs_dict) as create_span:
|
|
1645
|
+
request_headers['user-agent'] = get_pyrel_version(self.generation)
|
|
1646
|
+
request_headers['gi_setup_skipped'] = str(gi_setup_skipped)
|
|
1647
|
+
request_headers['pyrel_program_id'] = debugging.get_program_span_id() or ""
|
|
1648
|
+
response = self._exec_rai_app(
|
|
1649
|
+
database=database,
|
|
1650
|
+
engine=engine,
|
|
1651
|
+
raw_code=raw_code,
|
|
1652
|
+
inputs=inputs,
|
|
1653
|
+
readonly=readonly,
|
|
1654
|
+
nowait_durable=nowait_durable,
|
|
1655
|
+
request_headers=request_headers,
|
|
1656
|
+
bypass_index=bypass_index,
|
|
1657
|
+
language=language,
|
|
1658
|
+
query_timeout_mins=query_timeout_mins,
|
|
1659
|
+
)
|
|
1660
|
+
|
|
1661
|
+
artifact_info = {}
|
|
1662
|
+
rows = list(iter(response))
|
|
1663
|
+
|
|
1664
|
+
# process the first row since txn_id and state are the same for all rows
|
|
1665
|
+
first_row = rows[0]
|
|
1666
|
+
txn_id = first_row['ID']
|
|
1667
|
+
state = first_row['STATE']
|
|
1668
|
+
filename = first_row['FILENAME']
|
|
1669
|
+
|
|
1670
|
+
txn_span["txn_id"] = txn_id
|
|
1671
|
+
create_span["txn_id"] = txn_id
|
|
1672
|
+
debugging.event("transaction_created", txn_span, txn_id=txn_id)
|
|
1673
|
+
|
|
1674
|
+
# fast path: transaction already finished
|
|
1675
|
+
if state in ["COMPLETED", "ABORTED"]:
|
|
1676
|
+
if txn_id in self._pending_transactions:
|
|
1677
|
+
self._pending_transactions.remove(txn_id)
|
|
1678
|
+
|
|
1679
|
+
# Process rows to get the rest of the artifacts
|
|
1680
|
+
for row in rows:
|
|
1681
|
+
filename = row['FILENAME']
|
|
1682
|
+
artifact_info[filename] = row
|
|
1683
|
+
|
|
1684
|
+
# Slow path: transaction not done yet; start polling
|
|
1685
|
+
else:
|
|
1686
|
+
self._pending_transactions.append(txn_id)
|
|
1687
|
+
with debugging.span("wait", txn_id=txn_id):
|
|
1688
|
+
poll_with_specified_overhead(
|
|
1689
|
+
lambda: self._check_exec_async_status(txn_id, headers=request_headers), 0.1
|
|
1690
|
+
)
|
|
1691
|
+
artifact_info = self._list_exec_async_artifacts(txn_id, headers=request_headers)
|
|
1692
|
+
|
|
1693
|
+
with debugging.span("fetch"):
|
|
1694
|
+
return self._download_results(artifact_info, txn_id, state)
|
|
1695
|
+
|
|
1696
|
+
def get_user_based_engine_name(self):
|
|
1697
|
+
if not self._session:
|
|
1698
|
+
self._session = self.get_sf_session()
|
|
1699
|
+
user_table = self._session.sql("select current_user()").collect()
|
|
1700
|
+
user = user_table[0][0]
|
|
1701
|
+
assert isinstance(user, str), f"current_user() must return a string, not {type(user)}"
|
|
1702
|
+
return _sanitize_user_name(user)
|
|
1703
|
+
|
|
1704
|
+
def is_engine_ready(self, engine_name: str):
|
|
1705
|
+
engine = self.get_engine(engine_name)
|
|
1706
|
+
return engine and engine["state"] == "READY"
|
|
1707
|
+
|
|
1708
|
+
def auto_create_engine(self, name: str | None = None, size: str | None = None, headers: Dict | None = None):
|
|
1709
|
+
from v0.relationalai.tools.cli_helpers import validate_engine_name
|
|
1710
|
+
with debugging.span("auto_create_engine", active=self._active_engine) as span:
|
|
1711
|
+
active = self._get_active_engine()
|
|
1712
|
+
if active:
|
|
1713
|
+
return active
|
|
1714
|
+
|
|
1715
|
+
engine_name = name or self.get_default_engine_name()
|
|
1716
|
+
|
|
1717
|
+
# Use the provided size or fall back to the config
|
|
1718
|
+
if size:
|
|
1719
|
+
engine_size = size
|
|
1720
|
+
else:
|
|
1721
|
+
engine_size = self.config.get("engine_size", None)
|
|
1722
|
+
|
|
1723
|
+
# Validate engine size
|
|
1724
|
+
if engine_size:
|
|
1725
|
+
is_size_valid, sizes = self.validate_engine_size(engine_size)
|
|
1726
|
+
if not is_size_valid:
|
|
1727
|
+
raise Exception(f"Invalid engine size '{engine_size}'. Valid sizes are: {', '.join(sizes)}")
|
|
1728
|
+
|
|
1729
|
+
# Validate engine name
|
|
1730
|
+
is_name_valid, _ = validate_engine_name(engine_name)
|
|
1731
|
+
if not is_name_valid:
|
|
1732
|
+
raise EngineNameValidationException(engine_name)
|
|
1733
|
+
|
|
1734
|
+
try:
|
|
1735
|
+
engine = self.get_engine(engine_name)
|
|
1736
|
+
if engine:
|
|
1737
|
+
span.update(cast(dict, engine))
|
|
1738
|
+
|
|
1739
|
+
# if engine is in the pending state, poll until its status changes
|
|
1740
|
+
# if engine is gone, delete it and create new one
|
|
1741
|
+
# if engine is in the ready state, return engine name
|
|
1742
|
+
if engine:
|
|
1743
|
+
if engine["state"] == "PENDING":
|
|
1744
|
+
# if the user explicitly specified a size, warn if the pending engine size doesn't match it
|
|
1745
|
+
if size is not None and engine["size"] != size:
|
|
1746
|
+
EngineSizeMismatchWarning(engine_name, engine["size"], size)
|
|
1747
|
+
# poll until engine is ready
|
|
1748
|
+
with Spinner(
|
|
1749
|
+
"Waiting for engine to be initialized",
|
|
1750
|
+
"Engine ready",
|
|
1751
|
+
):
|
|
1752
|
+
poll_with_specified_overhead(lambda: self.is_engine_ready(engine_name), overhead_rate=0.1, max_delay=0.5, timeout=900)
|
|
1753
|
+
|
|
1754
|
+
elif engine["state"] == "SUSPENDED":
|
|
1755
|
+
with Spinner(f"Resuming engine '{engine_name}'", f"Engine '{engine_name}' resumed", f"Failed to resume engine '{engine_name}'"):
|
|
1756
|
+
try:
|
|
1757
|
+
self.resume_engine_async(engine_name, headers=headers)
|
|
1758
|
+
poll_with_specified_overhead(lambda: self.is_engine_ready(engine_name), overhead_rate=0.1, max_delay=0.5, timeout=900)
|
|
1759
|
+
except Exception:
|
|
1760
|
+
raise EngineResumeFailed(engine_name)
|
|
1761
|
+
elif engine["state"] == "READY":
|
|
1762
|
+
# if the user explicitly specified a size, warn if the ready engine size doesn't match it
|
|
1763
|
+
if size is not None and engine["size"] != size:
|
|
1764
|
+
EngineSizeMismatchWarning(engine_name, engine["size"], size)
|
|
1765
|
+
self._set_active_engine(engine)
|
|
1766
|
+
return engine_name
|
|
1767
|
+
elif engine["state"] == "GONE":
|
|
1768
|
+
try:
|
|
1769
|
+
# "Gone" is abnormal condition when metadata and SF service don't match
|
|
1770
|
+
# Therefore, we have to delete the engine and create a new one
|
|
1771
|
+
# it could be case that engine is already deleted, so we have to catch the exception
|
|
1772
|
+
self.delete_engine(engine_name, headers=headers)
|
|
1773
|
+
# After deleting the engine, set it to None so that we can create a new engine
|
|
1774
|
+
engine = None
|
|
1775
|
+
except Exception as e:
|
|
1776
|
+
# if engine is already deleted, we will get an exception
|
|
1777
|
+
# we can ignore this exception and create a new engine
|
|
1778
|
+
if isinstance(e, EngineNotFoundException):
|
|
1779
|
+
engine = None
|
|
1780
|
+
pass
|
|
1781
|
+
else:
|
|
1782
|
+
raise EngineProvisioningFailed(engine_name, e) from e
|
|
1783
|
+
|
|
1784
|
+
if not engine:
|
|
1785
|
+
with Spinner(
|
|
1786
|
+
f"Auto-creating engine {engine_name}",
|
|
1787
|
+
f"Auto-created engine {engine_name}",
|
|
1788
|
+
"Engine creation failed",
|
|
1789
|
+
):
|
|
1790
|
+
self.create_engine(engine_name, size=engine_size, headers=headers)
|
|
1791
|
+
except Exception as e:
|
|
1792
|
+
print(e)
|
|
1793
|
+
if DUO_TEXT in str(e).lower():
|
|
1794
|
+
raise DuoSecurityFailed(e)
|
|
1795
|
+
raise EngineProvisioningFailed(engine_name, e) from e
|
|
1796
|
+
return engine_name
|
|
1797
|
+
|
|
1798
|
+
def auto_create_engine_async(self, name: str | None = None):
|
|
1799
|
+
active = self._get_active_engine()
|
|
1800
|
+
if active and (active == name or name is None):
|
|
1801
|
+
return # @NOTE: This method weirdly doesn't return engine name even though all the other ones do?
|
|
1802
|
+
|
|
1803
|
+
with Spinner(
|
|
1804
|
+
"Checking engine status",
|
|
1805
|
+
leading_newline=True,
|
|
1806
|
+
) as spinner:
|
|
1807
|
+
from v0.relationalai.tools.cli_helpers import validate_engine_name
|
|
1808
|
+
with debugging.span("auto_create_engine_async", active=self._active_engine):
|
|
1809
|
+
engine_name = name or self.get_default_engine_name()
|
|
1810
|
+
engine_size = self.config.get("engine_size", None)
|
|
1811
|
+
if engine_size:
|
|
1812
|
+
is_size_valid, sizes = self.validate_engine_size(engine_size)
|
|
1813
|
+
if not is_size_valid:
|
|
1814
|
+
raise Exception(f"Invalid engine size in config: '{engine_size}'. Valid sizes are: {', '.join(sizes)}")
|
|
1815
|
+
else:
|
|
1816
|
+
engine_size = self.config.get_default_engine_size()
|
|
1817
|
+
|
|
1818
|
+
is_name_valid, _ = validate_engine_name(engine_name)
|
|
1819
|
+
if not is_name_valid:
|
|
1820
|
+
raise EngineNameValidationException(engine_name)
|
|
1821
|
+
try:
|
|
1822
|
+
engine = self.get_engine(engine_name)
|
|
1823
|
+
# if engine is gone, delete it and create new one
|
|
1824
|
+
# in case of pending state, do nothing, it is use_index responsibility to wait for engine to be ready
|
|
1825
|
+
if engine:
|
|
1826
|
+
if engine["state"] == "PENDING":
|
|
1827
|
+
spinner.update_messages(
|
|
1828
|
+
{
|
|
1829
|
+
"finished_message": f"Starting engine {engine_name}",
|
|
1830
|
+
}
|
|
1831
|
+
)
|
|
1832
|
+
pass
|
|
1833
|
+
elif engine["state"] == "SUSPENDED":
|
|
1834
|
+
spinner.update_messages(
|
|
1835
|
+
{
|
|
1836
|
+
"finished_message": f"Resuming engine {engine_name}",
|
|
1837
|
+
}
|
|
1838
|
+
)
|
|
1839
|
+
try:
|
|
1840
|
+
self.resume_engine_async(engine_name)
|
|
1841
|
+
except Exception:
|
|
1842
|
+
raise EngineResumeFailed(engine_name)
|
|
1843
|
+
elif engine["state"] == "READY":
|
|
1844
|
+
spinner.update_messages(
|
|
1845
|
+
{
|
|
1846
|
+
"finished_message": f"Engine {engine_name} initialized",
|
|
1847
|
+
}
|
|
1848
|
+
)
|
|
1849
|
+
pass
|
|
1850
|
+
elif engine["state"] == "GONE":
|
|
1851
|
+
spinner.update_messages(
|
|
1852
|
+
{
|
|
1853
|
+
"message": f"Restarting engine {engine_name}",
|
|
1854
|
+
}
|
|
1855
|
+
)
|
|
1856
|
+
try:
|
|
1857
|
+
# "Gone" is abnormal condition when metadata and SF service don't match
|
|
1858
|
+
# Therefore, we have to delete the engine and create a new one
|
|
1859
|
+
# it could be case that engine is already deleted, so we have to catch the exception
|
|
1860
|
+
# set it to None so that we can create a new engine
|
|
1861
|
+
engine = None
|
|
1862
|
+
self.delete_engine(engine_name)
|
|
1863
|
+
except Exception as e:
|
|
1864
|
+
# if engine is already deleted, we will get an exception
|
|
1865
|
+
# we can ignore this exception and create a new engine asynchronously
|
|
1866
|
+
if isinstance(e, EngineNotFoundException):
|
|
1867
|
+
engine = None
|
|
1868
|
+
pass
|
|
1869
|
+
else:
|
|
1870
|
+
print(e)
|
|
1871
|
+
raise EngineProvisioningFailed(engine_name, e) from e
|
|
1872
|
+
|
|
1873
|
+
if not engine:
|
|
1874
|
+
self.create_engine_async(engine_name, size=self.config.get("engine_size", None))
|
|
1875
|
+
spinner.update_messages(
|
|
1876
|
+
{
|
|
1877
|
+
"finished_message": f"Starting engine {engine_name}...",
|
|
1878
|
+
}
|
|
1879
|
+
)
|
|
1880
|
+
else:
|
|
1881
|
+
self._set_active_engine(engine)
|
|
1882
|
+
|
|
1883
|
+
except Exception as e:
|
|
1884
|
+
spinner.update_messages(
|
|
1885
|
+
{
|
|
1886
|
+
"finished_message": f"Failed to create engine {engine_name}",
|
|
1887
|
+
}
|
|
1888
|
+
)
|
|
1889
|
+
if DUO_TEXT in str(e).lower():
|
|
1890
|
+
raise DuoSecurityFailed(e)
|
|
1891
|
+
if isinstance(e, RAIException):
|
|
1892
|
+
raise e
|
|
1893
|
+
print(e)
|
|
1894
|
+
raise EngineProvisioningFailed(engine_name, e) from e
|
|
1895
|
+
|
|
1896
|
+
def validate_engine_size(self, size: str) -> Tuple[bool, List[str]]:
|
|
1897
|
+
if size is not None:
|
|
1898
|
+
sizes = self.get_engine_sizes()
|
|
1899
|
+
if size not in sizes:
|
|
1900
|
+
return False, sizes
|
|
1901
|
+
return True, []
|
|
1902
|
+
|
|
1903
|
+
#--------------------------------------------------
|
|
1904
|
+
# Exec
|
|
1905
|
+
#--------------------------------------------------
|
|
1906
|
+
|
|
1907
|
+
def _exec_with_gi_retry(
|
|
1908
|
+
self,
|
|
1909
|
+
database: str,
|
|
1910
|
+
engine: str | None,
|
|
1911
|
+
raw_code: str,
|
|
1912
|
+
inputs: Dict | None,
|
|
1913
|
+
readonly: bool,
|
|
1914
|
+
nowait_durable: bool,
|
|
1915
|
+
headers: Dict | None,
|
|
1916
|
+
bypass_index: bool,
|
|
1917
|
+
language: str,
|
|
1918
|
+
query_timeout_mins: int | None,
|
|
1919
|
+
):
|
|
1920
|
+
"""Execute with graph index retry logic.
|
|
1921
|
+
|
|
1922
|
+
Attempts execution with gi_setup_skipped=True first. If an engine or database
|
|
1923
|
+
issue occurs, polls use_index and retries with gi_setup_skipped=False.
|
|
1924
|
+
"""
|
|
1925
|
+
try:
|
|
1926
|
+
return self._exec_async_v2(
|
|
1927
|
+
database, engine, raw_code, inputs, readonly, nowait_durable,
|
|
1928
|
+
headers=headers, bypass_index=bypass_index, language=language,
|
|
1929
|
+
query_timeout_mins=query_timeout_mins, gi_setup_skipped=True,
|
|
1930
|
+
)
|
|
1931
|
+
except Exception as e:
|
|
1932
|
+
err_message = str(e).lower()
|
|
1933
|
+
if _is_engine_issue(err_message) or _is_database_issue(err_message):
|
|
1934
|
+
engine_name = engine or self.get_default_engine_name()
|
|
1935
|
+
engine_size = self.config.get_default_engine_size()
|
|
1936
|
+
self._poll_use_index(
|
|
1937
|
+
app_name=self.get_app_name(),
|
|
1938
|
+
sources=self.sources,
|
|
1939
|
+
model=database,
|
|
1940
|
+
engine_name=engine_name,
|
|
1941
|
+
engine_size=engine_size,
|
|
1942
|
+
headers=headers,
|
|
1943
|
+
)
|
|
1944
|
+
|
|
1945
|
+
return self._exec_async_v2(
|
|
1946
|
+
database, engine, raw_code, inputs, readonly, nowait_durable,
|
|
1947
|
+
headers=headers, bypass_index=bypass_index, language=language,
|
|
1948
|
+
query_timeout_mins=query_timeout_mins, gi_setup_skipped=False,
|
|
1949
|
+
)
|
|
1950
|
+
else:
|
|
1951
|
+
raise e
|
|
1952
|
+
|
|
1953
|
+
def exec_lqp(
|
|
1954
|
+
self,
|
|
1955
|
+
database: str,
|
|
1956
|
+
engine: str | None,
|
|
1957
|
+
raw_code: bytes,
|
|
1958
|
+
readonly=True,
|
|
1959
|
+
*,
|
|
1960
|
+
inputs: Dict | None = None,
|
|
1961
|
+
nowait_durable=False,
|
|
1962
|
+
headers: Dict | None = None,
|
|
1963
|
+
bypass_index=False,
|
|
1964
|
+
query_timeout_mins: int | None = None,
|
|
1965
|
+
):
|
|
1966
|
+
raw_code_b64 = base64.b64encode(raw_code).decode("utf-8")
|
|
1967
|
+
return self._exec_with_gi_retry(
|
|
1968
|
+
database, engine, raw_code_b64, inputs, readonly, nowait_durable,
|
|
1969
|
+
headers, bypass_index, 'lqp', query_timeout_mins
|
|
1970
|
+
)
|
|
1971
|
+
|
|
1972
|
+
|
|
1973
|
+
def exec_raw(
|
|
1974
|
+
self,
|
|
1975
|
+
database: str,
|
|
1976
|
+
engine: str | None,
|
|
1977
|
+
raw_code: str,
|
|
1978
|
+
readonly=True,
|
|
1979
|
+
*,
|
|
1980
|
+
inputs: Dict | None = None,
|
|
1981
|
+
nowait_durable=False,
|
|
1982
|
+
headers: Dict | None = None,
|
|
1983
|
+
bypass_index=False,
|
|
1984
|
+
query_timeout_mins: int | None = None,
|
|
1985
|
+
):
|
|
1986
|
+
raw_code = raw_code.replace("'", "\\'")
|
|
1987
|
+
return self._exec_with_gi_retry(
|
|
1988
|
+
database, engine, raw_code, inputs, readonly, nowait_durable,
|
|
1989
|
+
headers, bypass_index, 'rel', query_timeout_mins
|
|
1990
|
+
)
|
|
1991
|
+
|
|
1992
|
+
|
|
1993
|
+
def format_results(self, results, task:m.Task|None=None) -> Tuple[DataFrame, List[Any]]:
|
|
1994
|
+
return result_helpers.format_results(results, task)
|
|
1995
|
+
|
|
1996
|
+
#--------------------------------------------------
|
|
1997
|
+
# Exec format
|
|
1998
|
+
#--------------------------------------------------
|
|
1999
|
+
|
|
2000
|
+
def exec_format(
|
|
2001
|
+
self,
|
|
2002
|
+
database: str,
|
|
2003
|
+
engine: str,
|
|
2004
|
+
raw_code: str,
|
|
2005
|
+
cols: List[str],
|
|
2006
|
+
format: str,
|
|
2007
|
+
inputs: Dict | None = None,
|
|
2008
|
+
readonly=True,
|
|
2009
|
+
nowait_durable=False,
|
|
2010
|
+
skip_invalid_data=False,
|
|
2011
|
+
headers: Dict | None = None,
|
|
2012
|
+
query_timeout_mins: int | None = None,
|
|
2013
|
+
):
|
|
2014
|
+
if inputs is None:
|
|
2015
|
+
inputs = {}
|
|
2016
|
+
if headers is None:
|
|
2017
|
+
headers = {}
|
|
2018
|
+
if 'user-agent' not in headers:
|
|
2019
|
+
headers['user-agent'] = get_pyrel_version(self.generation)
|
|
2020
|
+
if query_timeout_mins is None and (timeout_value := self.config.get("query_timeout_mins", DEFAULT_QUERY_TIMEOUT_MINS)) is not None:
|
|
2021
|
+
query_timeout_mins = int(timeout_value)
|
|
2022
|
+
# TODO: add headers
|
|
2023
|
+
start = time.perf_counter()
|
|
2024
|
+
output_table = "out" + str(uuid.uuid4()).replace("-", "_")
|
|
2025
|
+
temp_table = f"temp_{output_table}"
|
|
2026
|
+
use_graph_index = self.config.get("use_graph_index", USE_GRAPH_INDEX)
|
|
2027
|
+
txn_id = None
|
|
2028
|
+
rejected_rows = None
|
|
2029
|
+
col_names_map = None
|
|
2030
|
+
artifacts = None
|
|
2031
|
+
assert self._session
|
|
2032
|
+
temp = self._session.createDataFrame([], StructType([StructField(name, StringType()) for name in cols]))
|
|
2033
|
+
with debugging.span("transaction") as txn_span:
|
|
2034
|
+
try:
|
|
2035
|
+
# In the graph index case we need to use the new exec_into_table proc as it obfuscates the db name
|
|
2036
|
+
with debugging.span("exec_format"):
|
|
2037
|
+
if use_graph_index:
|
|
2038
|
+
# we do not provide a default value for query_timeout_mins so that we can control the default on app level
|
|
2039
|
+
if query_timeout_mins is not None:
|
|
2040
|
+
res = self._exec(f"call {APP_NAME}.api.exec_into_table(?, ?, ?, ?, ?, NULL, ?, {headers}, ?, ?);", [database, engine, raw_code, output_table, readonly, nowait_durable, skip_invalid_data, query_timeout_mins])
|
|
2041
|
+
else:
|
|
2042
|
+
res = self._exec(f"call {APP_NAME}.api.exec_into_table(?, ?, ?, ?, ?, NULL, ?, {headers}, ?);", [database, engine, raw_code, output_table, readonly, nowait_durable, skip_invalid_data])
|
|
2043
|
+
txn_id = json.loads(res[0]["EXEC_INTO_TABLE"])["rai_transaction_id"]
|
|
2044
|
+
rejected_rows = json.loads(res[0]["EXEC_INTO_TABLE"]).get("rejected_rows", [])
|
|
2045
|
+
rejected_rows_count = json.loads(res[0]["EXEC_INTO_TABLE"]).get("rejected_rows_count", 0)
|
|
2046
|
+
else:
|
|
2047
|
+
if query_timeout_mins is not None:
|
|
2048
|
+
res = self._exec(f"call {APP_NAME}.api.exec_into(?, ?, ?, ?, ?, {inputs}, ?, {headers}, ?, ?);", [database, engine, raw_code, output_table, readonly, nowait_durable, skip_invalid_data, query_timeout_mins])
|
|
2049
|
+
else:
|
|
2050
|
+
res = self._exec(f"call {APP_NAME}.api.exec_into(?, ?, ?, ?, ?, {inputs}, ?, {headers}, ?);", [database, engine, raw_code, output_table, readonly, nowait_durable, skip_invalid_data])
|
|
2051
|
+
txn_id = json.loads(res[0]["EXEC_INTO"])["rai_transaction_id"]
|
|
2052
|
+
rejected_rows = json.loads(res[0]["EXEC_INTO"]).get("rejected_rows", [])
|
|
2053
|
+
rejected_rows_count = json.loads(res[0]["EXEC_INTO"]).get("rejected_rows_count", 0)
|
|
2054
|
+
debugging.event("transaction_created", txn_span, txn_id=txn_id)
|
|
2055
|
+
debugging.time("exec_format", time.perf_counter() - start, DataFrame())
|
|
2056
|
+
|
|
2057
|
+
with debugging.span("temp_table_swap", txn_id=txn_id):
|
|
2058
|
+
out_sample = self._exec(f"select * from {APP_NAME}.results.{output_table} limit 1;")
|
|
2059
|
+
if out_sample:
|
|
2060
|
+
keys = set([k.lower() for k in out_sample[0].as_dict().keys()])
|
|
2061
|
+
col_names_map = {}
|
|
2062
|
+
for ix, name in enumerate(cols):
|
|
2063
|
+
col_key = f"col{ix:03}"
|
|
2064
|
+
if col_key in keys:
|
|
2065
|
+
col_names_map[col_key] = IdentityParser(name).identity
|
|
2066
|
+
else:
|
|
2067
|
+
col_names_map[col_key] = name
|
|
2068
|
+
|
|
2069
|
+
names = ", ".join([
|
|
2070
|
+
f"{col_key} as {alias}" if col_key in keys else f"NULL as {alias}"
|
|
2071
|
+
for col_key, alias in col_names_map.items()
|
|
2072
|
+
])
|
|
2073
|
+
self._exec(f"CREATE TEMPORARY TABLE {APP_NAME}.results.{temp_table} AS SELECT {names} FROM {APP_NAME}.results.{output_table};")
|
|
2074
|
+
self._exec(f"call {APP_NAME}.api.drop_result_table(?)", [output_table])
|
|
2075
|
+
temp = cast(snowflake.snowpark.DataFrame, self._exec(f"select * from {APP_NAME}.results.{temp_table}", raw=True))
|
|
2076
|
+
if rejected_rows:
|
|
2077
|
+
debugging.warn(RowsDroppedFromTargetTableWarning(rejected_rows, rejected_rows_count, col_names_map))
|
|
2078
|
+
except Exception as e:
|
|
2079
|
+
msg = str(e).lower()
|
|
2080
|
+
if "no columns returned" in msg or "columns of results could not be determined" in msg:
|
|
2081
|
+
pass
|
|
2082
|
+
else:
|
|
2083
|
+
raise e
|
|
2084
|
+
if txn_id:
|
|
2085
|
+
artifact_info = self._list_exec_async_artifacts(txn_id)
|
|
2086
|
+
with debugging.span("fetch"):
|
|
2087
|
+
artifacts = self._download_results(artifact_info, txn_id, "ABORTED")
|
|
2088
|
+
return (temp, artifacts)
|
|
2089
|
+
|
|
2090
|
+
#--------------------------------------------------
|
|
2091
|
+
# Custom model types
|
|
2092
|
+
#--------------------------------------------------
|
|
2093
|
+
|
|
2094
|
+
def _get_ns(self, model:dsl.Graph):
|
|
2095
|
+
if model not in self._ns_cache:
|
|
2096
|
+
self._ns_cache[model] = _Snowflake(model)
|
|
2097
|
+
return self._ns_cache[model]
|
|
2098
|
+
|
|
2099
|
+
def to_model_type(self, model:dsl.Graph, name: str, source:str):
|
|
2100
|
+
parser = IdentityParser(source)
|
|
2101
|
+
if not parser.is_complete:
|
|
2102
|
+
raise SnowflakeInvalidSource(Errors.call_source(), source)
|
|
2103
|
+
ns = self._get_ns(model)
|
|
2104
|
+
# skip the last item in the list (the full identifier)
|
|
2105
|
+
for part in parser.to_list()[:-1]:
|
|
2106
|
+
ns = ns._safe_get(part)
|
|
2107
|
+
assert parser.identity, f"Error parsing source in to_model_type: {source}"
|
|
2108
|
+
self.sources.add(parser.identity)
|
|
2109
|
+
return ns
|
|
2110
|
+
|
|
2111
|
+
def _check_source_updates(self, sources: Iterable[str]):
|
|
2112
|
+
if not sources:
|
|
2113
|
+
return {}
|
|
2114
|
+
app_name = self.get_app_name()
|
|
2115
|
+
|
|
2116
|
+
source_types = dict[str, SourceInfo]()
|
|
2117
|
+
partitioned_sources: dict[str, dict[str, list[dict[str, str]]]] = defaultdict(
|
|
2118
|
+
lambda: defaultdict(list)
|
|
2119
|
+
)
|
|
2120
|
+
fqn_to_parts: dict[str, tuple[str, str, str]] = {}
|
|
2121
|
+
|
|
2122
|
+
for source in sources:
|
|
2123
|
+
parser = IdentityParser(source, True)
|
|
2124
|
+
parsed = parser.to_list()
|
|
2125
|
+
assert len(parsed) == 4, f"Invalid source: {source}"
|
|
2126
|
+
db, schema, entity, identity = parsed
|
|
2127
|
+
assert db and schema and entity and identity, f"Invalid source: {source}"
|
|
2128
|
+
source_types[identity] = cast(
|
|
2129
|
+
SourceInfo,
|
|
2130
|
+
{
|
|
2131
|
+
"type": None,
|
|
2132
|
+
"state": "",
|
|
2133
|
+
"columns_hash": None,
|
|
2134
|
+
"table_created_at": None,
|
|
2135
|
+
"stream_created_at": None,
|
|
2136
|
+
"last_ddl": None,
|
|
2137
|
+
},
|
|
2138
|
+
)
|
|
2139
|
+
partitioned_sources[db][schema].append({"entity": entity, "identity": identity})
|
|
2140
|
+
fqn_to_parts[identity] = (db, schema, entity)
|
|
2141
|
+
|
|
2142
|
+
if not partitioned_sources:
|
|
2143
|
+
return source_types
|
|
2144
|
+
|
|
2145
|
+
state_queries: list[str] = []
|
|
2146
|
+
for db, schemas in partitioned_sources.items():
|
|
2147
|
+
select_rows: list[str] = []
|
|
2148
|
+
for schema, tables in schemas.items():
|
|
2149
|
+
for table_info in tables:
|
|
2150
|
+
select_rows.append(
|
|
2151
|
+
"SELECT "
|
|
2152
|
+
f"{IdentityParser.to_sql_value(db)} AS catalog_name, "
|
|
2153
|
+
f"{IdentityParser.to_sql_value(schema)} AS schema_name, "
|
|
2154
|
+
f"{IdentityParser.to_sql_value(table_info['entity'])} AS table_name"
|
|
2155
|
+
)
|
|
2156
|
+
|
|
2157
|
+
if not select_rows:
|
|
2158
|
+
continue
|
|
2159
|
+
|
|
2160
|
+
target_entities_clause = "\n UNION ALL\n ".join(select_rows)
|
|
2161
|
+
# Main query:
|
|
2162
|
+
# 1. Enumerate the target tables via target_entities.
|
|
2163
|
+
# 2. Pull their metadata (last_altered, type) from INFORMATION_SCHEMA.TABLES.
|
|
2164
|
+
# 3. Look up the most recent stream activity for those FQNs only.
|
|
2165
|
+
# 4. Capture creation timestamps and use last_ddl vs created_at to classify each target,
|
|
2166
|
+
# so we mark tables as stale when they were recreated even if column hashes still match.
|
|
2167
|
+
state_queries.append(
|
|
2168
|
+
f"""WITH target_entities AS (
|
|
2169
|
+
{target_entities_clause}
|
|
2170
|
+
),
|
|
2171
|
+
table_info AS (
|
|
2172
|
+
SELECT
|
|
2173
|
+
{app_name}.api.normalize_fq_ids(
|
|
2174
|
+
ARRAY_CONSTRUCT(
|
|
2175
|
+
CASE
|
|
2176
|
+
WHEN t.table_catalog = UPPER(t.table_catalog) THEN t.table_catalog
|
|
2177
|
+
ELSE '"' || t.table_catalog || '"'
|
|
2178
|
+
END || '.' ||
|
|
2179
|
+
CASE
|
|
2180
|
+
WHEN t.table_schema = UPPER(t.table_schema) THEN t.table_schema
|
|
2181
|
+
ELSE '"' || t.table_schema || '"'
|
|
2182
|
+
END || '.' ||
|
|
2183
|
+
CASE
|
|
2184
|
+
WHEN t.table_name = UPPER(t.table_name) THEN t.table_name
|
|
2185
|
+
ELSE '"' || t.table_name || '"'
|
|
2186
|
+
END
|
|
2187
|
+
)
|
|
2188
|
+
)[0]:identifier::string AS fqn,
|
|
2189
|
+
CONVERT_TIMEZONE('UTC', t.last_altered) AS last_ddl,
|
|
2190
|
+
CONVERT_TIMEZONE('UTC', t.created) AS table_created_at,
|
|
2191
|
+
t.table_type AS kind
|
|
2192
|
+
FROM {db}.INFORMATION_SCHEMA.tables t
|
|
2193
|
+
JOIN target_entities te
|
|
2194
|
+
ON t.table_catalog = te.catalog_name
|
|
2195
|
+
AND t.table_schema = te.schema_name
|
|
2196
|
+
AND t.table_name = te.table_name
|
|
2197
|
+
),
|
|
2198
|
+
stream_activity AS (
|
|
2199
|
+
SELECT
|
|
2200
|
+
sa.fqn,
|
|
2201
|
+
MAX(sa.created_at) AS created_at
|
|
2202
|
+
FROM (
|
|
2203
|
+
SELECT
|
|
2204
|
+
{app_name}.api.normalize_fq_ids(ARRAY_CONSTRUCT(fq_object_name))[0]:identifier::string AS fqn,
|
|
2205
|
+
created_at
|
|
2206
|
+
FROM {app_name}.api.data_streams
|
|
2207
|
+
WHERE rai_database = '{PYREL_ROOT_DB}'
|
|
2208
|
+
) sa
|
|
2209
|
+
JOIN table_info ti
|
|
2210
|
+
ON sa.fqn = ti.fqn
|
|
2211
|
+
GROUP BY sa.fqn
|
|
2212
|
+
)
|
|
2213
|
+
SELECT
|
|
2214
|
+
ti.fqn,
|
|
2215
|
+
ti.kind,
|
|
2216
|
+
ti.last_ddl,
|
|
2217
|
+
ti.table_created_at,
|
|
2218
|
+
sa.created_at AS stream_created_at,
|
|
2219
|
+
IFF(
|
|
2220
|
+
DATEDIFF(second, sa.created_at::timestamp, ti.last_ddl::timestamp) > 0,
|
|
2221
|
+
'STALE',
|
|
2222
|
+
'CURRENT'
|
|
2223
|
+
) AS state
|
|
2224
|
+
FROM table_info ti
|
|
2225
|
+
LEFT JOIN stream_activity sa
|
|
2226
|
+
ON sa.fqn = ti.fqn
|
|
2227
|
+
"""
|
|
2228
|
+
)
|
|
2229
|
+
|
|
2230
|
+
stale_fqns: list[str] = []
|
|
2231
|
+
for state_query in state_queries:
|
|
2232
|
+
for row in self._exec(state_query):
|
|
2233
|
+
row_dict = row.as_dict() if hasattr(row, "as_dict") else dict(row)
|
|
2234
|
+
row_fqn = row_dict["FQN"]
|
|
2235
|
+
parser = IdentityParser(row_fqn, True)
|
|
2236
|
+
fqn = parser.identity
|
|
2237
|
+
assert fqn, f"Error parsing returned FQN: {row_fqn}"
|
|
2238
|
+
|
|
2239
|
+
source_types[fqn]["type"] = (
|
|
2240
|
+
"TABLE" if row_dict["KIND"] == "BASE TABLE" else row_dict["KIND"]
|
|
2241
|
+
)
|
|
2242
|
+
source_types[fqn]["state"] = row_dict["STATE"]
|
|
2243
|
+
source_types[fqn]["last_ddl"] = normalize_datetime(row_dict.get("LAST_DDL"))
|
|
2244
|
+
source_types[fqn]["table_created_at"] = normalize_datetime(row_dict.get("TABLE_CREATED_AT"))
|
|
2245
|
+
source_types[fqn]["stream_created_at"] = normalize_datetime(row_dict.get("STREAM_CREATED_AT"))
|
|
2246
|
+
if row_dict["STATE"] == "STALE":
|
|
2247
|
+
stale_fqns.append(fqn)
|
|
2248
|
+
|
|
2249
|
+
if not stale_fqns:
|
|
2250
|
+
return source_types
|
|
2251
|
+
|
|
2252
|
+
# We batch stale tables by database/schema so each Snowflake query can hash
|
|
2253
|
+
# multiple objects at once instead of issuing one statement per table.
|
|
2254
|
+
stale_partitioned: dict[str, dict[str, list[dict[str, str]]]] = defaultdict(
|
|
2255
|
+
lambda: defaultdict(list)
|
|
2256
|
+
)
|
|
2257
|
+
for fqn in stale_fqns:
|
|
2258
|
+
db, schema, table = fqn_to_parts[fqn]
|
|
2259
|
+
stale_partitioned[db][schema].append({"table": table, "identity": fqn})
|
|
2260
|
+
|
|
2261
|
+
# Build one hash query per database, grouping schemas/tables inside so we submit
|
|
2262
|
+
# at most a handful of set-based statements to Snowflake.
|
|
2263
|
+
for db, schemas in stale_partitioned.items():
|
|
2264
|
+
column_select_rows: list[str] = []
|
|
2265
|
+
for schema, tables in schemas.items():
|
|
2266
|
+
for table_info in tables:
|
|
2267
|
+
# Build the literal rows for this db/schema so we can join back
|
|
2268
|
+
# against INFORMATION_SCHEMA.COLUMNS in a single statement.
|
|
2269
|
+
column_select_rows.append(
|
|
2270
|
+
"SELECT "
|
|
2271
|
+
f"{IdentityParser.to_sql_value(db)} AS catalog_name, "
|
|
2272
|
+
f"{IdentityParser.to_sql_value(schema)} AS schema_name, "
|
|
2273
|
+
f"{IdentityParser.to_sql_value(table_info['table'])} AS table_name"
|
|
2274
|
+
)
|
|
2275
|
+
|
|
2276
|
+
if not column_select_rows:
|
|
2277
|
+
continue
|
|
2278
|
+
|
|
2279
|
+
target_entities_clause = "\n UNION ALL\n ".join(column_select_rows)
|
|
2280
|
+
# Main query: compute deterministic column hashes for every stale table
|
|
2281
|
+
# in this database/schema batch so we can compare schemas without a round trip per table.
|
|
2282
|
+
column_query = f"""WITH target_entities AS (
|
|
2283
|
+
{target_entities_clause}
|
|
2284
|
+
),
|
|
2285
|
+
column_info AS (
|
|
2286
|
+
SELECT
|
|
2287
|
+
{app_name}.api.normalize_fq_ids(
|
|
2288
|
+
ARRAY_CONSTRUCT(
|
|
2289
|
+
CASE
|
|
2290
|
+
WHEN c.table_catalog = UPPER(c.table_catalog) THEN c.table_catalog
|
|
2291
|
+
ELSE '"' || c.table_catalog || '"'
|
|
2292
|
+
END || '.' ||
|
|
2293
|
+
CASE
|
|
2294
|
+
WHEN c.table_schema = UPPER(c.table_schema) THEN c.table_schema
|
|
2295
|
+
ELSE '"' || c.table_schema || '"'
|
|
2296
|
+
END || '.' ||
|
|
2297
|
+
CASE
|
|
2298
|
+
WHEN c.table_name = UPPER(c.table_name) THEN c.table_name
|
|
2299
|
+
ELSE '"' || c.table_name || '"'
|
|
2300
|
+
END
|
|
2301
|
+
)
|
|
2302
|
+
)[0]:identifier::string AS fqn,
|
|
2303
|
+
c.column_name,
|
|
2304
|
+
CASE
|
|
2305
|
+
WHEN c.numeric_precision IS NOT NULL AND c.numeric_scale IS NOT NULL
|
|
2306
|
+
THEN c.data_type || '(' || c.numeric_precision || ',' || c.numeric_scale || ')'
|
|
2307
|
+
WHEN c.datetime_precision IS NOT NULL
|
|
2308
|
+
THEN c.data_type || '(0,' || c.datetime_precision || ')'
|
|
2309
|
+
WHEN c.character_maximum_length IS NOT NULL
|
|
2310
|
+
THEN c.data_type || '(' || c.character_maximum_length || ')'
|
|
2311
|
+
ELSE c.data_type
|
|
2312
|
+
END AS type_signature,
|
|
2313
|
+
IFF(c.is_nullable = 'YES', 'YES', 'NO') AS nullable_flag
|
|
2314
|
+
FROM {db}.INFORMATION_SCHEMA.COLUMNS c
|
|
2315
|
+
JOIN target_entities te
|
|
2316
|
+
ON c.table_catalog = te.catalog_name
|
|
2317
|
+
AND c.table_schema = te.schema_name
|
|
2318
|
+
AND c.table_name = te.table_name
|
|
2319
|
+
)
|
|
2320
|
+
SELECT
|
|
2321
|
+
fqn,
|
|
2322
|
+
HEX_ENCODE(
|
|
2323
|
+
HASH_AGG(
|
|
2324
|
+
HASH(
|
|
2325
|
+
column_name,
|
|
2326
|
+
type_signature,
|
|
2327
|
+
nullable_flag
|
|
2328
|
+
)
|
|
2329
|
+
)
|
|
2330
|
+
) AS columns_hash
|
|
2331
|
+
FROM column_info
|
|
2332
|
+
GROUP BY fqn
|
|
2333
|
+
"""
|
|
2334
|
+
|
|
2335
|
+
for row in self._exec(column_query):
|
|
2336
|
+
row_fqn = row["FQN"]
|
|
2337
|
+
parser = IdentityParser(row_fqn, True)
|
|
2338
|
+
fqn = parser.identity
|
|
2339
|
+
assert fqn, f"Error parsing returned FQN: {row_fqn}"
|
|
2340
|
+
source_types[fqn]["columns_hash"] = row["COLUMNS_HASH"]
|
|
2341
|
+
|
|
2342
|
+
return source_types
|
|
2343
|
+
|
|
2344
|
+
def _get_source_references(self, source_info: dict[str, SourceInfo]):
|
|
2345
|
+
app_name = self.get_app_name()
|
|
2346
|
+
missing_sources = []
|
|
2347
|
+
invalid_sources = {}
|
|
2348
|
+
source_references = []
|
|
2349
|
+
for source, info in source_info.items():
|
|
2350
|
+
source_type = info.get("type")
|
|
2351
|
+
if source_type is None:
|
|
2352
|
+
missing_sources.append(source)
|
|
2353
|
+
elif source_type not in ("TABLE", "VIEW"):
|
|
2354
|
+
invalid_sources[source] = source_type
|
|
2355
|
+
else:
|
|
2356
|
+
source_references.append(f"{app_name}.api.object_reference('{source_type}', '{source}')")
|
|
2357
|
+
|
|
2358
|
+
if missing_sources:
|
|
2359
|
+
current_role = self.get_sf_session().get_current_role()
|
|
2360
|
+
if current_role is None:
|
|
2361
|
+
current_role = self.config.get("role", None)
|
|
2362
|
+
debugging.warn(UnknownSourceWarning(missing_sources, current_role))
|
|
2363
|
+
|
|
2364
|
+
if invalid_sources:
|
|
2365
|
+
debugging.warn(InvalidSourceTypeWarning(invalid_sources))
|
|
2366
|
+
|
|
2367
|
+
self.source_references = source_references
|
|
2368
|
+
return source_references
|
|
2369
|
+
|
|
2370
|
+
#--------------------------------------------------
|
|
2371
|
+
# Transactions
|
|
2372
|
+
#--------------------------------------------------
|
|
2373
|
+
def txn_list_to_dicts(self, transactions):
|
|
2374
|
+
dicts = []
|
|
2375
|
+
for txn in transactions:
|
|
2376
|
+
dict = {}
|
|
2377
|
+
txn_dict = txn.asDict()
|
|
2378
|
+
for key in txn_dict:
|
|
2379
|
+
mapValue = FIELD_MAP.get(key.lower())
|
|
2380
|
+
if mapValue:
|
|
2381
|
+
dict[mapValue] = txn_dict[key]
|
|
2382
|
+
else:
|
|
2383
|
+
dict[key.lower()] = txn_dict[key]
|
|
2384
|
+
dicts.append(dict)
|
|
2385
|
+
return dicts
|
|
2386
|
+
|
|
2387
|
+
def get_transaction(self, transaction_id):
|
|
2388
|
+
results = self._exec(
|
|
2389
|
+
f"CALL {APP_NAME}.api.get_transaction(?);", [transaction_id])
|
|
2390
|
+
if not results:
|
|
2391
|
+
return None
|
|
2392
|
+
|
|
2393
|
+
results = self.txn_list_to_dicts(results)
|
|
2394
|
+
|
|
2395
|
+
txn = {field: results[0][field] for field in GET_TXN_SQL_FIELDS}
|
|
2396
|
+
|
|
2397
|
+
state = txn.get("state")
|
|
2398
|
+
created_on = txn.get("created_on")
|
|
2399
|
+
finished_at = txn.get("finished_at")
|
|
2400
|
+
if created_on:
|
|
2401
|
+
# Transaction is still running
|
|
2402
|
+
if state not in TERMINAL_TXN_STATES:
|
|
2403
|
+
tz_info = created_on.tzinfo
|
|
2404
|
+
txn['duration'] = datetime.now(tz_info) - created_on
|
|
2405
|
+
# Transaction is terminal
|
|
2406
|
+
elif finished_at:
|
|
2407
|
+
txn['duration'] = finished_at - created_on
|
|
2408
|
+
# Transaction is still running and we have no state or finished_at
|
|
2409
|
+
else:
|
|
2410
|
+
txn['duration'] = timedelta(0)
|
|
2411
|
+
return txn
|
|
2412
|
+
|
|
2413
|
+
def list_transactions(self, **kwargs):
|
|
2414
|
+
id = kwargs.get("id", None)
|
|
2415
|
+
state = kwargs.get("state", None)
|
|
2416
|
+
engine = kwargs.get("engine", None)
|
|
2417
|
+
limit = kwargs.get("limit", 100)
|
|
2418
|
+
all_users = kwargs.get("all_users", False)
|
|
2419
|
+
created_by = kwargs.get("created_by", None)
|
|
2420
|
+
only_active = kwargs.get("only_active", False)
|
|
2421
|
+
where_clause_arr = []
|
|
2422
|
+
|
|
2423
|
+
if id:
|
|
2424
|
+
where_clause_arr.append(f"id = '{id}'")
|
|
2425
|
+
if state:
|
|
2426
|
+
where_clause_arr.append(f"state = '{state.upper()}'")
|
|
2427
|
+
if engine:
|
|
2428
|
+
where_clause_arr.append(f"LOWER(engine_name) = '{engine.lower()}'")
|
|
2429
|
+
else:
|
|
2430
|
+
if only_active:
|
|
2431
|
+
where_clause_arr.append("state in ('CREATED', 'RUNNING', 'PENDING')")
|
|
2432
|
+
if not all_users and created_by is not None:
|
|
2433
|
+
where_clause_arr.append(f"LOWER(created_by) = '{created_by.lower()}'")
|
|
2434
|
+
|
|
2435
|
+
if len(where_clause_arr):
|
|
2436
|
+
where_clause = f'WHERE {" AND ".join(where_clause_arr)}'
|
|
2437
|
+
else:
|
|
2438
|
+
where_clause = ""
|
|
2439
|
+
|
|
2440
|
+
sql_fields = ", ".join(LIST_TXN_SQL_FIELDS)
|
|
2441
|
+
query = f"SELECT {sql_fields} from {APP_NAME}.api.transactions {where_clause} ORDER BY created_on DESC LIMIT ?"
|
|
2442
|
+
results = self._exec(query, [limit])
|
|
2443
|
+
if not results:
|
|
2444
|
+
return []
|
|
2445
|
+
return self.txn_list_to_dicts(results)
|
|
2446
|
+
|
|
2447
|
+
def cancel_transaction(self, transaction_id):
|
|
2448
|
+
self._exec(f"CALL {APP_NAME}.api.cancel_own_transaction(?);", [transaction_id])
|
|
2449
|
+
if transaction_id in self._pending_transactions:
|
|
2450
|
+
self._pending_transactions.remove(transaction_id)
|
|
2451
|
+
|
|
2452
|
+
def cancel_pending_transactions(self):
|
|
2453
|
+
for txn_id in self._pending_transactions:
|
|
2454
|
+
self.cancel_transaction(txn_id)
|
|
2455
|
+
|
|
2456
|
+
def get_transaction_events(self, transaction_id: str, continuation_token:str=''):
|
|
2457
|
+
results = self._exec(
|
|
2458
|
+
f"SELECT {APP_NAME}.api.get_own_transaction_events(?, ?);",
|
|
2459
|
+
[transaction_id, continuation_token],
|
|
2460
|
+
)
|
|
2461
|
+
if not results:
|
|
2462
|
+
return {
|
|
2463
|
+
"events": [],
|
|
2464
|
+
"continuation_token": None
|
|
2465
|
+
}
|
|
2466
|
+
row = results[0][0]
|
|
2467
|
+
return json.loads(row)
|
|
2468
|
+
|
|
2469
|
+
#--------------------------------------------------
|
|
2470
|
+
# Snowflake specific
|
|
2471
|
+
#--------------------------------------------------
|
|
2472
|
+
|
|
2473
|
+
def get_version(self):
|
|
2474
|
+
results = self._exec(f"SELECT {APP_NAME}.app.get_release()")
|
|
2475
|
+
if not results:
|
|
2476
|
+
return None
|
|
2477
|
+
return results[0][0]
|
|
2478
|
+
|
|
2479
|
+
def list_warehouses(self):
|
|
2480
|
+
results = self._exec("SHOW WAREHOUSES")
|
|
2481
|
+
if not results:
|
|
2482
|
+
return []
|
|
2483
|
+
return [{"name":name}
|
|
2484
|
+
for (name, *rest) in results]
|
|
2485
|
+
|
|
2486
|
+
def list_compute_pools(self):
|
|
2487
|
+
results = self._exec("SHOW COMPUTE POOLS")
|
|
2488
|
+
if not results:
|
|
2489
|
+
return []
|
|
2490
|
+
return [{"name":name, "status":status, "min_nodes":min_nodes, "max_nodes":max_nodes, "instance_family":instance_family}
|
|
2491
|
+
for (name, status, min_nodes, max_nodes, instance_family, *rest) in results]
|
|
2492
|
+
|
|
2493
|
+
def list_roles(self):
|
|
2494
|
+
results = self._exec("SELECT CURRENT_AVAILABLE_ROLES()")
|
|
2495
|
+
if not results:
|
|
2496
|
+
return []
|
|
2497
|
+
# the response is a single row with a single column containing
|
|
2498
|
+
# a stringified JSON array of role names:
|
|
2499
|
+
row = results[0]
|
|
2500
|
+
if not row:
|
|
2501
|
+
return []
|
|
2502
|
+
return [{"name": name} for name in json.loads(row[0])]
|
|
2503
|
+
|
|
2504
|
+
def list_apps(self):
|
|
2505
|
+
all_apps = self._exec(f"SHOW APPLICATIONS LIKE '{RAI_APP_NAME}'")
|
|
2506
|
+
if not all_apps:
|
|
2507
|
+
all_apps = self._exec("SHOW APPLICATIONS")
|
|
2508
|
+
if not all_apps:
|
|
2509
|
+
return []
|
|
2510
|
+
return [{"name":name}
|
|
2511
|
+
for (time, name, *rest) in all_apps]
|
|
2512
|
+
|
|
2513
|
+
def list_databases(self):
|
|
2514
|
+
results = self._exec("SHOW DATABASES")
|
|
2515
|
+
if not results:
|
|
2516
|
+
return []
|
|
2517
|
+
return [{"name":name}
|
|
2518
|
+
for (time, name, *rest) in results]
|
|
2519
|
+
|
|
2520
|
+
def list_sf_schemas(self, database:str):
|
|
2521
|
+
results = self._exec(f"SHOW SCHEMAS IN {database}")
|
|
2522
|
+
if not results:
|
|
2523
|
+
return []
|
|
2524
|
+
return [{"name":name}
|
|
2525
|
+
for (time, name, *rest) in results]
|
|
2526
|
+
|
|
2527
|
+
def list_tables(self, database:str, schema:str):
|
|
2528
|
+
results = self._exec(f"SHOW OBJECTS IN {database}.{schema}")
|
|
2529
|
+
items = []
|
|
2530
|
+
if results:
|
|
2531
|
+
for (time, name, db_name, schema_name, kind, *rest) in results:
|
|
2532
|
+
items.append({"name":name, "kind":kind.lower()})
|
|
2533
|
+
return items
|
|
2534
|
+
|
|
2535
|
+
def schema_info(self, database:str, schema:str, tables:Iterable[str]):
|
|
2536
|
+
app_name = self.get_app_name()
|
|
2537
|
+
# Only pass the db + schema as the identifier so that the resulting identity is correct
|
|
2538
|
+
parser = IdentityParser(f"{database}.{schema}")
|
|
2539
|
+
|
|
2540
|
+
with debugging.span("schema_info"):
|
|
2541
|
+
with debugging.span("primary_keys") as span:
|
|
2542
|
+
pk_query = f"SHOW PRIMARY KEYS IN SCHEMA {parser.identity};"
|
|
2543
|
+
pks = self._exec(pk_query)
|
|
2544
|
+
span["sql"] = pk_query
|
|
2545
|
+
|
|
2546
|
+
with debugging.span("foreign_keys") as span:
|
|
2547
|
+
fk_query = f"SHOW IMPORTED KEYS IN SCHEMA {parser.identity};"
|
|
2548
|
+
fks = self._exec(fk_query)
|
|
2549
|
+
span["sql"] = fk_query
|
|
2550
|
+
|
|
2551
|
+
# IdentityParser will parse a single value (with no ".") and store it in this case in the db field
|
|
2552
|
+
with debugging.span("columns") as span:
|
|
2553
|
+
tables = ", ".join([f"'{IdentityParser(t).db}'" for t in tables])
|
|
2554
|
+
query = textwrap.dedent(f"""
|
|
2555
|
+
begin
|
|
2556
|
+
SHOW COLUMNS IN SCHEMA {parser.identity};
|
|
2557
|
+
let r resultset := (
|
|
2558
|
+
SELECT
|
|
2559
|
+
CASE
|
|
2560
|
+
WHEN "table_name" = UPPER("table_name") THEN "table_name"
|
|
2561
|
+
ELSE '"' || "table_name" || '"'
|
|
2562
|
+
END as "table_name",
|
|
2563
|
+
"column_name",
|
|
2564
|
+
"data_type",
|
|
2565
|
+
CASE
|
|
2566
|
+
WHEN ARRAY_CONTAINS(PARSE_JSON("data_type"):"type", {app_name}.app.get_supported_column_types()) THEN TRUE
|
|
2567
|
+
ELSE FALSE
|
|
2568
|
+
END as "supported_type"
|
|
2569
|
+
FROM table(result_scan(-1)) as t
|
|
2570
|
+
WHERE "table_name" in ({tables})
|
|
2571
|
+
);
|
|
2572
|
+
return table(r);
|
|
2573
|
+
end;
|
|
2574
|
+
""")
|
|
2575
|
+
span["sql"] = query
|
|
2576
|
+
columns = self._exec(query)
|
|
2577
|
+
|
|
2578
|
+
results = defaultdict(lambda: {"pks": [], "fks": {}, "columns": {}, "invalid_columns": {}})
|
|
2579
|
+
if pks:
|
|
2580
|
+
for row in pks:
|
|
2581
|
+
results[row[3]]["pks"].append(row[4]) # type: ignore
|
|
2582
|
+
if fks:
|
|
2583
|
+
for row in fks:
|
|
2584
|
+
results[row[7]]["fks"][row[8]] = row[3]
|
|
2585
|
+
if columns:
|
|
2586
|
+
# It seems that a SF parameter (QUOTED_IDENTIFIERS_IGNORE_CASE) can control
|
|
2587
|
+
# whether snowflake will ignore case on `row.data_type`,
|
|
2588
|
+
# so we have to use column indexes instead :(
|
|
2589
|
+
for row in columns:
|
|
2590
|
+
table_name = row[0]
|
|
2591
|
+
column_name = row[1]
|
|
2592
|
+
data_type = row[2]
|
|
2593
|
+
supported_type = row[3]
|
|
2594
|
+
# Filter out unsupported types
|
|
2595
|
+
if supported_type:
|
|
2596
|
+
results[table_name]["columns"][column_name] = data_type
|
|
2597
|
+
else:
|
|
2598
|
+
results[table_name]["invalid_columns"][column_name] = data_type
|
|
2599
|
+
return results
|
|
2600
|
+
|
|
2601
|
+
def get_cloud_provider(self) -> str:
|
|
2602
|
+
"""
|
|
2603
|
+
Detect whether this is Snowflake on Azure, or AWS using Snowflake's CURRENT_REGION().
|
|
2604
|
+
Returns 'azure' or 'aws'.
|
|
2605
|
+
"""
|
|
2606
|
+
if self._session:
|
|
2607
|
+
try:
|
|
2608
|
+
# Query Snowflake's current region using the built-in function
|
|
2609
|
+
result = self._session.sql("SELECT CURRENT_REGION()").collect()
|
|
2610
|
+
if result:
|
|
2611
|
+
region_info = result[0][0]
|
|
2612
|
+
# Check if the region string contains the cloud provider name
|
|
2613
|
+
if isinstance(region_info, str):
|
|
2614
|
+
region_str = region_info.lower()
|
|
2615
|
+
# Check for cloud providers in the region string
|
|
2616
|
+
if 'azure' in region_str:
|
|
2617
|
+
return 'azure'
|
|
2618
|
+
else:
|
|
2619
|
+
return 'aws'
|
|
2620
|
+
except Exception:
|
|
2621
|
+
pass
|
|
2622
|
+
|
|
2623
|
+
# Fallback to AWS as default if detection fails
|
|
2624
|
+
return 'aws'
|
|
2625
|
+
|
|
2626
|
+
#--------------------------------------------------
|
|
2627
|
+
# Snowflake Wrapper
|
|
2628
|
+
#--------------------------------------------------
|
|
2629
|
+
|
|
2630
|
+
class PrimaryKey:
|
|
2631
|
+
pass
|
|
2632
|
+
|
|
2633
|
+
class _Snowflake:
|
|
2634
|
+
def __init__(self, model, auto_import=False):
|
|
2635
|
+
self._model = model
|
|
2636
|
+
self._auto_import = auto_import
|
|
2637
|
+
if not isinstance(model._client.resources, Resources):
|
|
2638
|
+
raise ValueError("Snowflake model must be used with a snowflake config")
|
|
2639
|
+
self._dbs = {}
|
|
2640
|
+
imports = model._client.resources.list_imports(model=model.name)
|
|
2641
|
+
self._import_structure(imports)
|
|
2642
|
+
|
|
2643
|
+
def _import_structure(self, imports: list[Import]):
|
|
2644
|
+
tree = self._dbs
|
|
2645
|
+
# pre-create existing imports
|
|
2646
|
+
schemas = set()
|
|
2647
|
+
for item in imports:
|
|
2648
|
+
parser = IdentityParser(item["name"])
|
|
2649
|
+
database_name, schema_name, table_name = parser.to_list()[:-1]
|
|
2650
|
+
database = getattr(self, database_name)
|
|
2651
|
+
schema = getattr(database, schema_name)
|
|
2652
|
+
schemas.add(schema)
|
|
2653
|
+
schema._add(table_name, is_imported=True)
|
|
2654
|
+
return tree
|
|
2655
|
+
|
|
2656
|
+
def _safe_get(self, name:str) -> 'SnowflakeDB':
|
|
2657
|
+
name = name
|
|
2658
|
+
if name in self._dbs:
|
|
2659
|
+
return self._dbs[name]
|
|
2660
|
+
self._dbs[name] = SnowflakeDB(self, name)
|
|
2661
|
+
return self._dbs[name]
|
|
2662
|
+
|
|
2663
|
+
def __getattr__(self, name: str) -> 'SnowflakeDB':
|
|
2664
|
+
return self._safe_get(name)
|
|
2665
|
+
|
|
2666
|
+
|
|
2667
|
+
class Snowflake(_Snowflake):
|
|
2668
|
+
def __init__(self, model: dsl.Graph, auto_import=False):
|
|
2669
|
+
if model._config.get_bool("use_graph_index", USE_GRAPH_INDEX):
|
|
2670
|
+
raise SnowflakeProxySourceError()
|
|
2671
|
+
else:
|
|
2672
|
+
debugging.warn(SnowflakeProxyAPIDeprecationWarning())
|
|
2673
|
+
|
|
2674
|
+
super().__init__(model, auto_import)
|
|
2675
|
+
|
|
2676
|
+
class SnowflakeDB:
|
|
2677
|
+
def __init__(self, parent, name):
|
|
2678
|
+
self._name = name
|
|
2679
|
+
self._parent = parent
|
|
2680
|
+
self._model = parent._model
|
|
2681
|
+
self._schemas = {}
|
|
2682
|
+
|
|
2683
|
+
def _safe_get(self, name: str) -> 'SnowflakeSchema':
|
|
2684
|
+
name = name
|
|
2685
|
+
if name in self._schemas:
|
|
2686
|
+
return self._schemas[name]
|
|
2687
|
+
self._schemas[name] = SnowflakeSchema(self, name)
|
|
2688
|
+
return self._schemas[name]
|
|
2689
|
+
|
|
2690
|
+
def __getattr__(self, name: str) -> 'SnowflakeSchema':
|
|
2691
|
+
return self._safe_get(name)
|
|
2692
|
+
|
|
2693
|
+
class SnowflakeSchema:
|
|
2694
|
+
def __init__(self, parent, name):
|
|
2695
|
+
self._name = name
|
|
2696
|
+
self._parent = parent
|
|
2697
|
+
self._model = parent._model
|
|
2698
|
+
self._tables = {}
|
|
2699
|
+
self._imported = set()
|
|
2700
|
+
self._table_info = defaultdict(lambda: {"pks": [], "fks": {}, "columns": {}, "invalid_columns": {}})
|
|
2701
|
+
self._dirty = True
|
|
2702
|
+
|
|
2703
|
+
def _fetch_info(self):
|
|
2704
|
+
if not self._dirty:
|
|
2705
|
+
return
|
|
2706
|
+
self._table_info = self._model._client.resources.schema_info(self._parent._name, self._name, list(self._tables.keys()))
|
|
2707
|
+
|
|
2708
|
+
check_column_types = self._model._config.get("check_column_types", True)
|
|
2709
|
+
|
|
2710
|
+
if check_column_types:
|
|
2711
|
+
self._check_and_confirm_invalid_columns()
|
|
2712
|
+
|
|
2713
|
+
self._dirty = False
|
|
2714
|
+
|
|
2715
|
+
def _check_and_confirm_invalid_columns(self):
|
|
2716
|
+
"""Check for invalid columns across the schema's tables."""
|
|
2717
|
+
tables_with_invalid_columns = {}
|
|
2718
|
+
for table_name, table_info in self._table_info.items():
|
|
2719
|
+
if table_info.get("invalid_columns"):
|
|
2720
|
+
tables_with_invalid_columns[table_name] = table_info["invalid_columns"]
|
|
2721
|
+
|
|
2722
|
+
if tables_with_invalid_columns:
|
|
2723
|
+
from ..errors import UnsupportedColumnTypesWarning
|
|
2724
|
+
UnsupportedColumnTypesWarning(tables_with_invalid_columns)
|
|
2725
|
+
|
|
2726
|
+
def _add(self, name, is_imported=False):
|
|
2727
|
+
if name in self._tables:
|
|
2728
|
+
return self._tables[name]
|
|
2729
|
+
self._dirty = True
|
|
2730
|
+
if is_imported:
|
|
2731
|
+
self._imported.add(name)
|
|
2732
|
+
else:
|
|
2733
|
+
self._tables[name] = SnowflakeTable(self, name)
|
|
2734
|
+
return self._tables.get(name)
|
|
2735
|
+
|
|
2736
|
+
def _safe_get(self, name: str) -> 'SnowflakeTable | None':
|
|
2737
|
+
table = self._add(name)
|
|
2738
|
+
return table
|
|
2739
|
+
|
|
2740
|
+
def __getattr__(self, name: str) -> 'SnowflakeTable | None':
|
|
2741
|
+
return self._safe_get(name)
|
|
2742
|
+
|
|
2743
|
+
|
|
2744
|
+
class SnowflakeTable(dsl.Type):
|
|
2745
|
+
def __init__(self, parent, name):
|
|
2746
|
+
super().__init__(parent._model, f"sf_{name}")
|
|
2747
|
+
# hack to make this work for pathfinder
|
|
2748
|
+
self._type.parents.append(m.Builtins.PQFilterAnnotation)
|
|
2749
|
+
self._name = name
|
|
2750
|
+
self._model = parent._model
|
|
2751
|
+
self._parent = parent
|
|
2752
|
+
self._aliases = {}
|
|
2753
|
+
self._finalzed = False
|
|
2754
|
+
self._source = runtime_env.get_source()
|
|
2755
|
+
relation_name = to_fqn_relation_name(self.fqname())
|
|
2756
|
+
self._model.install_raw(f"declare {relation_name}")
|
|
2757
|
+
|
|
2758
|
+
def __call__(self, *args, **kwargs):
|
|
2759
|
+
self._lazy_init()
|
|
2760
|
+
return super().__call__(*args, **kwargs)
|
|
2761
|
+
|
|
2762
|
+
def add(self, *args, **kwargs):
|
|
2763
|
+
self._lazy_init()
|
|
2764
|
+
return super().add(*args, **kwargs)
|
|
2765
|
+
|
|
2766
|
+
def extend(self, *args, **kwargs):
|
|
2767
|
+
self._lazy_init()
|
|
2768
|
+
return super().extend(*args, **kwargs)
|
|
2769
|
+
|
|
2770
|
+
def known_properties(self):
|
|
2771
|
+
self._lazy_init()
|
|
2772
|
+
return super().known_properties()
|
|
2773
|
+
|
|
2774
|
+
def _lazy_init(self):
|
|
2775
|
+
if self._finalzed:
|
|
2776
|
+
return
|
|
2777
|
+
|
|
2778
|
+
parent = self._parent
|
|
2779
|
+
name = self._name
|
|
2780
|
+
use_graph_index = self._model._config.get("use_graph_index", USE_GRAPH_INDEX)
|
|
2781
|
+
|
|
2782
|
+
if not use_graph_index and name not in parent._imported:
|
|
2783
|
+
if self._parent._parent._parent._auto_import:
|
|
2784
|
+
with Spinner(f"Creating stream for {self.fqname()}", f"Stream for {self.fqname()} created successfully"):
|
|
2785
|
+
db_name = parent._parent._name
|
|
2786
|
+
schema_name = parent._name
|
|
2787
|
+
self._model._client.resources.create_import_stream(ImportSourceTable(db_name, schema_name, name), self._model.name)
|
|
2788
|
+
print("")
|
|
2789
|
+
parent._imported.add(name)
|
|
2790
|
+
else:
|
|
2791
|
+
imports = self._model._client.resources.list_imports(model=self._model.name)
|
|
2792
|
+
for item in imports:
|
|
2793
|
+
cur_name = item["name"].lower().split(".")[-1]
|
|
2794
|
+
parent._imported.add(cur_name)
|
|
2795
|
+
if name not in parent._imported:
|
|
2796
|
+
exception = SnowflakeImportMissingException(runtime_env.get_source(), self.fqname(), self._model.name)
|
|
2797
|
+
raise exception from None
|
|
2798
|
+
|
|
2799
|
+
parent._fetch_info()
|
|
2800
|
+
self._finalize()
|
|
2801
|
+
|
|
2802
|
+
def _finalize(self):
|
|
2803
|
+
if self._finalzed:
|
|
2804
|
+
return
|
|
2805
|
+
|
|
2806
|
+
self._finalzed = True
|
|
2807
|
+
self._schema = self._parent._table_info[self._name]
|
|
2808
|
+
|
|
2809
|
+
# Set the relation name to the sanitized version of the fully qualified name
|
|
2810
|
+
relation_name = to_fqn_relation_name(self.fqname())
|
|
2811
|
+
|
|
2812
|
+
model:dsl.Graph = self._model
|
|
2813
|
+
edb = getattr(std.rel, relation_name)
|
|
2814
|
+
edb._rel.parents.append(m.Builtins.EDB)
|
|
2815
|
+
id_rel = getattr(std.rel, f"{relation_name}_pyrel_id")
|
|
2816
|
+
|
|
2817
|
+
with model.rule(globalize=True, source=self._source):
|
|
2818
|
+
id, val = dsl.create_vars(2)
|
|
2819
|
+
edb(dsl.Symbol("METADATA$ROW_ID"), id, val)
|
|
2820
|
+
std.rel.SHA1(id)
|
|
2821
|
+
id_rel.add(id)
|
|
2822
|
+
|
|
2823
|
+
with model.rule(dynamic=True, globalize=True, source=self._source):
|
|
2824
|
+
prop, id, val = dsl.create_vars(3)
|
|
2825
|
+
id_rel(id)
|
|
2826
|
+
std.rel.SHA1(id)
|
|
2827
|
+
self.add(snowflake_id=id)
|
|
2828
|
+
|
|
2829
|
+
for prop, prop_type in self._schema["columns"].items():
|
|
2830
|
+
_prop = prop
|
|
2831
|
+
if _prop.startswith("_"):
|
|
2832
|
+
_prop = "col" + prop
|
|
2833
|
+
|
|
2834
|
+
prop_ident = sanitize_identifier(_prop.lower())
|
|
2835
|
+
|
|
2836
|
+
with model.rule(dynamic=True, globalize=True, source=self._source):
|
|
2837
|
+
id, val = dsl.create_vars(2)
|
|
2838
|
+
edb(dsl.Symbol(prop), id, val)
|
|
2839
|
+
std.rel.SHA1(id)
|
|
2840
|
+
_prop = getattr(self, prop_ident)
|
|
2841
|
+
if not _prop:
|
|
2842
|
+
raise ValueError(f"Property {_prop} couldn't be accessed on {self.fqname()}")
|
|
2843
|
+
if _prop.is_multi_valued:
|
|
2844
|
+
inst = self(snowflake_id=id)
|
|
2845
|
+
getattr(inst, prop_ident).add(val)
|
|
2846
|
+
else:
|
|
2847
|
+
self(snowflake_id=id).set(**{prop_ident: val})
|
|
2848
|
+
|
|
2849
|
+
# Because we're bypassing a bunch of the normal Type.add machinery here,
|
|
2850
|
+
# we need to manually account for the case where people are using value types.
|
|
2851
|
+
def wrapped(x):
|
|
2852
|
+
if not model._config.get("compiler.use_value_types", False):
|
|
2853
|
+
return x
|
|
2854
|
+
other_id = dsl.create_var()
|
|
2855
|
+
model._action(dsl.build.construct(self._type, [x, other_id]))
|
|
2856
|
+
return other_id
|
|
2857
|
+
|
|
2858
|
+
# new UInt128 schema mapping rules
|
|
2859
|
+
with model.rule(dynamic=True, globalize=True, source=self._source):
|
|
2860
|
+
id = dsl.create_var()
|
|
2861
|
+
# This will generate an arity mismatch warning when used with the old SHA-1 Data Streams.
|
|
2862
|
+
# Ideally we have the `@no_diagnostics(:ARITY_MISMATCH)` attribute on the relation using
|
|
2863
|
+
# the METADATA$KEY column but that ended up being a more involved change then expected
|
|
2864
|
+
# for avoiding a non-blocking warning
|
|
2865
|
+
edb(dsl.Symbol("METADATA$KEY"), id)
|
|
2866
|
+
std.rel.UInt128(id)
|
|
2867
|
+
self.add(wrapped(id), snowflake_id=id)
|
|
2868
|
+
|
|
2869
|
+
for prop, prop_type in self._schema["columns"].items():
|
|
2870
|
+
_prop = prop
|
|
2871
|
+
if _prop.startswith("_"):
|
|
2872
|
+
_prop = "col" + prop
|
|
2873
|
+
|
|
2874
|
+
prop_ident = sanitize_identifier(_prop.lower())
|
|
2875
|
+
with model.rule(dynamic=True, globalize=True, source=self._source):
|
|
2876
|
+
id, val = dsl.create_vars(2)
|
|
2877
|
+
edb(dsl.Symbol(prop), id, val)
|
|
2878
|
+
std.rel.UInt128(id)
|
|
2879
|
+
_prop = getattr(self, prop_ident)
|
|
2880
|
+
if not _prop:
|
|
2881
|
+
raise ValueError(f"Property {_prop} couldn't be accessed on {self.fqname()}")
|
|
2882
|
+
if _prop.is_multi_valued:
|
|
2883
|
+
inst = self(id)
|
|
2884
|
+
getattr(inst, prop_ident).add(val)
|
|
2885
|
+
else:
|
|
2886
|
+
model._check_property(_prop._prop)
|
|
2887
|
+
raw_relation = getattr(std.rel, prop_ident)
|
|
2888
|
+
dsl.tag(raw_relation, dsl.Builtins.FunctionAnnotation)
|
|
2889
|
+
raw_relation.add(wrapped(id), val)
|
|
2890
|
+
|
|
2891
|
+
def namespace(self):
|
|
2892
|
+
return f"{self._parent._parent._name}.{self._parent._name}"
|
|
2893
|
+
|
|
2894
|
+
def fqname(self):
|
|
2895
|
+
return f"{self.namespace()}.{self._name}"
|
|
2896
|
+
|
|
2897
|
+
def describe(self, **kwargs):
|
|
2898
|
+
model = self._model
|
|
2899
|
+
for k, v in kwargs.items():
|
|
2900
|
+
if v is PrimaryKey:
|
|
2901
|
+
self._schema["pks"] = [k]
|
|
2902
|
+
elif isinstance(v, tuple):
|
|
2903
|
+
(table, name) = v
|
|
2904
|
+
if isinstance(table, SnowflakeTable):
|
|
2905
|
+
fk_table = table
|
|
2906
|
+
pk = fk_table._schema["pks"]
|
|
2907
|
+
with model.rule():
|
|
2908
|
+
inst = fk_table()
|
|
2909
|
+
me = self()
|
|
2910
|
+
getattr(inst, pk[0]) == getattr(me, k)
|
|
2911
|
+
if getattr(self, name).is_multi_valued:
|
|
2912
|
+
getattr(me, name).add(inst)
|
|
2913
|
+
else:
|
|
2914
|
+
me.set(**{name: inst})
|
|
2915
|
+
else:
|
|
2916
|
+
raise ValueError(f"Invalid foreign key {v}")
|
|
2917
|
+
else:
|
|
2918
|
+
raise ValueError(f"Invalid column {k}={v}")
|
|
2919
|
+
return self
|
|
2920
|
+
|
|
2921
|
+
class Provider(ProviderBase):
|
|
2922
|
+
def __init__(
|
|
2923
|
+
self,
|
|
2924
|
+
profile: str | None = None,
|
|
2925
|
+
config: Config | None = None,
|
|
2926
|
+
resources: Resources | None = None,
|
|
2927
|
+
generation: Generation | None = None,
|
|
2928
|
+
):
|
|
2929
|
+
if resources:
|
|
2930
|
+
self.resources = resources
|
|
2931
|
+
else:
|
|
2932
|
+
resource_class = Resources
|
|
2933
|
+
if config and config.get("use_direct_access", USE_DIRECT_ACCESS):
|
|
2934
|
+
resource_class = DirectAccessResources
|
|
2935
|
+
self.resources = resource_class(profile=profile, config=config, generation=generation)
|
|
2936
|
+
|
|
2937
|
+
def list_streams(self, model:str):
|
|
2938
|
+
return self.resources.list_imports(model=model)
|
|
2939
|
+
|
|
2940
|
+
def create_streams(self, sources:List[str], model:str, force=False):
|
|
2941
|
+
if not self.resources.get_graph(model):
|
|
2942
|
+
self.resources.create_graph(model)
|
|
2943
|
+
def parse_source(raw:str):
|
|
2944
|
+
parser = IdentityParser(raw)
|
|
2945
|
+
assert parser.is_complete, "Snowflake table imports must be in `database.schema.table` format"
|
|
2946
|
+
return ImportSourceTable(*parser.to_list())
|
|
2947
|
+
for source in sources:
|
|
2948
|
+
source_table = parse_source(source)
|
|
2949
|
+
try:
|
|
2950
|
+
with Spinner(f"Creating stream for {source_table.name}", f"Stream for {source_table.name} created successfully"):
|
|
2951
|
+
if force:
|
|
2952
|
+
self.resources.delete_import(source_table.name, model, True)
|
|
2953
|
+
self.resources.create_import_stream(source_table, model)
|
|
2954
|
+
except Exception as e:
|
|
2955
|
+
if "stream already exists" in f"{e}":
|
|
2956
|
+
raise Exception(f"\n\nStream'{source_table.name.upper()}' already exists.")
|
|
2957
|
+
elif "engine not found" in f"{e}":
|
|
2958
|
+
raise Exception("\n\nNo engines found in a READY state. Please use `engines:create` to create an engine that will be used to initialize the target relation.")
|
|
2959
|
+
else:
|
|
2960
|
+
raise e
|
|
2961
|
+
with Spinner("Waiting for imports to complete", "Imports complete"):
|
|
2962
|
+
self.resources.poll_imports(sources, model)
|
|
2963
|
+
|
|
2964
|
+
def delete_stream(self, stream_id: str, model: str):
|
|
2965
|
+
return self.resources.delete_import(stream_id, model)
|
|
2966
|
+
|
|
2967
|
+
def sql(self, query:str, params:List[Any]=[], format:Literal["list", "pandas", "polars", "lazy"]="list"):
|
|
2968
|
+
# note: default format cannot be pandas because .to_pandas() only works on SELECT queries
|
|
2969
|
+
result = self.resources._exec(query, params, raw=True, help=False)
|
|
2970
|
+
if format == "lazy":
|
|
2971
|
+
return cast(snowflake.snowpark.DataFrame, result)
|
|
2972
|
+
elif format == "list":
|
|
2973
|
+
return cast(list, result.collect())
|
|
2974
|
+
elif format == "pandas":
|
|
2975
|
+
import pandas as pd
|
|
2976
|
+
try:
|
|
2977
|
+
# use to_pandas for SELECT queries
|
|
2978
|
+
return cast(pd.DataFrame, result.to_pandas())
|
|
2979
|
+
except Exception:
|
|
2980
|
+
# handle non-SELECT queries like SHOW
|
|
2981
|
+
return pd.DataFrame(result.collect())
|
|
2982
|
+
elif format == "polars":
|
|
2983
|
+
import polars as pl # type: ignore
|
|
2984
|
+
return pl.DataFrame(
|
|
2985
|
+
[row.as_dict() for row in result.collect()],
|
|
2986
|
+
orient="row",
|
|
2987
|
+
strict=False,
|
|
2988
|
+
infer_schema_length=None
|
|
2989
|
+
)
|
|
2990
|
+
else:
|
|
2991
|
+
raise ValueError(f"Invalid format {format}. Should be one of 'list', 'pandas', 'polars', 'lazy'")
|
|
2992
|
+
|
|
2993
|
+
def activate(self):
|
|
2994
|
+
with Spinner("Activating RelationalAI app...", "RelationalAI app activated"):
|
|
2995
|
+
self.sql("CALL RELATIONALAI.APP.ACTIVATE();")
|
|
2996
|
+
|
|
2997
|
+
def deactivate(self):
|
|
2998
|
+
with Spinner("Deactivating RelationalAI app...", "RelationalAI app deactivated"):
|
|
2999
|
+
self.sql("CALL RELATIONALAI.APP.DEACTIVATE();")
|
|
3000
|
+
|
|
3001
|
+
def drop_service(self):
|
|
3002
|
+
warnings.warn(
|
|
3003
|
+
"The drop_service method has been deprecated in favor of deactivate",
|
|
3004
|
+
DeprecationWarning,
|
|
3005
|
+
stacklevel=2,
|
|
3006
|
+
)
|
|
3007
|
+
self.deactivate()
|
|
3008
|
+
|
|
3009
|
+
def resume_service(self):
|
|
3010
|
+
warnings.warn(
|
|
3011
|
+
"The resume_service method has been deprecated in favor of activate",
|
|
3012
|
+
DeprecationWarning,
|
|
3013
|
+
stacklevel=2,
|
|
3014
|
+
)
|
|
3015
|
+
self.activate()
|
|
3016
|
+
|
|
3017
|
+
|
|
3018
|
+
#--------------------------------------------------
|
|
3019
|
+
# SnowflakeClient
|
|
3020
|
+
#--------------------------------------------------
|
|
3021
|
+
class SnowflakeClient(Client):
|
|
3022
|
+
def create_database(self, isolated=True, nowait_durable=True, headers: Dict | None = None):
|
|
3023
|
+
from v0.relationalai.tools.cli_helpers import validate_engine_name
|
|
3024
|
+
|
|
3025
|
+
assert isinstance(self.resources, Resources)
|
|
3026
|
+
|
|
3027
|
+
if self.last_database_version == len(self.resources.sources):
|
|
3028
|
+
return
|
|
3029
|
+
|
|
3030
|
+
model = self._source_database
|
|
3031
|
+
app_name = self.resources.get_app_name()
|
|
3032
|
+
engine_name = self.resources.get_default_engine_name()
|
|
3033
|
+
engine_size = self.resources.config.get_default_engine_size()
|
|
3034
|
+
|
|
3035
|
+
# Validate engine name
|
|
3036
|
+
is_name_valid, _ = validate_engine_name(engine_name)
|
|
3037
|
+
if not is_name_valid:
|
|
3038
|
+
raise EngineNameValidationException(engine_name)
|
|
3039
|
+
|
|
3040
|
+
# Validate engine size
|
|
3041
|
+
valid_sizes = self.resources.get_engine_sizes()
|
|
3042
|
+
if not isinstance(engine_size, str) or engine_size not in valid_sizes:
|
|
3043
|
+
raise InvalidEngineSizeError(str(engine_size), valid_sizes)
|
|
3044
|
+
|
|
3045
|
+
program_span_id = debugging.get_program_span_id()
|
|
3046
|
+
|
|
3047
|
+
query_attrs_dict = json.loads(headers.get("X-Query-Attributes", "{}")) if headers else {}
|
|
3048
|
+
with debugging.span("poll_use_index", sources=self.resources.sources, model=model, engine=engine_name, **query_attrs_dict):
|
|
3049
|
+
self.maybe_poll_use_index(
|
|
3050
|
+
app_name=app_name,
|
|
3051
|
+
sources=self.resources.sources,
|
|
3052
|
+
model=model,
|
|
3053
|
+
engine_name=engine_name,
|
|
3054
|
+
engine_size=engine_size,
|
|
3055
|
+
program_span_id=program_span_id,
|
|
3056
|
+
headers=headers
|
|
3057
|
+
)
|
|
3058
|
+
|
|
3059
|
+
self.last_database_version = len(self.resources.sources)
|
|
3060
|
+
self._manage_packages()
|
|
3061
|
+
|
|
3062
|
+
if isolated and not self.keep_model:
|
|
3063
|
+
atexit.register(self.delete_database)
|
|
3064
|
+
|
|
3065
|
+
def maybe_poll_use_index(
|
|
3066
|
+
self,
|
|
3067
|
+
app_name: str,
|
|
3068
|
+
sources: Iterable[str],
|
|
3069
|
+
model: str,
|
|
3070
|
+
engine_name: str,
|
|
3071
|
+
engine_size: str | None = None,
|
|
3072
|
+
program_span_id: str | None = None,
|
|
3073
|
+
headers: Dict | None = None,
|
|
3074
|
+
):
|
|
3075
|
+
"""Only call _poll_use_index if there are sources to process."""
|
|
3076
|
+
assert isinstance(self.resources, Resources)
|
|
3077
|
+
return self.resources.maybe_poll_use_index(
|
|
3078
|
+
app_name=app_name,
|
|
3079
|
+
sources=sources,
|
|
3080
|
+
model=model,
|
|
3081
|
+
engine_name=engine_name,
|
|
3082
|
+
engine_size=engine_size,
|
|
3083
|
+
program_span_id=program_span_id,
|
|
3084
|
+
headers=headers
|
|
3085
|
+
)
|
|
3086
|
+
|
|
3087
|
+
|
|
3088
|
+
#--------------------------------------------------
|
|
3089
|
+
# Graph
|
|
3090
|
+
#--------------------------------------------------
|
|
3091
|
+
|
|
3092
|
+
def Graph(
|
|
3093
|
+
name,
|
|
3094
|
+
*,
|
|
3095
|
+
profile: str | None = None,
|
|
3096
|
+
config: Config,
|
|
3097
|
+
dry_run: bool = False,
|
|
3098
|
+
isolated: bool = True,
|
|
3099
|
+
connection: Session | None = None,
|
|
3100
|
+
keep_model: bool = False,
|
|
3101
|
+
nowait_durable: bool = True,
|
|
3102
|
+
format: str = "default",
|
|
3103
|
+
):
|
|
3104
|
+
|
|
3105
|
+
client_class = Client
|
|
3106
|
+
resource_class = Resources
|
|
3107
|
+
use_graph_index = config.get("use_graph_index", USE_GRAPH_INDEX)
|
|
3108
|
+
use_monotype_operators = config.get("compiler.use_monotype_operators", False)
|
|
3109
|
+
use_direct_access = config.get("use_direct_access", USE_DIRECT_ACCESS)
|
|
3110
|
+
|
|
3111
|
+
if use_graph_index:
|
|
3112
|
+
client_class = SnowflakeClient
|
|
3113
|
+
if use_direct_access:
|
|
3114
|
+
resource_class = DirectAccessResources
|
|
3115
|
+
client = client_class(
|
|
3116
|
+
resource_class(generation=Generation.V0, profile=profile, config=config, connection=connection),
|
|
3117
|
+
rel.Compiler(config),
|
|
3118
|
+
name,
|
|
3119
|
+
config,
|
|
3120
|
+
dry_run=dry_run,
|
|
3121
|
+
isolated=isolated,
|
|
3122
|
+
keep_model=keep_model,
|
|
3123
|
+
nowait_durable=nowait_durable
|
|
3124
|
+
)
|
|
3125
|
+
base_rel = """
|
|
3126
|
+
@inline
|
|
3127
|
+
def make_identity(x..., z):
|
|
3128
|
+
rel_primitive_hash_tuple_uint128(x..., z)
|
|
3129
|
+
|
|
3130
|
+
@inline
|
|
3131
|
+
def pyrel_default({F}, c, k..., v):
|
|
3132
|
+
F(k..., v) or (not F(k..., _) and v = c)
|
|
3133
|
+
|
|
3134
|
+
@inline
|
|
3135
|
+
def pyrel_unwrap(x in UInt128, y): y = x
|
|
3136
|
+
|
|
3137
|
+
@inline
|
|
3138
|
+
def pyrel_dates_period_days(x in Date, y in Date, z in Int):
|
|
3139
|
+
exists((u) | dates_period_days(x, y , u) and u = ::std::common::^Day[z])
|
|
3140
|
+
|
|
3141
|
+
@inline
|
|
3142
|
+
def pyrel_datetimes_period_milliseconds(x in DateTime, y in DateTime, z in Int):
|
|
3143
|
+
exists((u) | datetimes_period_milliseconds(x, y , u) and u = ^Millisecond[z])
|
|
3144
|
+
|
|
3145
|
+
@inline
|
|
3146
|
+
def pyrel_bool_filter(a, b, {F}, z): { z = if_then_else[F(a, b), boolean_true, boolean_false] }
|
|
3147
|
+
|
|
3148
|
+
@inline
|
|
3149
|
+
def pyrel_strftime(v, fmt, tz in String, s in String):
|
|
3150
|
+
(Date(v) and s = format_date[v, fmt])
|
|
3151
|
+
or (DateTime(v) and s = format_datetime[v, fmt, tz])
|
|
3152
|
+
|
|
3153
|
+
@inline
|
|
3154
|
+
def pyrel_regex_match_all(pattern, string in String, pos in Int, offset in Int, match in String):
|
|
3155
|
+
regex_match_all(pattern, string, offset, match) and offset >= pos
|
|
3156
|
+
|
|
3157
|
+
@inline
|
|
3158
|
+
def pyrel_regex_match(pattern, string in String, pos in Int, offset in Int, match in String):
|
|
3159
|
+
pyrel_regex_match_all(pattern, string, pos, offset, match) and offset = pos
|
|
3160
|
+
|
|
3161
|
+
@inline
|
|
3162
|
+
def pyrel_regex_search(pattern, string in String, pos in Int, offset in Int, match in String):
|
|
3163
|
+
enumerate(pyrel_regex_match_all[pattern, string, pos], 1, offset, match)
|
|
3164
|
+
|
|
3165
|
+
@inline
|
|
3166
|
+
def pyrel_regex_sub(pattern, repl in String, string in String, result in String):
|
|
3167
|
+
string_replace_multiple(string, {(last[regex_match_all[pattern, string]], repl)}, result)
|
|
3168
|
+
|
|
3169
|
+
@inline
|
|
3170
|
+
def pyrel_capture_group(regex in Pattern, string in String, pos in Int, index, match in String):
|
|
3171
|
+
(Integer(index) and capture_group_by_index(regex, string, pos, index, match)) or
|
|
3172
|
+
(String(index) and capture_group_by_name(regex, string, pos, index, match))
|
|
3173
|
+
|
|
3174
|
+
declare __resource
|
|
3175
|
+
declare __compiled_patterns
|
|
3176
|
+
"""
|
|
3177
|
+
if use_monotype_operators:
|
|
3178
|
+
base_rel += """
|
|
3179
|
+
|
|
3180
|
+
// use monotyped operators
|
|
3181
|
+
from ::std::monotype import +, -, *, /, <, <=, >, >=
|
|
3182
|
+
"""
|
|
3183
|
+
pyrel_base = dsl.build.raw_task(base_rel)
|
|
3184
|
+
debugging.set_source(pyrel_base)
|
|
3185
|
+
client.install("pyrel_base", pyrel_base)
|
|
3186
|
+
return dsl.Graph(client, name, format=format)
|
|
3187
|
+
|
|
3188
|
+
|
|
3189
|
+
|
|
3190
|
+
#--------------------------------------------------
|
|
3191
|
+
# Direct Access
|
|
3192
|
+
#--------------------------------------------------
|
|
3193
|
+
# Note: All direct access components should live in a separate file
|
|
3194
|
+
|
|
3195
|
+
class DirectAccessResources(Resources):
|
|
3196
|
+
"""
|
|
3197
|
+
Resources class for Direct Service Access avoiding Snowflake service functions.
|
|
3198
|
+
"""
|
|
3199
|
+
def __init__(
|
|
3200
|
+
self,
|
|
3201
|
+
profile: Union[str, None] = None,
|
|
3202
|
+
config: Union[Config, None] = None,
|
|
3203
|
+
connection: Union[Session, None] = None,
|
|
3204
|
+
dry_run: bool = False,
|
|
3205
|
+
reset_session: bool = False,
|
|
3206
|
+
generation: Optional[Generation] = None,
|
|
3207
|
+
language: str = "rel",
|
|
3208
|
+
):
|
|
3209
|
+
super().__init__(
|
|
3210
|
+
generation=generation,
|
|
3211
|
+
profile=profile,
|
|
3212
|
+
config=config,
|
|
3213
|
+
connection=connection,
|
|
3214
|
+
reset_session=reset_session,
|
|
3215
|
+
dry_run=dry_run,
|
|
3216
|
+
language=language,
|
|
3217
|
+
)
|
|
3218
|
+
self._endpoint_info = ConfigStore(ENDPOINT_FILE)
|
|
3219
|
+
self._service_endpoint = ""
|
|
3220
|
+
self._direct_access_client = None
|
|
3221
|
+
self.generation = generation
|
|
3222
|
+
self.database = ""
|
|
3223
|
+
|
|
3224
|
+
@property
|
|
3225
|
+
def service_endpoint(self) -> str:
|
|
3226
|
+
return self._retrieve_service_endpoint()
|
|
3227
|
+
|
|
3228
|
+
def _retrieve_service_endpoint(self, enforce_update=False) -> str:
|
|
3229
|
+
account = self.config.get("account")
|
|
3230
|
+
app_name = self.config.get("rai_app_name")
|
|
3231
|
+
service_endpoint_key = f"{account}.{app_name}.service_endpoint"
|
|
3232
|
+
if self._service_endpoint and not enforce_update:
|
|
3233
|
+
return self._service_endpoint
|
|
3234
|
+
if self._endpoint_info.get(service_endpoint_key, "") and not enforce_update:
|
|
3235
|
+
self._service_endpoint = str(self._endpoint_info.get(service_endpoint_key, ""))
|
|
3236
|
+
return self._service_endpoint
|
|
3237
|
+
|
|
3238
|
+
is_snowflake_notebook = isinstance(runtime_env, SnowbookEnvironment)
|
|
3239
|
+
query = f"CALL {self.get_app_name()}.app.service_endpoint({not is_snowflake_notebook});"
|
|
3240
|
+
result = self._exec(query)
|
|
3241
|
+
assert result, f"Could not retrieve service endpoint for {self.get_app_name()}"
|
|
3242
|
+
if is_snowflake_notebook:
|
|
3243
|
+
self._service_endpoint = f"http://{result[0]['SERVICE_ENDPOINT']}"
|
|
3244
|
+
else:
|
|
3245
|
+
self._service_endpoint = f"https://{result[0]['SERVICE_ENDPOINT']}"
|
|
3246
|
+
|
|
3247
|
+
self._endpoint_info.set(service_endpoint_key, self._service_endpoint)
|
|
3248
|
+
# save the endpoint to `ENDPOINT_FILE` to avoid calling the endpoint with every
|
|
3249
|
+
# pyrel execution
|
|
3250
|
+
try:
|
|
3251
|
+
self._endpoint_info.save()
|
|
3252
|
+
except Exception:
|
|
3253
|
+
print("Failed to persist endpoints to file. This might slow down future executions.")
|
|
3254
|
+
|
|
3255
|
+
return self._service_endpoint
|
|
3256
|
+
|
|
3257
|
+
@property
|
|
3258
|
+
def direct_access_client(self) -> DirectAccessClient:
|
|
3259
|
+
if self._direct_access_client:
|
|
3260
|
+
return self._direct_access_client
|
|
3261
|
+
try:
|
|
3262
|
+
service_endpoint = self.service_endpoint
|
|
3263
|
+
self._direct_access_client = DirectAccessClient(
|
|
3264
|
+
self.config, self.token_handler, service_endpoint, self.generation,
|
|
3265
|
+
)
|
|
3266
|
+
except Exception as e:
|
|
3267
|
+
raise e
|
|
3268
|
+
return self._direct_access_client
|
|
3269
|
+
|
|
3270
|
+
def request(
|
|
3271
|
+
self,
|
|
3272
|
+
endpoint: str,
|
|
3273
|
+
payload: Dict[str, Any] | None = None,
|
|
3274
|
+
headers: Dict[str, str] | None = None,
|
|
3275
|
+
path_params: Dict[str, str] | None = None,
|
|
3276
|
+
query_params: Dict[str, str] | None = None,
|
|
3277
|
+
skip_auto_create: bool = False,
|
|
3278
|
+
skip_engine_db_error_retry: bool = False,
|
|
3279
|
+
) -> requests.Response:
|
|
3280
|
+
with debugging.span("direct_access_request"):
|
|
3281
|
+
def _send_request():
|
|
3282
|
+
return self.direct_access_client.request(
|
|
3283
|
+
endpoint=endpoint,
|
|
3284
|
+
payload=payload,
|
|
3285
|
+
headers=headers,
|
|
3286
|
+
path_params=path_params,
|
|
3287
|
+
query_params=query_params,
|
|
3288
|
+
)
|
|
3289
|
+
try:
|
|
3290
|
+
response = _send_request()
|
|
3291
|
+
if response.status_code != 200:
|
|
3292
|
+
# For 404 responses with skip_auto_create=True, return immediately to let caller handle it
|
|
3293
|
+
# (e.g., get_engine needs to check 404 and return None for auto_create_engine)
|
|
3294
|
+
# For skip_auto_create=False, continue to auto-creation logic below
|
|
3295
|
+
if response.status_code == 404 and skip_auto_create:
|
|
3296
|
+
return response
|
|
3297
|
+
|
|
3298
|
+
try:
|
|
3299
|
+
message = response.json().get("message", "")
|
|
3300
|
+
except requests.exceptions.JSONDecodeError:
|
|
3301
|
+
# Can't parse JSON response. For skip_auto_create=True (e.g., get_engine),
|
|
3302
|
+
# this should have been caught by the 404 check above, so this is an error.
|
|
3303
|
+
# For skip_auto_create=False, we explicitly check status_code below,
|
|
3304
|
+
# so we don't need to parse the message.
|
|
3305
|
+
if skip_auto_create:
|
|
3306
|
+
raise ResponseStatusException(
|
|
3307
|
+
f"Failed to parse error response from endpoint {endpoint}.", response
|
|
3308
|
+
)
|
|
3309
|
+
message = "" # Not used when we check status_code directly
|
|
3310
|
+
|
|
3311
|
+
# fix engine on engine error and retry
|
|
3312
|
+
# Skip setting up GI if skip_auto_create is True to avoid recursion or skip_engine_db_error_retry is true to let _exec_async_v2 perform the retry with the correct headers.
|
|
3313
|
+
if ((_is_engine_issue(message) and not skip_auto_create) or _is_database_issue(message)) and not skip_engine_db_error_retry:
|
|
3314
|
+
engine_name = payload.get("caller_engine_name", "") if payload else ""
|
|
3315
|
+
engine_name = engine_name or self.get_default_engine_name()
|
|
3316
|
+
engine_size = self.config.get_default_engine_size()
|
|
3317
|
+
self._poll_use_index(
|
|
3318
|
+
app_name=self.get_app_name(),
|
|
3319
|
+
sources=self.sources,
|
|
3320
|
+
model=self.database,
|
|
3321
|
+
engine_name=engine_name,
|
|
3322
|
+
engine_size=engine_size,
|
|
3323
|
+
headers=headers,
|
|
3324
|
+
)
|
|
3325
|
+
response = _send_request()
|
|
3326
|
+
except requests.exceptions.ConnectionError as e:
|
|
3327
|
+
if "NameResolutionError" in str(e):
|
|
3328
|
+
# when we can not resolve the service endpoint, we assume it is outdated
|
|
3329
|
+
# hence, we try to retrieve it again and query again.
|
|
3330
|
+
self.direct_access_client.service_endpoint = self._retrieve_service_endpoint(
|
|
3331
|
+
enforce_update=True,
|
|
3332
|
+
)
|
|
3333
|
+
return _send_request()
|
|
3334
|
+
# raise in all other cases
|
|
3335
|
+
raise e
|
|
3336
|
+
return response
|
|
3337
|
+
|
|
3338
|
+
def _txn_request_with_gi_retry(
|
|
3339
|
+
self,
|
|
3340
|
+
payload: Dict,
|
|
3341
|
+
headers: Dict[str, str],
|
|
3342
|
+
query_params: Dict,
|
|
3343
|
+
engine: Union[str, None],
|
|
3344
|
+
):
|
|
3345
|
+
"""Make request with graph index retry logic.
|
|
3346
|
+
|
|
3347
|
+
Attempts request with gi_setup_skipped=True first. If an engine or database
|
|
3348
|
+
issue occurs, polls use_index and retries with gi_setup_skipped=False.
|
|
3349
|
+
"""
|
|
3350
|
+
response = self.request(
|
|
3351
|
+
"create_txn", payload=payload, headers=headers, query_params=query_params, skip_auto_create=True, skip_engine_db_error_retry=True
|
|
3352
|
+
)
|
|
3353
|
+
|
|
3354
|
+
if response.status_code != 200:
|
|
3355
|
+
try:
|
|
3356
|
+
message = response.json().get("message", "")
|
|
3357
|
+
except requests.exceptions.JSONDecodeError:
|
|
3358
|
+
message = ""
|
|
3359
|
+
|
|
3360
|
+
if _is_engine_issue(message) or _is_database_issue(message):
|
|
3361
|
+
engine_name = engine or self.get_default_engine_name()
|
|
3362
|
+
engine_size = self.config.get_default_engine_size()
|
|
3363
|
+
self._poll_use_index(
|
|
3364
|
+
app_name=self.get_app_name(),
|
|
3365
|
+
sources=self.sources,
|
|
3366
|
+
model=self.database,
|
|
3367
|
+
engine_name=engine_name,
|
|
3368
|
+
engine_size=engine_size,
|
|
3369
|
+
headers=headers,
|
|
3370
|
+
)
|
|
3371
|
+
headers['gi_setup_skipped'] = 'False'
|
|
3372
|
+
response = self.request(
|
|
3373
|
+
"create_txn", payload=payload, headers=headers, query_params=query_params, skip_auto_create=True, skip_engine_db_error_retry=True
|
|
3374
|
+
)
|
|
3375
|
+
else:
|
|
3376
|
+
raise ResponseStatusException("Failed to create transaction.", response)
|
|
3377
|
+
|
|
3378
|
+
return response
|
|
3379
|
+
|
|
3380
|
+
def _exec_async_v2(
|
|
3381
|
+
self,
|
|
3382
|
+
database: str,
|
|
3383
|
+
engine: Union[str, None],
|
|
3384
|
+
raw_code: str,
|
|
3385
|
+
inputs: Dict | None = None,
|
|
3386
|
+
readonly=True,
|
|
3387
|
+
nowait_durable=False,
|
|
3388
|
+
headers: Dict[str, str] | None = None,
|
|
3389
|
+
bypass_index=False,
|
|
3390
|
+
language: str = "rel",
|
|
3391
|
+
query_timeout_mins: int | None = None,
|
|
3392
|
+
gi_setup_skipped: bool = False,
|
|
3393
|
+
):
|
|
3394
|
+
|
|
3395
|
+
with debugging.span("transaction") as txn_span:
|
|
3396
|
+
with debugging.span("create_v2") as create_span:
|
|
3397
|
+
|
|
3398
|
+
use_graph_index = self.config.get("use_graph_index", USE_GRAPH_INDEX)
|
|
3399
|
+
|
|
3400
|
+
payload = {
|
|
3401
|
+
"dbname": database,
|
|
3402
|
+
"engine_name": engine,
|
|
3403
|
+
"query": raw_code,
|
|
3404
|
+
"v1_inputs": inputs,
|
|
3405
|
+
"nowait_durable": nowait_durable,
|
|
3406
|
+
"readonly": readonly,
|
|
3407
|
+
"language": language,
|
|
3408
|
+
}
|
|
3409
|
+
if query_timeout_mins is None and (timeout_value := self.config.get("query_timeout_mins", DEFAULT_QUERY_TIMEOUT_MINS)) is not None:
|
|
3410
|
+
query_timeout_mins = int(timeout_value)
|
|
3411
|
+
if query_timeout_mins is not None:
|
|
3412
|
+
payload["timeout_mins"] = query_timeout_mins
|
|
3413
|
+
query_params={"use_graph_index": str(use_graph_index and not bypass_index)}
|
|
3414
|
+
|
|
3415
|
+
# Add gi_setup_skipped to headers
|
|
3416
|
+
if headers is None:
|
|
3417
|
+
headers = {}
|
|
3418
|
+
headers["gi_setup_skipped"] = str(gi_setup_skipped)
|
|
3419
|
+
headers['pyrel_program_id'] = debugging.get_program_span_id() or ""
|
|
3420
|
+
|
|
3421
|
+
response = self._txn_request_with_gi_retry(
|
|
3422
|
+
payload, headers, query_params, engine
|
|
3423
|
+
)
|
|
3424
|
+
|
|
3425
|
+
artifact_info = {}
|
|
3426
|
+
response_content = response.json()
|
|
3427
|
+
|
|
3428
|
+
txn_id = response_content["transaction"]['id']
|
|
3429
|
+
state = response_content["transaction"]['state']
|
|
3430
|
+
|
|
3431
|
+
txn_span["txn_id"] = txn_id
|
|
3432
|
+
create_span["txn_id"] = txn_id
|
|
3433
|
+
debugging.event("transaction_created", txn_span, txn_id=txn_id)
|
|
3434
|
+
|
|
3435
|
+
# fast path: transaction already finished
|
|
3436
|
+
if state in ["COMPLETED", "ABORTED"]:
|
|
3437
|
+
if txn_id in self._pending_transactions:
|
|
3438
|
+
self._pending_transactions.remove(txn_id)
|
|
3439
|
+
|
|
3440
|
+
# Process rows to get the rest of the artifacts
|
|
3441
|
+
for result in response_content.get("results", []):
|
|
3442
|
+
filename = result['filename']
|
|
3443
|
+
# making keys uppercase to match the old behavior
|
|
3444
|
+
artifact_info[filename] = {k.upper(): v for k, v in result.items()}
|
|
3445
|
+
|
|
3446
|
+
# Slow path: transaction not done yet; start polling
|
|
3447
|
+
else:
|
|
3448
|
+
self._pending_transactions.append(txn_id)
|
|
3449
|
+
with debugging.span("wait", txn_id=txn_id):
|
|
3450
|
+
poll_with_specified_overhead(
|
|
3451
|
+
lambda: self._check_exec_async_status(txn_id, headers=headers), 0.1
|
|
3452
|
+
)
|
|
3453
|
+
artifact_info = self._list_exec_async_artifacts(txn_id, headers=headers)
|
|
3454
|
+
|
|
3455
|
+
with debugging.span("fetch"):
|
|
3456
|
+
return self._download_results(artifact_info, txn_id, state)
|
|
3457
|
+
|
|
3458
|
+
def _prepare_index(
|
|
3459
|
+
self,
|
|
3460
|
+
model: str,
|
|
3461
|
+
engine_name: str,
|
|
3462
|
+
engine_size: str = "",
|
|
3463
|
+
language: str = "rel",
|
|
3464
|
+
rai_relations: List[str] | None = None,
|
|
3465
|
+
pyrel_program_id: str | None = None,
|
|
3466
|
+
skip_pull_relations: bool = False,
|
|
3467
|
+
headers: Dict | None = None,
|
|
3468
|
+
):
|
|
3469
|
+
"""
|
|
3470
|
+
Prepare the index for the given engine and model.
|
|
3471
|
+
"""
|
|
3472
|
+
with debugging.span("prepare_index"):
|
|
3473
|
+
if headers is None:
|
|
3474
|
+
headers = {}
|
|
3475
|
+
|
|
3476
|
+
payload = {
|
|
3477
|
+
"model_name": model,
|
|
3478
|
+
"caller_engine_name": engine_name,
|
|
3479
|
+
"language": language,
|
|
3480
|
+
"pyrel_program_id": pyrel_program_id,
|
|
3481
|
+
"skip_pull_relations": skip_pull_relations,
|
|
3482
|
+
"rai_relations": rai_relations or [],
|
|
3483
|
+
"user_agent": get_pyrel_version(self.generation),
|
|
3484
|
+
}
|
|
3485
|
+
# Only include engine_size if it has a non-empty string value
|
|
3486
|
+
if engine_size and engine_size.strip():
|
|
3487
|
+
payload["caller_engine_size"] = engine_size
|
|
3488
|
+
|
|
3489
|
+
response = self.request(
|
|
3490
|
+
"prepare_index", payload=payload, headers=headers
|
|
3491
|
+
)
|
|
3492
|
+
|
|
3493
|
+
if response.status_code != 200:
|
|
3494
|
+
raise ResponseStatusException("Failed to prepare index.", response)
|
|
3495
|
+
|
|
3496
|
+
return response.json()
|
|
3497
|
+
|
|
3498
|
+
def _poll_use_index(
|
|
3499
|
+
self,
|
|
3500
|
+
app_name: str,
|
|
3501
|
+
sources: Iterable[str],
|
|
3502
|
+
model: str,
|
|
3503
|
+
engine_name: str,
|
|
3504
|
+
engine_size: str | None = None,
|
|
3505
|
+
program_span_id: str | None = None,
|
|
3506
|
+
headers: Dict | None = None,
|
|
3507
|
+
):
|
|
3508
|
+
return DirectUseIndexPoller(
|
|
3509
|
+
self,
|
|
3510
|
+
app_name=app_name,
|
|
3511
|
+
sources=sources,
|
|
3512
|
+
model=model,
|
|
3513
|
+
engine_name=engine_name,
|
|
3514
|
+
engine_size=engine_size,
|
|
3515
|
+
language=self.language,
|
|
3516
|
+
program_span_id=program_span_id,
|
|
3517
|
+
headers=headers,
|
|
3518
|
+
generation=self.generation,
|
|
3519
|
+
).poll()
|
|
3520
|
+
|
|
3521
|
+
def maybe_poll_use_index(
|
|
3522
|
+
self,
|
|
3523
|
+
app_name: str,
|
|
3524
|
+
sources: Iterable[str],
|
|
3525
|
+
model: str,
|
|
3526
|
+
engine_name: str,
|
|
3527
|
+
engine_size: str | None = None,
|
|
3528
|
+
program_span_id: str | None = None,
|
|
3529
|
+
headers: Dict | None = None,
|
|
3530
|
+
):
|
|
3531
|
+
"""Only call poll() if there are sources to process and cache is not valid."""
|
|
3532
|
+
sources_list = list(sources)
|
|
3533
|
+
self.database = model
|
|
3534
|
+
if sources_list:
|
|
3535
|
+
poller = DirectUseIndexPoller(
|
|
3536
|
+
self,
|
|
3537
|
+
app_name=app_name,
|
|
3538
|
+
sources=sources_list,
|
|
3539
|
+
model=model,
|
|
3540
|
+
engine_name=engine_name,
|
|
3541
|
+
engine_size=engine_size,
|
|
3542
|
+
language=self.language,
|
|
3543
|
+
program_span_id=program_span_id,
|
|
3544
|
+
headers=headers,
|
|
3545
|
+
generation=self.generation,
|
|
3546
|
+
)
|
|
3547
|
+
# If cache is valid (data freshness has not expired), skip polling
|
|
3548
|
+
if poller.cache.is_valid():
|
|
3549
|
+
cached_sources = len(poller.cache.sources)
|
|
3550
|
+
total_sources = len(sources_list)
|
|
3551
|
+
cached_timestamp = poller.cache._metadata.get("cachedIndices", {}).get(poller.cache.key, {}).get("last_use_index_update_on", "")
|
|
3552
|
+
|
|
3553
|
+
message = f"Using cached data for {cached_sources}/{total_sources} data streams"
|
|
3554
|
+
if cached_timestamp:
|
|
3555
|
+
print(f"\n{message} (cached at {cached_timestamp})\n")
|
|
3556
|
+
else:
|
|
3557
|
+
print(f"\n{message}\n")
|
|
3558
|
+
else:
|
|
3559
|
+
return poller.poll()
|
|
3560
|
+
|
|
3561
|
+
def _check_exec_async_status(self, txn_id: str, headers: Dict[str, str] | None = None) -> bool:
|
|
3562
|
+
"""Check whether the given transaction has completed."""
|
|
3563
|
+
|
|
3564
|
+
with debugging.span("check_status"):
|
|
3565
|
+
response = self.request(
|
|
3566
|
+
"get_txn",
|
|
3567
|
+
headers=headers,
|
|
3568
|
+
path_params={"txn_id": txn_id},
|
|
3569
|
+
)
|
|
3570
|
+
assert response, f"No results from get_transaction('{txn_id}')"
|
|
3571
|
+
|
|
3572
|
+
response_content = response.json()
|
|
3573
|
+
transaction = response_content["transaction"]
|
|
3574
|
+
status: str = transaction['state']
|
|
3575
|
+
|
|
3576
|
+
# remove the transaction from the pending list if it's completed or aborted
|
|
3577
|
+
if status in ["COMPLETED", "ABORTED"]:
|
|
3578
|
+
if txn_id in self._pending_transactions:
|
|
3579
|
+
self._pending_transactions.remove(txn_id)
|
|
3580
|
+
|
|
3581
|
+
if status == "ABORTED" and transaction.get("abort_reason", "") == TXN_ABORT_REASON_TIMEOUT:
|
|
3582
|
+
config_file_path = getattr(self.config, 'file_path', None)
|
|
3583
|
+
timeout_ms = int(transaction.get("timeout_ms", 0))
|
|
3584
|
+
timeout_mins = timeout_ms // 60000 if timeout_ms > 0 else int(self.config.get("query_timeout_mins", DEFAULT_QUERY_TIMEOUT_MINS) or DEFAULT_QUERY_TIMEOUT_MINS)
|
|
3585
|
+
raise QueryTimeoutExceededException(
|
|
3586
|
+
timeout_mins=timeout_mins,
|
|
3587
|
+
query_id=txn_id,
|
|
3588
|
+
config_file_path=config_file_path,
|
|
3589
|
+
)
|
|
3590
|
+
|
|
3591
|
+
# @TODO: Find some way to tunnel the ABORT_REASON out. Azure doesn't have this, but it's handy
|
|
3592
|
+
return status == "COMPLETED" or status == "ABORTED"
|
|
3593
|
+
|
|
3594
|
+
def _list_exec_async_artifacts(self, txn_id: str, headers: Dict[str, str] | None = None) -> Dict[str, Dict]:
|
|
3595
|
+
"""Grab the list of artifacts produced in the transaction and the URLs to retrieve their contents."""
|
|
3596
|
+
with debugging.span("list_results"):
|
|
3597
|
+
response = self.request(
|
|
3598
|
+
"get_txn_artifacts",
|
|
3599
|
+
headers=headers,
|
|
3600
|
+
path_params={"txn_id": txn_id},
|
|
3601
|
+
)
|
|
3602
|
+
assert response, f"No results from get_transaction_artifacts('{txn_id}')"
|
|
3603
|
+
artifact_info = {}
|
|
3604
|
+
for result in response.json()["results"]:
|
|
3605
|
+
filename = result['filename']
|
|
3606
|
+
# making keys uppercase to match the old behavior
|
|
3607
|
+
artifact_info[filename] = {k.upper(): v for k, v in result.items()}
|
|
3608
|
+
return artifact_info
|
|
3609
|
+
|
|
3610
|
+
def get_transaction_problems(self, txn_id: str) -> List[Dict[str, Any]]:
|
|
3611
|
+
with debugging.span("get_transaction_problems"):
|
|
3612
|
+
response = self.request(
|
|
3613
|
+
"get_txn_problems",
|
|
3614
|
+
path_params={"txn_id": txn_id},
|
|
3615
|
+
)
|
|
3616
|
+
response_content = response.json()
|
|
3617
|
+
if not response_content:
|
|
3618
|
+
return []
|
|
3619
|
+
return response_content.get("problems", [])
|
|
3620
|
+
|
|
3621
|
+
def get_transaction_events(self, transaction_id: str, continuation_token: str = ''):
|
|
3622
|
+
response = self.request(
|
|
3623
|
+
"get_txn_events",
|
|
3624
|
+
path_params={"txn_id": transaction_id, "stream_name": "profiler"},
|
|
3625
|
+
query_params={"continuation_token": continuation_token},
|
|
3626
|
+
)
|
|
3627
|
+
response_content = response.json()
|
|
3628
|
+
if not response_content:
|
|
3629
|
+
return {
|
|
3630
|
+
"events": [],
|
|
3631
|
+
"continuation_token": None
|
|
3632
|
+
}
|
|
3633
|
+
return response_content
|
|
3634
|
+
|
|
3635
|
+
#--------------------------------------------------
|
|
3636
|
+
# Databases
|
|
3637
|
+
#--------------------------------------------------
|
|
3638
|
+
|
|
3639
|
+
def get_installed_packages(self, database: str) -> Union[Dict, None]:
|
|
3640
|
+
use_graph_index = self.config.get("use_graph_index", USE_GRAPH_INDEX)
|
|
3641
|
+
if use_graph_index:
|
|
3642
|
+
response = self.request(
|
|
3643
|
+
"get_model_package_versions",
|
|
3644
|
+
payload={"model_name": database},
|
|
3645
|
+
)
|
|
3646
|
+
else:
|
|
3647
|
+
response = self.request(
|
|
3648
|
+
"get_package_versions",
|
|
3649
|
+
path_params={"db_name": database},
|
|
3650
|
+
)
|
|
3651
|
+
if response.status_code == 404 and response.json().get("message", "") == "database not found":
|
|
3652
|
+
return None
|
|
3653
|
+
if response.status_code != 200:
|
|
3654
|
+
raise ResponseStatusException(
|
|
3655
|
+
f"Failed to retrieve package versions for {database}.", response
|
|
3656
|
+
)
|
|
3657
|
+
|
|
3658
|
+
content = response.json()
|
|
3659
|
+
if not content:
|
|
3660
|
+
return None
|
|
3661
|
+
|
|
3662
|
+
return safe_json_loads(content["package_versions"])
|
|
3663
|
+
|
|
3664
|
+
def get_database(self, database: str):
|
|
3665
|
+
with debugging.span("get_database", dbname=database):
|
|
3666
|
+
if not database:
|
|
3667
|
+
raise ValueError("Database name must be provided to get database.")
|
|
3668
|
+
response = self.request(
|
|
3669
|
+
"get_db",
|
|
3670
|
+
path_params={},
|
|
3671
|
+
query_params={"name": database},
|
|
3672
|
+
)
|
|
3673
|
+
if response.status_code != 200:
|
|
3674
|
+
raise ResponseStatusException(f"Failed to get db. db:{database}", response)
|
|
3675
|
+
|
|
3676
|
+
response_content = response.json()
|
|
3677
|
+
|
|
3678
|
+
if (response_content.get("databases") and len(response_content["databases"]) == 1):
|
|
3679
|
+
db = response_content["databases"][0]
|
|
3680
|
+
return {
|
|
3681
|
+
"id": db["id"],
|
|
3682
|
+
"name": db["name"],
|
|
3683
|
+
"created_by": db.get("created_by"),
|
|
3684
|
+
"created_on": ms_to_timestamp(db.get("created_on")),
|
|
3685
|
+
"deleted_by": db.get("deleted_by"),
|
|
3686
|
+
"deleted_on": ms_to_timestamp(db.get("deleted_on")),
|
|
3687
|
+
"state": db["state"],
|
|
3688
|
+
}
|
|
3689
|
+
else:
|
|
3690
|
+
return None
|
|
3691
|
+
|
|
3692
|
+
def create_graph(self, name: str):
|
|
3693
|
+
with debugging.span("create_model", dbname=name):
|
|
3694
|
+
return self._create_database(name,"")
|
|
3695
|
+
|
|
3696
|
+
def delete_graph(self, name:str, force=False, language: str = "rel"):
|
|
3697
|
+
prop_hdrs = debugging.gen_current_propagation_headers()
|
|
3698
|
+
if self.config.get("use_graph_index", USE_GRAPH_INDEX):
|
|
3699
|
+
keep_database = not force and self.config.get("reuse_model", True)
|
|
3700
|
+
with debugging.span("release_index", name=name, keep_database=keep_database, language=language):
|
|
3701
|
+
response = self.request(
|
|
3702
|
+
"release_index",
|
|
3703
|
+
payload={
|
|
3704
|
+
"model_name": name,
|
|
3705
|
+
"keep_database": keep_database,
|
|
3706
|
+
"language": language,
|
|
3707
|
+
"user_agent": get_pyrel_version(self.generation),
|
|
3708
|
+
},
|
|
3709
|
+
headers=prop_hdrs,
|
|
3710
|
+
)
|
|
3711
|
+
if (
|
|
3712
|
+
response.status_code != 200
|
|
3713
|
+
and not (
|
|
3714
|
+
response.status_code == 404
|
|
3715
|
+
and "database not found" in response.json().get("message", "")
|
|
3716
|
+
)
|
|
3717
|
+
):
|
|
3718
|
+
raise ResponseStatusException(f"Failed to release index. Model: {name} ", response)
|
|
3719
|
+
else:
|
|
3720
|
+
with debugging.span("delete_model", name=name):
|
|
3721
|
+
self._delete_database(name, headers=prop_hdrs)
|
|
3722
|
+
|
|
3723
|
+
def clone_graph(self, target_name:str, source_name:str, nowait_durable=True, force=False):
|
|
3724
|
+
if force and self.get_graph(target_name):
|
|
3725
|
+
self.delete_graph(target_name)
|
|
3726
|
+
with debugging.span("clone_model", target_name=target_name, source_name=source_name):
|
|
3727
|
+
return self._create_database(target_name,source_name)
|
|
3728
|
+
|
|
3729
|
+
def _delete_database(self, name:str, headers:Dict={}):
|
|
3730
|
+
with debugging.span("_delete_database", dbname=name):
|
|
3731
|
+
response = self.request(
|
|
3732
|
+
"delete_db",
|
|
3733
|
+
path_params={"db_name": name},
|
|
3734
|
+
query_params={},
|
|
3735
|
+
headers=headers,
|
|
3736
|
+
)
|
|
3737
|
+
if response.status_code != 200:
|
|
3738
|
+
raise ResponseStatusException(f"Failed to delete db. db:{name} ", response)
|
|
3739
|
+
|
|
3740
|
+
def _create_database(self, name:str, source_name:str):
|
|
3741
|
+
with debugging.span("_create_database", dbname=name):
|
|
3742
|
+
payload = {
|
|
3743
|
+
"name": name,
|
|
3744
|
+
"source_name": source_name,
|
|
3745
|
+
}
|
|
3746
|
+
response = self.request(
|
|
3747
|
+
"create_db", payload=payload, headers={}, query_params={},
|
|
3748
|
+
)
|
|
3749
|
+
if response.status_code != 200:
|
|
3750
|
+
raise ResponseStatusException(f"Failed to create db. db:{name}", response)
|
|
3751
|
+
|
|
3752
|
+
#--------------------------------------------------
|
|
3753
|
+
# Engines
|
|
3754
|
+
#--------------------------------------------------
|
|
3755
|
+
|
|
3756
|
+
def list_engines(self, state: str | None = None):
|
|
3757
|
+
response = self.request("list_engines")
|
|
3758
|
+
if response.status_code != 200:
|
|
3759
|
+
raise ResponseStatusException(
|
|
3760
|
+
"Failed to retrieve engines.", response
|
|
3761
|
+
)
|
|
3762
|
+
response_content = response.json()
|
|
3763
|
+
if not response_content:
|
|
3764
|
+
return []
|
|
3765
|
+
engines = [
|
|
3766
|
+
{
|
|
3767
|
+
"name": engine["name"],
|
|
3768
|
+
"id": engine["id"],
|
|
3769
|
+
"size": engine["size"],
|
|
3770
|
+
"state": engine["status"], # callers are expecting 'state'
|
|
3771
|
+
"created_by": engine["created_by"],
|
|
3772
|
+
"created_on": engine["created_on"],
|
|
3773
|
+
"updated_on": engine["updated_on"],
|
|
3774
|
+
}
|
|
3775
|
+
for engine in response_content.get("engines", [])
|
|
3776
|
+
if state is None or engine.get("status") == state
|
|
3777
|
+
]
|
|
3778
|
+
return sorted(engines, key=lambda x: x["name"])
|
|
3779
|
+
|
|
3780
|
+
def get_engine(self, name: str):
|
|
3781
|
+
response = self.request("get_engine", path_params={"engine_name": name, "engine_type": "logic"}, skip_auto_create=True)
|
|
3782
|
+
if response.status_code == 404: # engine not found return 404
|
|
3783
|
+
return None
|
|
3784
|
+
elif response.status_code != 200:
|
|
3785
|
+
raise ResponseStatusException(
|
|
3786
|
+
f"Failed to retrieve engine {name}.", response
|
|
3787
|
+
)
|
|
3788
|
+
engine = response.json()
|
|
3789
|
+
if not engine:
|
|
3790
|
+
return None
|
|
3791
|
+
engine_state: EngineState = {
|
|
3792
|
+
"name": engine["name"],
|
|
3793
|
+
"id": engine["id"],
|
|
3794
|
+
"size": engine["size"],
|
|
3795
|
+
"state": engine["status"], # callers are expecting 'state'
|
|
3796
|
+
"created_by": engine["created_by"],
|
|
3797
|
+
"created_on": engine["created_on"],
|
|
3798
|
+
"updated_on": engine["updated_on"],
|
|
3799
|
+
"version": engine["version"],
|
|
3800
|
+
"auto_suspend": engine["auto_suspend_mins"],
|
|
3801
|
+
"suspends_at": engine["suspends_at"],
|
|
3802
|
+
}
|
|
3803
|
+
return engine_state
|
|
3804
|
+
|
|
3805
|
+
def _create_engine(
|
|
3806
|
+
self,
|
|
3807
|
+
name: str,
|
|
3808
|
+
size: str | None = None,
|
|
3809
|
+
auto_suspend_mins: int | None = None,
|
|
3810
|
+
is_async: bool = False,
|
|
3811
|
+
headers: Dict[str, str] | None = None
|
|
3812
|
+
):
|
|
3813
|
+
# only async engine creation supported via direct access
|
|
3814
|
+
if not is_async:
|
|
3815
|
+
return super()._create_engine(name, size, auto_suspend_mins, is_async, headers=headers)
|
|
3816
|
+
payload:Dict[str, Any] = {
|
|
3817
|
+
"name": name,
|
|
3818
|
+
}
|
|
3819
|
+
if auto_suspend_mins is not None:
|
|
3820
|
+
payload["auto_suspend_mins"] = auto_suspend_mins
|
|
3821
|
+
if size is not None:
|
|
3822
|
+
payload["size"] = size
|
|
3823
|
+
response = self.request(
|
|
3824
|
+
"create_engine",
|
|
3825
|
+
payload=payload,
|
|
3826
|
+
path_params={"engine_type": "logic"},
|
|
3827
|
+
headers=headers,
|
|
3828
|
+
skip_auto_create=True,
|
|
3829
|
+
)
|
|
3830
|
+
if response.status_code != 200:
|
|
3831
|
+
raise ResponseStatusException(
|
|
3832
|
+
f"Failed to create engine {name} with size {size}.", response
|
|
3833
|
+
)
|
|
3834
|
+
|
|
3835
|
+
def delete_engine(self, name:str, force:bool = False, headers={}):
|
|
3836
|
+
response = self.request(
|
|
3837
|
+
"delete_engine",
|
|
3838
|
+
path_params={"engine_name": name, "engine_type": "logic"},
|
|
3839
|
+
headers=headers,
|
|
3840
|
+
skip_auto_create=True,
|
|
3841
|
+
)
|
|
3842
|
+
if response.status_code != 200:
|
|
3843
|
+
raise ResponseStatusException(
|
|
3844
|
+
f"Failed to delete engine {name}.", response
|
|
3845
|
+
)
|
|
3846
|
+
|
|
3847
|
+
def suspend_engine(self, name:str):
|
|
3848
|
+
response = self.request(
|
|
3849
|
+
"suspend_engine",
|
|
3850
|
+
path_params={"engine_name": name, "engine_type": "logic"},
|
|
3851
|
+
skip_auto_create=True,
|
|
3852
|
+
)
|
|
3853
|
+
if response.status_code != 200:
|
|
3854
|
+
raise ResponseStatusException(
|
|
3855
|
+
f"Failed to suspend engine {name}.", response
|
|
3856
|
+
)
|
|
3857
|
+
|
|
3858
|
+
def resume_engine_async(self, name:str, headers={}):
|
|
3859
|
+
response = self.request(
|
|
3860
|
+
"resume_engine",
|
|
3861
|
+
path_params={"engine_name": name, "engine_type": "logic"},
|
|
3862
|
+
headers=headers,
|
|
3863
|
+
skip_auto_create=True,
|
|
3864
|
+
)
|
|
3865
|
+
if response.status_code != 200:
|
|
3866
|
+
raise ResponseStatusException(
|
|
3867
|
+
f"Failed to resume engine {name}.", response
|
|
3868
|
+
)
|
|
3869
|
+
return {}
|