relationalai 0.13.0.dev0__py3-none-any.whl → 0.13.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (838) hide show
  1. frontend/debugger/dist/.gitignore +2 -0
  2. frontend/debugger/dist/assets/favicon-Dy0ZgA6N.png +0 -0
  3. frontend/debugger/dist/assets/index-Cssla-O7.js +208 -0
  4. frontend/debugger/dist/assets/index-DlHsYx1V.css +9 -0
  5. frontend/debugger/dist/index.html +17 -0
  6. relationalai/__init__.py +256 -1
  7. relationalai/clients/__init__.py +18 -0
  8. relationalai/clients/client.py +947 -0
  9. relationalai/clients/config.py +673 -0
  10. relationalai/clients/direct_access_client.py +118 -0
  11. relationalai/clients/exec_txn_poller.py +91 -0
  12. relationalai/clients/hash_util.py +31 -0
  13. relationalai/clients/local.py +586 -0
  14. relationalai/clients/profile_polling.py +73 -0
  15. relationalai/clients/resources/__init__.py +8 -0
  16. relationalai/clients/resources/azure/azure.py +502 -0
  17. relationalai/clients/resources/snowflake/__init__.py +20 -0
  18. relationalai/clients/resources/snowflake/cli_resources.py +98 -0
  19. relationalai/clients/resources/snowflake/direct_access_resources.py +734 -0
  20. relationalai/clients/resources/snowflake/engine_service.py +381 -0
  21. relationalai/clients/resources/snowflake/engine_state_handlers.py +315 -0
  22. relationalai/clients/resources/snowflake/error_handlers.py +240 -0
  23. relationalai/clients/resources/snowflake/export_procedure.py.jinja +249 -0
  24. relationalai/clients/resources/snowflake/resources_factory.py +99 -0
  25. relationalai/clients/resources/snowflake/snowflake.py +3185 -0
  26. relationalai/clients/resources/snowflake/use_index_poller.py +1019 -0
  27. relationalai/clients/resources/snowflake/use_index_resources.py +188 -0
  28. relationalai/clients/resources/snowflake/util.py +387 -0
  29. relationalai/clients/result_helpers.py +420 -0
  30. relationalai/clients/types.py +118 -0
  31. relationalai/clients/util.py +356 -0
  32. relationalai/debugging.py +389 -0
  33. relationalai/dsl.py +1749 -0
  34. relationalai/early_access/builder/__init__.py +30 -0
  35. relationalai/early_access/builder/builder/__init__.py +35 -0
  36. relationalai/early_access/builder/snowflake/__init__.py +12 -0
  37. relationalai/early_access/builder/std/__init__.py +25 -0
  38. relationalai/early_access/builder/std/decimals/__init__.py +12 -0
  39. relationalai/early_access/builder/std/integers/__init__.py +12 -0
  40. relationalai/early_access/builder/std/math/__init__.py +12 -0
  41. relationalai/early_access/builder/std/strings/__init__.py +14 -0
  42. relationalai/early_access/devtools/__init__.py +12 -0
  43. relationalai/early_access/devtools/benchmark_lqp/__init__.py +12 -0
  44. relationalai/early_access/devtools/extract_lqp/__init__.py +12 -0
  45. relationalai/early_access/dsl/adapters/orm/adapter_qb.py +427 -0
  46. relationalai/early_access/dsl/adapters/orm/parser.py +636 -0
  47. relationalai/early_access/dsl/adapters/owl/adapter.py +176 -0
  48. relationalai/early_access/dsl/adapters/owl/parser.py +160 -0
  49. relationalai/early_access/dsl/bindings/common.py +402 -0
  50. relationalai/early_access/dsl/bindings/csv.py +170 -0
  51. relationalai/early_access/dsl/bindings/legacy/binding_models.py +143 -0
  52. relationalai/early_access/dsl/bindings/snowflake.py +64 -0
  53. relationalai/early_access/dsl/codegen/binder.py +411 -0
  54. relationalai/early_access/dsl/codegen/common.py +79 -0
  55. relationalai/early_access/dsl/codegen/helpers.py +23 -0
  56. relationalai/early_access/dsl/codegen/relations.py +700 -0
  57. relationalai/early_access/dsl/codegen/weaver.py +417 -0
  58. relationalai/early_access/dsl/core/builders/__init__.py +47 -0
  59. relationalai/early_access/dsl/core/builders/logic.py +19 -0
  60. relationalai/early_access/dsl/core/builders/scalar_constraint.py +11 -0
  61. relationalai/early_access/dsl/core/constraints/predicate/atomic.py +455 -0
  62. relationalai/early_access/dsl/core/constraints/predicate/universal.py +73 -0
  63. relationalai/early_access/dsl/core/constraints/scalar.py +310 -0
  64. relationalai/early_access/dsl/core/context.py +13 -0
  65. relationalai/early_access/dsl/core/cset.py +132 -0
  66. relationalai/early_access/dsl/core/exprs/__init__.py +116 -0
  67. relationalai/early_access/dsl/core/exprs/relational.py +18 -0
  68. relationalai/early_access/dsl/core/exprs/scalar.py +412 -0
  69. relationalai/early_access/dsl/core/instances.py +44 -0
  70. relationalai/early_access/dsl/core/logic/__init__.py +193 -0
  71. relationalai/early_access/dsl/core/logic/aggregation.py +98 -0
  72. relationalai/early_access/dsl/core/logic/exists.py +223 -0
  73. relationalai/early_access/dsl/core/logic/helper.py +163 -0
  74. relationalai/early_access/dsl/core/namespaces.py +32 -0
  75. relationalai/early_access/dsl/core/relations.py +276 -0
  76. relationalai/early_access/dsl/core/rules.py +112 -0
  77. relationalai/early_access/dsl/core/std/__init__.py +45 -0
  78. relationalai/early_access/dsl/core/temporal/recall.py +6 -0
  79. relationalai/early_access/dsl/core/types/__init__.py +270 -0
  80. relationalai/early_access/dsl/core/types/concepts.py +128 -0
  81. relationalai/early_access/dsl/core/types/constrained/__init__.py +267 -0
  82. relationalai/early_access/dsl/core/types/constrained/nominal.py +143 -0
  83. relationalai/early_access/dsl/core/types/constrained/subtype.py +124 -0
  84. relationalai/early_access/dsl/core/types/standard.py +92 -0
  85. relationalai/early_access/dsl/core/types/unconstrained.py +50 -0
  86. relationalai/early_access/dsl/core/types/variables.py +203 -0
  87. relationalai/early_access/dsl/ir/compiler.py +318 -0
  88. relationalai/early_access/dsl/ir/executor.py +260 -0
  89. relationalai/early_access/dsl/ontologies/constraints.py +88 -0
  90. relationalai/early_access/dsl/ontologies/export.py +30 -0
  91. relationalai/early_access/dsl/ontologies/models.py +453 -0
  92. relationalai/early_access/dsl/ontologies/python_printer.py +303 -0
  93. relationalai/early_access/dsl/ontologies/readings.py +60 -0
  94. relationalai/early_access/dsl/ontologies/relationships.py +322 -0
  95. relationalai/early_access/dsl/ontologies/roles.py +87 -0
  96. relationalai/early_access/dsl/ontologies/subtyping.py +55 -0
  97. relationalai/early_access/dsl/orm/constraints.py +438 -0
  98. relationalai/early_access/dsl/orm/measures/dimensions.py +200 -0
  99. relationalai/early_access/dsl/orm/measures/initializer.py +16 -0
  100. relationalai/early_access/dsl/orm/measures/measure_rules.py +275 -0
  101. relationalai/early_access/dsl/orm/measures/measures.py +299 -0
  102. relationalai/early_access/dsl/orm/measures/role_exprs.py +268 -0
  103. relationalai/early_access/dsl/orm/models.py +256 -0
  104. relationalai/early_access/dsl/orm/object_oriented_printer.py +344 -0
  105. relationalai/early_access/dsl/orm/printer.py +469 -0
  106. relationalai/early_access/dsl/orm/reasoners.py +480 -0
  107. relationalai/early_access/dsl/orm/relations.py +19 -0
  108. relationalai/early_access/dsl/orm/relationships.py +251 -0
  109. relationalai/early_access/dsl/orm/types.py +42 -0
  110. relationalai/early_access/dsl/orm/utils.py +79 -0
  111. relationalai/early_access/dsl/orm/verb.py +204 -0
  112. relationalai/early_access/dsl/physical_metadata/tables.py +133 -0
  113. relationalai/early_access/dsl/relations.py +170 -0
  114. relationalai/early_access/dsl/rulesets.py +69 -0
  115. relationalai/early_access/dsl/schemas/__init__.py +450 -0
  116. relationalai/early_access/dsl/schemas/builder.py +48 -0
  117. relationalai/early_access/dsl/schemas/comp_names.py +51 -0
  118. relationalai/early_access/dsl/schemas/components.py +203 -0
  119. relationalai/early_access/dsl/schemas/contexts.py +156 -0
  120. relationalai/early_access/dsl/schemas/exprs.py +89 -0
  121. relationalai/early_access/dsl/schemas/fragments.py +464 -0
  122. relationalai/early_access/dsl/serialization.py +79 -0
  123. relationalai/early_access/dsl/serialize/exporter.py +163 -0
  124. relationalai/early_access/dsl/snow/api.py +105 -0
  125. relationalai/early_access/dsl/snow/common.py +76 -0
  126. relationalai/early_access/dsl/state_mgmt/__init__.py +129 -0
  127. relationalai/early_access/dsl/state_mgmt/state_charts.py +125 -0
  128. relationalai/early_access/dsl/state_mgmt/transitions.py +130 -0
  129. relationalai/early_access/dsl/types/__init__.py +40 -0
  130. relationalai/early_access/dsl/types/concepts.py +12 -0
  131. relationalai/early_access/dsl/types/entities.py +135 -0
  132. relationalai/early_access/dsl/types/values.py +17 -0
  133. relationalai/early_access/dsl/utils.py +102 -0
  134. relationalai/early_access/graphs/__init__.py +13 -0
  135. relationalai/early_access/lqp/__init__.py +12 -0
  136. relationalai/early_access/lqp/compiler/__init__.py +12 -0
  137. relationalai/early_access/lqp/constructors/__init__.py +18 -0
  138. relationalai/early_access/lqp/executor/__init__.py +12 -0
  139. relationalai/early_access/lqp/ir/__init__.py +12 -0
  140. relationalai/early_access/lqp/passes/__init__.py +12 -0
  141. relationalai/early_access/lqp/pragmas/__init__.py +12 -0
  142. relationalai/early_access/lqp/primitives/__init__.py +12 -0
  143. relationalai/early_access/lqp/types/__init__.py +12 -0
  144. relationalai/early_access/lqp/utils/__init__.py +12 -0
  145. relationalai/early_access/lqp/validators/__init__.py +12 -0
  146. relationalai/early_access/metamodel/__init__.py +58 -0
  147. relationalai/early_access/metamodel/builtins/__init__.py +12 -0
  148. relationalai/early_access/metamodel/compiler/__init__.py +12 -0
  149. relationalai/early_access/metamodel/dependency/__init__.py +12 -0
  150. relationalai/early_access/metamodel/factory/__init__.py +17 -0
  151. relationalai/early_access/metamodel/helpers/__init__.py +12 -0
  152. relationalai/early_access/metamodel/ir/__init__.py +14 -0
  153. relationalai/early_access/metamodel/rewrite/__init__.py +7 -0
  154. relationalai/early_access/metamodel/typer/__init__.py +3 -0
  155. relationalai/early_access/metamodel/typer/typer/__init__.py +12 -0
  156. relationalai/early_access/metamodel/types/__init__.py +15 -0
  157. relationalai/early_access/metamodel/util/__init__.py +15 -0
  158. relationalai/early_access/metamodel/visitor/__init__.py +12 -0
  159. relationalai/early_access/rel/__init__.py +12 -0
  160. relationalai/early_access/rel/executor/__init__.py +12 -0
  161. relationalai/early_access/rel/rel_utils/__init__.py +12 -0
  162. relationalai/early_access/rel/rewrite/__init__.py +7 -0
  163. relationalai/early_access/solvers/__init__.py +19 -0
  164. relationalai/early_access/sql/__init__.py +11 -0
  165. relationalai/early_access/sql/executor/__init__.py +3 -0
  166. relationalai/early_access/sql/rewrite/__init__.py +3 -0
  167. relationalai/early_access/tests/logging/__init__.py +12 -0
  168. relationalai/early_access/tests/test_snapshot_base/__init__.py +12 -0
  169. relationalai/early_access/tests/utils/__init__.py +12 -0
  170. relationalai/environments/__init__.py +35 -0
  171. relationalai/environments/base.py +381 -0
  172. relationalai/environments/colab.py +14 -0
  173. relationalai/environments/generic.py +71 -0
  174. relationalai/environments/ipython.py +68 -0
  175. relationalai/environments/jupyter.py +9 -0
  176. relationalai/environments/snowbook.py +169 -0
  177. relationalai/errors.py +2496 -0
  178. relationalai/experimental/SF.py +38 -0
  179. relationalai/experimental/inspect.py +47 -0
  180. relationalai/experimental/pathfinder/__init__.py +158 -0
  181. relationalai/experimental/pathfinder/api.py +160 -0
  182. relationalai/experimental/pathfinder/automaton.py +584 -0
  183. relationalai/experimental/pathfinder/bridge.py +226 -0
  184. relationalai/experimental/pathfinder/compiler.py +416 -0
  185. relationalai/experimental/pathfinder/datalog.py +214 -0
  186. relationalai/experimental/pathfinder/diagnostics.py +56 -0
  187. relationalai/experimental/pathfinder/filter.py +236 -0
  188. relationalai/experimental/pathfinder/glushkov.py +439 -0
  189. relationalai/experimental/pathfinder/options.py +265 -0
  190. relationalai/experimental/pathfinder/pathfinder-v0.7.0.rel +1951 -0
  191. relationalai/experimental/pathfinder/rpq.py +344 -0
  192. relationalai/experimental/pathfinder/transition.py +200 -0
  193. relationalai/experimental/pathfinder/utils.py +26 -0
  194. relationalai/experimental/paths/README.md +107 -0
  195. relationalai/experimental/paths/api.py +143 -0
  196. relationalai/experimental/paths/benchmarks/grid_graph.py +37 -0
  197. relationalai/experimental/paths/code_organization.md +2 -0
  198. relationalai/experimental/paths/examples/Movies.ipynb +16328 -0
  199. relationalai/experimental/paths/examples/basic_example.py +40 -0
  200. relationalai/experimental/paths/examples/minimal_engine_warmup.py +3 -0
  201. relationalai/experimental/paths/examples/movie_example.py +77 -0
  202. relationalai/experimental/paths/examples/movies_data/actedin.csv +193 -0
  203. relationalai/experimental/paths/examples/movies_data/directed.csv +45 -0
  204. relationalai/experimental/paths/examples/movies_data/follows.csv +7 -0
  205. relationalai/experimental/paths/examples/movies_data/movies.csv +39 -0
  206. relationalai/experimental/paths/examples/movies_data/person.csv +134 -0
  207. relationalai/experimental/paths/examples/movies_data/produced.csv +16 -0
  208. relationalai/experimental/paths/examples/movies_data/ratings.csv +10 -0
  209. relationalai/experimental/paths/examples/movies_data/wrote.csv +11 -0
  210. relationalai/experimental/paths/examples/paths_benchmark.py +115 -0
  211. relationalai/experimental/paths/examples/paths_example.py +116 -0
  212. relationalai/experimental/paths/examples/pattern_to_automaton.py +28 -0
  213. relationalai/experimental/paths/find_paths_via_automaton.py +85 -0
  214. relationalai/experimental/paths/graph.py +185 -0
  215. relationalai/experimental/paths/path_algorithms/find_paths.py +280 -0
  216. relationalai/experimental/paths/path_algorithms/one_sided_ball_repetition.py +26 -0
  217. relationalai/experimental/paths/path_algorithms/one_sided_ball_upto.py +111 -0
  218. relationalai/experimental/paths/path_algorithms/single.py +59 -0
  219. relationalai/experimental/paths/path_algorithms/two_sided_balls_repetition.py +39 -0
  220. relationalai/experimental/paths/path_algorithms/two_sided_balls_upto.py +103 -0
  221. relationalai/experimental/paths/path_algorithms/usp-old.py +130 -0
  222. relationalai/experimental/paths/path_algorithms/usp-tuple.py +183 -0
  223. relationalai/experimental/paths/path_algorithms/usp.py +150 -0
  224. relationalai/experimental/paths/product_graph.py +93 -0
  225. relationalai/experimental/paths/rpq/automaton.py +584 -0
  226. relationalai/experimental/paths/rpq/diagnostics.py +56 -0
  227. relationalai/experimental/paths/rpq/rpq.py +378 -0
  228. relationalai/experimental/paths/tests/tests_limit_sp_max_length.py +90 -0
  229. relationalai/experimental/paths/tests/tests_limit_sp_multiple.py +119 -0
  230. relationalai/experimental/paths/tests/tests_limit_sp_single.py +104 -0
  231. relationalai/experimental/paths/tests/tests_limit_walks_multiple.py +113 -0
  232. relationalai/experimental/paths/tests/tests_limit_walks_single.py +149 -0
  233. relationalai/experimental/paths/tests/tests_one_sided_ball_repetition_multiple.py +70 -0
  234. relationalai/experimental/paths/tests/tests_one_sided_ball_repetition_single.py +64 -0
  235. relationalai/experimental/paths/tests/tests_one_sided_ball_upto_multiple.py +115 -0
  236. relationalai/experimental/paths/tests/tests_one_sided_ball_upto_single.py +75 -0
  237. relationalai/experimental/paths/tests/tests_single_paths.py +152 -0
  238. relationalai/experimental/paths/tests/tests_single_walks.py +208 -0
  239. relationalai/experimental/paths/tests/tests_single_walks_undirected.py +297 -0
  240. relationalai/experimental/paths/tests/tests_two_sided_balls_repetition_multiple.py +107 -0
  241. relationalai/experimental/paths/tests/tests_two_sided_balls_repetition_single.py +76 -0
  242. relationalai/experimental/paths/tests/tests_two_sided_balls_upto_multiple.py +76 -0
  243. relationalai/experimental/paths/tests/tests_two_sided_balls_upto_single.py +110 -0
  244. relationalai/experimental/paths/tests/tests_usp_nsp_multiple.py +229 -0
  245. relationalai/experimental/paths/tests/tests_usp_nsp_single.py +108 -0
  246. relationalai/experimental/paths/tree_agg.py +168 -0
  247. relationalai/experimental/paths/utilities/iterators.py +27 -0
  248. relationalai/experimental/paths/utilities/prefix_sum.py +91 -0
  249. relationalai/experimental/solvers.py +1087 -0
  250. relationalai/loaders/csv.py +195 -0
  251. relationalai/loaders/loader.py +177 -0
  252. relationalai/loaders/types.py +23 -0
  253. relationalai/rel_emitter.py +373 -0
  254. relationalai/rel_utils.py +185 -0
  255. relationalai/semantics/__init__.py +22 -146
  256. relationalai/semantics/designs/query_builder/identify_by.md +106 -0
  257. relationalai/semantics/devtools/benchmark_lqp.py +535 -0
  258. relationalai/semantics/devtools/compilation_manager.py +294 -0
  259. relationalai/semantics/devtools/extract_lqp.py +110 -0
  260. relationalai/semantics/internal/internal.py +3785 -0
  261. relationalai/semantics/internal/snowflake.py +325 -0
  262. relationalai/semantics/lqp/README.md +34 -0
  263. relationalai/semantics/lqp/builtins.py +16 -0
  264. relationalai/semantics/lqp/compiler.py +22 -0
  265. relationalai/semantics/lqp/constructors.py +68 -0
  266. relationalai/semantics/lqp/executor.py +469 -0
  267. relationalai/semantics/lqp/intrinsics.py +24 -0
  268. relationalai/semantics/lqp/model2lqp.py +877 -0
  269. relationalai/semantics/lqp/passes.py +680 -0
  270. relationalai/semantics/lqp/primitives.py +252 -0
  271. relationalai/semantics/lqp/result_helpers.py +202 -0
  272. relationalai/semantics/lqp/rewrite/annotate_constraints.py +57 -0
  273. relationalai/semantics/lqp/rewrite/cdc.py +216 -0
  274. relationalai/semantics/lqp/rewrite/extract_common.py +338 -0
  275. relationalai/semantics/lqp/rewrite/extract_keys.py +512 -0
  276. relationalai/semantics/lqp/rewrite/function_annotations.py +114 -0
  277. relationalai/semantics/lqp/rewrite/functional_dependencies.py +314 -0
  278. relationalai/semantics/lqp/rewrite/quantify_vars.py +296 -0
  279. relationalai/semantics/lqp/rewrite/splinter.py +76 -0
  280. relationalai/semantics/lqp/types.py +101 -0
  281. relationalai/semantics/lqp/utils.py +160 -0
  282. relationalai/semantics/lqp/validators.py +57 -0
  283. relationalai/semantics/metamodel/__init__.py +40 -6
  284. relationalai/semantics/metamodel/builtins.py +771 -205
  285. relationalai/semantics/metamodel/compiler.py +133 -0
  286. relationalai/semantics/metamodel/dependency.py +862 -0
  287. relationalai/semantics/metamodel/executor.py +61 -0
  288. relationalai/semantics/metamodel/factory.py +287 -0
  289. relationalai/semantics/metamodel/helpers.py +361 -0
  290. relationalai/semantics/metamodel/rewrite/discharge_constraints.py +39 -0
  291. relationalai/semantics/metamodel/rewrite/dnf_union_splitter.py +210 -0
  292. relationalai/semantics/metamodel/rewrite/extract_nested_logicals.py +78 -0
  293. relationalai/semantics/metamodel/rewrite/flatten.py +554 -0
  294. relationalai/semantics/metamodel/rewrite/format_outputs.py +165 -0
  295. relationalai/semantics/metamodel/typer/checker.py +353 -0
  296. relationalai/semantics/metamodel/typer/typer.py +1399 -0
  297. relationalai/semantics/metamodel/util.py +506 -0
  298. relationalai/semantics/reasoners/__init__.py +10 -0
  299. relationalai/semantics/reasoners/graph/README.md +620 -0
  300. relationalai/semantics/reasoners/graph/__init__.py +37 -0
  301. relationalai/semantics/reasoners/graph/core.py +9019 -0
  302. relationalai/semantics/reasoners/graph/design/beyond_demand_transform.md +797 -0
  303. relationalai/semantics/reasoners/graph/tests/README.md +21 -0
  304. relationalai/semantics/reasoners/optimization/__init__.py +68 -0
  305. relationalai/semantics/reasoners/optimization/common.py +88 -0
  306. relationalai/semantics/reasoners/optimization/solvers_dev.py +568 -0
  307. relationalai/semantics/reasoners/optimization/solvers_pb.py +1414 -0
  308. relationalai/semantics/rel/builtins.py +40 -0
  309. relationalai/semantics/rel/compiler.py +989 -0
  310. relationalai/semantics/rel/executor.py +362 -0
  311. relationalai/semantics/rel/rel.py +482 -0
  312. relationalai/semantics/rel/rel_utils.py +276 -0
  313. relationalai/semantics/snowflake/__init__.py +3 -0
  314. relationalai/semantics/sql/compiler.py +2503 -0
  315. relationalai/semantics/sql/executor/duck_db.py +52 -0
  316. relationalai/semantics/sql/executor/result_helpers.py +64 -0
  317. relationalai/semantics/sql/executor/snowflake.py +149 -0
  318. relationalai/semantics/sql/rewrite/denormalize.py +222 -0
  319. relationalai/semantics/sql/rewrite/double_negation.py +49 -0
  320. relationalai/semantics/sql/rewrite/recursive_union.py +127 -0
  321. relationalai/semantics/sql/rewrite/sort_output_query.py +246 -0
  322. relationalai/semantics/sql/sql.py +504 -0
  323. relationalai/semantics/std/__init__.py +40 -60
  324. relationalai/semantics/std/constraints.py +43 -37
  325. relationalai/semantics/std/datetime.py +135 -246
  326. relationalai/semantics/std/decimals.py +52 -45
  327. relationalai/semantics/std/floats.py +5 -13
  328. relationalai/semantics/std/integers.py +11 -26
  329. relationalai/semantics/std/math.py +112 -183
  330. relationalai/semantics/std/pragmas.py +11 -0
  331. relationalai/semantics/std/re.py +62 -80
  332. relationalai/semantics/std/std.py +14 -0
  333. relationalai/semantics/std/strings.py +60 -117
  334. relationalai/semantics/tests/test_snapshot_abstract.py +143 -0
  335. relationalai/semantics/tests/test_snapshot_base.py +9 -0
  336. relationalai/semantics/tests/utils.py +46 -0
  337. relationalai/std/__init__.py +70 -0
  338. relationalai/tools/cli.py +2089 -0
  339. relationalai/tools/cli_controls.py +1826 -0
  340. relationalai/tools/cli_helpers.py +802 -0
  341. relationalai/tools/debugger.py +183 -289
  342. relationalai/tools/debugger_client.py +109 -0
  343. relationalai/tools/debugger_server.py +302 -0
  344. relationalai/tools/dev.py +685 -0
  345. relationalai/tools/notes +7 -0
  346. relationalai/tools/qb_debugger.py +425 -0
  347. relationalai/util/clean_up_databases.py +95 -0
  348. relationalai/util/format.py +106 -48
  349. relationalai/util/list_databases.py +9 -0
  350. relationalai/util/otel_configuration.py +26 -0
  351. relationalai/util/otel_handler.py +484 -0
  352. relationalai/util/snowflake_handler.py +88 -0
  353. relationalai/util/span_format_test.py +43 -0
  354. relationalai/util/span_tracker.py +207 -0
  355. relationalai/util/spans_file_handler.py +72 -0
  356. relationalai/util/tracing_handler.py +34 -0
  357. relationalai-0.13.2.dist-info/METADATA +74 -0
  358. relationalai-0.13.2.dist-info/RECORD +460 -0
  359. relationalai-0.13.2.dist-info/WHEEL +4 -0
  360. relationalai-0.13.2.dist-info/entry_points.txt +3 -0
  361. relationalai-0.13.2.dist-info/licenses/LICENSE +202 -0
  362. relationalai_test_util/__init__.py +4 -0
  363. relationalai_test_util/fixtures.py +233 -0
  364. relationalai_test_util/snapshot.py +252 -0
  365. relationalai_test_util/traceback.py +118 -0
  366. relationalai/config/__init__.py +0 -56
  367. relationalai/config/config.py +0 -289
  368. relationalai/config/config_fields.py +0 -86
  369. relationalai/config/connections/__init__.py +0 -46
  370. relationalai/config/connections/base.py +0 -23
  371. relationalai/config/connections/duckdb.py +0 -29
  372. relationalai/config/connections/snowflake.py +0 -243
  373. relationalai/config/external/__init__.py +0 -17
  374. relationalai/config/external/dbt_converter.py +0 -101
  375. relationalai/config/external/dbt_models.py +0 -93
  376. relationalai/config/external/snowflake_converter.py +0 -41
  377. relationalai/config/external/snowflake_models.py +0 -85
  378. relationalai/config/external/utils.py +0 -19
  379. relationalai/semantics/backends/lqp/annotations.py +0 -11
  380. relationalai/semantics/backends/sql/sql_compiler.py +0 -327
  381. relationalai/semantics/frontend/base.py +0 -1707
  382. relationalai/semantics/frontend/core.py +0 -179
  383. relationalai/semantics/frontend/front_compiler.py +0 -1313
  384. relationalai/semantics/frontend/pprint.py +0 -408
  385. relationalai/semantics/metamodel/metamodel.py +0 -437
  386. relationalai/semantics/metamodel/metamodel_analyzer.py +0 -519
  387. relationalai/semantics/metamodel/metamodel_compiler.py +0 -0
  388. relationalai/semantics/metamodel/pprint.py +0 -412
  389. relationalai/semantics/metamodel/rewriter.py +0 -266
  390. relationalai/semantics/metamodel/typer.py +0 -1378
  391. relationalai/semantics/std/aggregates.py +0 -149
  392. relationalai/semantics/std/common.py +0 -44
  393. relationalai/semantics/std/numbers.py +0 -86
  394. relationalai/shims/executor.py +0 -147
  395. relationalai/shims/helpers.py +0 -126
  396. relationalai/shims/hoister.py +0 -221
  397. relationalai/shims/mm2v0.py +0 -1290
  398. relationalai/tools/cli/__init__.py +0 -6
  399. relationalai/tools/cli/cli.py +0 -90
  400. relationalai/tools/cli/components/__init__.py +0 -5
  401. relationalai/tools/cli/components/progress_reader.py +0 -1524
  402. relationalai/tools/cli/components/utils.py +0 -58
  403. relationalai/tools/cli/config_template.py +0 -45
  404. relationalai/tools/cli/dev.py +0 -19
  405. relationalai/tools/typer_debugger.py +0 -93
  406. relationalai/util/dataclasses.py +0 -43
  407. relationalai/util/docutils.py +0 -40
  408. relationalai/util/error.py +0 -199
  409. relationalai/util/naming.py +0 -145
  410. relationalai/util/python.py +0 -35
  411. relationalai/util/runtime.py +0 -156
  412. relationalai/util/schema.py +0 -197
  413. relationalai/util/source.py +0 -185
  414. relationalai/util/structures.py +0 -163
  415. relationalai/util/tracing.py +0 -261
  416. relationalai-0.13.0.dev0.dist-info/METADATA +0 -46
  417. relationalai-0.13.0.dev0.dist-info/RECORD +0 -488
  418. relationalai-0.13.0.dev0.dist-info/WHEEL +0 -5
  419. relationalai-0.13.0.dev0.dist-info/entry_points.txt +0 -3
  420. relationalai-0.13.0.dev0.dist-info/top_level.txt +0 -2
  421. v0/relationalai/__init__.py +0 -216
  422. v0/relationalai/clients/__init__.py +0 -5
  423. v0/relationalai/clients/azure.py +0 -477
  424. v0/relationalai/clients/client.py +0 -912
  425. v0/relationalai/clients/config.py +0 -673
  426. v0/relationalai/clients/direct_access_client.py +0 -118
  427. v0/relationalai/clients/hash_util.py +0 -31
  428. v0/relationalai/clients/local.py +0 -571
  429. v0/relationalai/clients/profile_polling.py +0 -73
  430. v0/relationalai/clients/result_helpers.py +0 -420
  431. v0/relationalai/clients/snowflake.py +0 -3869
  432. v0/relationalai/clients/types.py +0 -113
  433. v0/relationalai/clients/use_index_poller.py +0 -980
  434. v0/relationalai/clients/util.py +0 -356
  435. v0/relationalai/debugging.py +0 -389
  436. v0/relationalai/dsl.py +0 -1749
  437. v0/relationalai/early_access/builder/__init__.py +0 -30
  438. v0/relationalai/early_access/builder/builder/__init__.py +0 -35
  439. v0/relationalai/early_access/builder/snowflake/__init__.py +0 -12
  440. v0/relationalai/early_access/builder/std/__init__.py +0 -25
  441. v0/relationalai/early_access/builder/std/decimals/__init__.py +0 -12
  442. v0/relationalai/early_access/builder/std/integers/__init__.py +0 -12
  443. v0/relationalai/early_access/builder/std/math/__init__.py +0 -12
  444. v0/relationalai/early_access/builder/std/strings/__init__.py +0 -14
  445. v0/relationalai/early_access/devtools/__init__.py +0 -12
  446. v0/relationalai/early_access/devtools/benchmark_lqp/__init__.py +0 -12
  447. v0/relationalai/early_access/devtools/extract_lqp/__init__.py +0 -12
  448. v0/relationalai/early_access/dsl/adapters/orm/adapter_qb.py +0 -427
  449. v0/relationalai/early_access/dsl/adapters/orm/parser.py +0 -636
  450. v0/relationalai/early_access/dsl/adapters/owl/adapter.py +0 -176
  451. v0/relationalai/early_access/dsl/adapters/owl/parser.py +0 -160
  452. v0/relationalai/early_access/dsl/bindings/common.py +0 -402
  453. v0/relationalai/early_access/dsl/bindings/csv.py +0 -170
  454. v0/relationalai/early_access/dsl/bindings/legacy/binding_models.py +0 -143
  455. v0/relationalai/early_access/dsl/bindings/snowflake.py +0 -64
  456. v0/relationalai/early_access/dsl/codegen/binder.py +0 -411
  457. v0/relationalai/early_access/dsl/codegen/common.py +0 -79
  458. v0/relationalai/early_access/dsl/codegen/helpers.py +0 -23
  459. v0/relationalai/early_access/dsl/codegen/relations.py +0 -700
  460. v0/relationalai/early_access/dsl/codegen/weaver.py +0 -417
  461. v0/relationalai/early_access/dsl/core/builders/__init__.py +0 -47
  462. v0/relationalai/early_access/dsl/core/builders/logic.py +0 -19
  463. v0/relationalai/early_access/dsl/core/builders/scalar_constraint.py +0 -11
  464. v0/relationalai/early_access/dsl/core/constraints/predicate/atomic.py +0 -455
  465. v0/relationalai/early_access/dsl/core/constraints/predicate/universal.py +0 -73
  466. v0/relationalai/early_access/dsl/core/constraints/scalar.py +0 -310
  467. v0/relationalai/early_access/dsl/core/context.py +0 -13
  468. v0/relationalai/early_access/dsl/core/cset.py +0 -132
  469. v0/relationalai/early_access/dsl/core/exprs/__init__.py +0 -116
  470. v0/relationalai/early_access/dsl/core/exprs/relational.py +0 -18
  471. v0/relationalai/early_access/dsl/core/exprs/scalar.py +0 -412
  472. v0/relationalai/early_access/dsl/core/instances.py +0 -44
  473. v0/relationalai/early_access/dsl/core/logic/__init__.py +0 -193
  474. v0/relationalai/early_access/dsl/core/logic/aggregation.py +0 -98
  475. v0/relationalai/early_access/dsl/core/logic/exists.py +0 -223
  476. v0/relationalai/early_access/dsl/core/logic/helper.py +0 -163
  477. v0/relationalai/early_access/dsl/core/namespaces.py +0 -32
  478. v0/relationalai/early_access/dsl/core/relations.py +0 -276
  479. v0/relationalai/early_access/dsl/core/rules.py +0 -112
  480. v0/relationalai/early_access/dsl/core/std/__init__.py +0 -45
  481. v0/relationalai/early_access/dsl/core/temporal/recall.py +0 -6
  482. v0/relationalai/early_access/dsl/core/types/__init__.py +0 -270
  483. v0/relationalai/early_access/dsl/core/types/concepts.py +0 -128
  484. v0/relationalai/early_access/dsl/core/types/constrained/__init__.py +0 -267
  485. v0/relationalai/early_access/dsl/core/types/constrained/nominal.py +0 -143
  486. v0/relationalai/early_access/dsl/core/types/constrained/subtype.py +0 -124
  487. v0/relationalai/early_access/dsl/core/types/standard.py +0 -92
  488. v0/relationalai/early_access/dsl/core/types/unconstrained.py +0 -50
  489. v0/relationalai/early_access/dsl/core/types/variables.py +0 -203
  490. v0/relationalai/early_access/dsl/ir/compiler.py +0 -318
  491. v0/relationalai/early_access/dsl/ir/executor.py +0 -260
  492. v0/relationalai/early_access/dsl/ontologies/constraints.py +0 -88
  493. v0/relationalai/early_access/dsl/ontologies/export.py +0 -30
  494. v0/relationalai/early_access/dsl/ontologies/models.py +0 -453
  495. v0/relationalai/early_access/dsl/ontologies/python_printer.py +0 -303
  496. v0/relationalai/early_access/dsl/ontologies/readings.py +0 -60
  497. v0/relationalai/early_access/dsl/ontologies/relationships.py +0 -322
  498. v0/relationalai/early_access/dsl/ontologies/roles.py +0 -87
  499. v0/relationalai/early_access/dsl/ontologies/subtyping.py +0 -55
  500. v0/relationalai/early_access/dsl/orm/constraints.py +0 -438
  501. v0/relationalai/early_access/dsl/orm/measures/dimensions.py +0 -200
  502. v0/relationalai/early_access/dsl/orm/measures/initializer.py +0 -16
  503. v0/relationalai/early_access/dsl/orm/measures/measure_rules.py +0 -275
  504. v0/relationalai/early_access/dsl/orm/measures/measures.py +0 -299
  505. v0/relationalai/early_access/dsl/orm/measures/role_exprs.py +0 -268
  506. v0/relationalai/early_access/dsl/orm/models.py +0 -256
  507. v0/relationalai/early_access/dsl/orm/object_oriented_printer.py +0 -344
  508. v0/relationalai/early_access/dsl/orm/printer.py +0 -469
  509. v0/relationalai/early_access/dsl/orm/reasoners.py +0 -480
  510. v0/relationalai/early_access/dsl/orm/relations.py +0 -19
  511. v0/relationalai/early_access/dsl/orm/relationships.py +0 -251
  512. v0/relationalai/early_access/dsl/orm/types.py +0 -42
  513. v0/relationalai/early_access/dsl/orm/utils.py +0 -79
  514. v0/relationalai/early_access/dsl/orm/verb.py +0 -204
  515. v0/relationalai/early_access/dsl/physical_metadata/tables.py +0 -133
  516. v0/relationalai/early_access/dsl/relations.py +0 -170
  517. v0/relationalai/early_access/dsl/rulesets.py +0 -69
  518. v0/relationalai/early_access/dsl/schemas/__init__.py +0 -450
  519. v0/relationalai/early_access/dsl/schemas/builder.py +0 -48
  520. v0/relationalai/early_access/dsl/schemas/comp_names.py +0 -51
  521. v0/relationalai/early_access/dsl/schemas/components.py +0 -203
  522. v0/relationalai/early_access/dsl/schemas/contexts.py +0 -156
  523. v0/relationalai/early_access/dsl/schemas/exprs.py +0 -89
  524. v0/relationalai/early_access/dsl/schemas/fragments.py +0 -464
  525. v0/relationalai/early_access/dsl/serialization.py +0 -79
  526. v0/relationalai/early_access/dsl/serialize/exporter.py +0 -163
  527. v0/relationalai/early_access/dsl/snow/api.py +0 -104
  528. v0/relationalai/early_access/dsl/snow/common.py +0 -76
  529. v0/relationalai/early_access/dsl/state_mgmt/__init__.py +0 -129
  530. v0/relationalai/early_access/dsl/state_mgmt/state_charts.py +0 -125
  531. v0/relationalai/early_access/dsl/state_mgmt/transitions.py +0 -130
  532. v0/relationalai/early_access/dsl/types/__init__.py +0 -40
  533. v0/relationalai/early_access/dsl/types/concepts.py +0 -12
  534. v0/relationalai/early_access/dsl/types/entities.py +0 -135
  535. v0/relationalai/early_access/dsl/types/values.py +0 -17
  536. v0/relationalai/early_access/dsl/utils.py +0 -102
  537. v0/relationalai/early_access/graphs/__init__.py +0 -13
  538. v0/relationalai/early_access/lqp/__init__.py +0 -12
  539. v0/relationalai/early_access/lqp/compiler/__init__.py +0 -12
  540. v0/relationalai/early_access/lqp/constructors/__init__.py +0 -18
  541. v0/relationalai/early_access/lqp/executor/__init__.py +0 -12
  542. v0/relationalai/early_access/lqp/ir/__init__.py +0 -12
  543. v0/relationalai/early_access/lqp/passes/__init__.py +0 -12
  544. v0/relationalai/early_access/lqp/pragmas/__init__.py +0 -12
  545. v0/relationalai/early_access/lqp/primitives/__init__.py +0 -12
  546. v0/relationalai/early_access/lqp/types/__init__.py +0 -12
  547. v0/relationalai/early_access/lqp/utils/__init__.py +0 -12
  548. v0/relationalai/early_access/lqp/validators/__init__.py +0 -12
  549. v0/relationalai/early_access/metamodel/__init__.py +0 -58
  550. v0/relationalai/early_access/metamodel/builtins/__init__.py +0 -12
  551. v0/relationalai/early_access/metamodel/compiler/__init__.py +0 -12
  552. v0/relationalai/early_access/metamodel/dependency/__init__.py +0 -12
  553. v0/relationalai/early_access/metamodel/factory/__init__.py +0 -17
  554. v0/relationalai/early_access/metamodel/helpers/__init__.py +0 -12
  555. v0/relationalai/early_access/metamodel/ir/__init__.py +0 -14
  556. v0/relationalai/early_access/metamodel/rewrite/__init__.py +0 -7
  557. v0/relationalai/early_access/metamodel/typer/__init__.py +0 -3
  558. v0/relationalai/early_access/metamodel/typer/typer/__init__.py +0 -12
  559. v0/relationalai/early_access/metamodel/types/__init__.py +0 -15
  560. v0/relationalai/early_access/metamodel/util/__init__.py +0 -15
  561. v0/relationalai/early_access/metamodel/visitor/__init__.py +0 -12
  562. v0/relationalai/early_access/rel/__init__.py +0 -12
  563. v0/relationalai/early_access/rel/executor/__init__.py +0 -12
  564. v0/relationalai/early_access/rel/rel_utils/__init__.py +0 -12
  565. v0/relationalai/early_access/rel/rewrite/__init__.py +0 -7
  566. v0/relationalai/early_access/solvers/__init__.py +0 -19
  567. v0/relationalai/early_access/sql/__init__.py +0 -11
  568. v0/relationalai/early_access/sql/executor/__init__.py +0 -3
  569. v0/relationalai/early_access/sql/rewrite/__init__.py +0 -3
  570. v0/relationalai/early_access/tests/logging/__init__.py +0 -12
  571. v0/relationalai/early_access/tests/test_snapshot_base/__init__.py +0 -12
  572. v0/relationalai/early_access/tests/utils/__init__.py +0 -12
  573. v0/relationalai/environments/__init__.py +0 -35
  574. v0/relationalai/environments/base.py +0 -381
  575. v0/relationalai/environments/colab.py +0 -14
  576. v0/relationalai/environments/generic.py +0 -71
  577. v0/relationalai/environments/ipython.py +0 -68
  578. v0/relationalai/environments/jupyter.py +0 -9
  579. v0/relationalai/environments/snowbook.py +0 -169
  580. v0/relationalai/errors.py +0 -2455
  581. v0/relationalai/experimental/SF.py +0 -38
  582. v0/relationalai/experimental/inspect.py +0 -47
  583. v0/relationalai/experimental/pathfinder/__init__.py +0 -158
  584. v0/relationalai/experimental/pathfinder/api.py +0 -160
  585. v0/relationalai/experimental/pathfinder/automaton.py +0 -584
  586. v0/relationalai/experimental/pathfinder/bridge.py +0 -226
  587. v0/relationalai/experimental/pathfinder/compiler.py +0 -416
  588. v0/relationalai/experimental/pathfinder/datalog.py +0 -214
  589. v0/relationalai/experimental/pathfinder/diagnostics.py +0 -56
  590. v0/relationalai/experimental/pathfinder/filter.py +0 -236
  591. v0/relationalai/experimental/pathfinder/glushkov.py +0 -439
  592. v0/relationalai/experimental/pathfinder/options.py +0 -265
  593. v0/relationalai/experimental/pathfinder/rpq.py +0 -344
  594. v0/relationalai/experimental/pathfinder/transition.py +0 -200
  595. v0/relationalai/experimental/pathfinder/utils.py +0 -26
  596. v0/relationalai/experimental/paths/api.py +0 -143
  597. v0/relationalai/experimental/paths/benchmarks/grid_graph.py +0 -37
  598. v0/relationalai/experimental/paths/examples/basic_example.py +0 -40
  599. v0/relationalai/experimental/paths/examples/minimal_engine_warmup.py +0 -3
  600. v0/relationalai/experimental/paths/examples/movie_example.py +0 -77
  601. v0/relationalai/experimental/paths/examples/paths_benchmark.py +0 -115
  602. v0/relationalai/experimental/paths/examples/paths_example.py +0 -116
  603. v0/relationalai/experimental/paths/examples/pattern_to_automaton.py +0 -28
  604. v0/relationalai/experimental/paths/find_paths_via_automaton.py +0 -85
  605. v0/relationalai/experimental/paths/graph.py +0 -185
  606. v0/relationalai/experimental/paths/path_algorithms/find_paths.py +0 -280
  607. v0/relationalai/experimental/paths/path_algorithms/one_sided_ball_repetition.py +0 -26
  608. v0/relationalai/experimental/paths/path_algorithms/one_sided_ball_upto.py +0 -111
  609. v0/relationalai/experimental/paths/path_algorithms/single.py +0 -59
  610. v0/relationalai/experimental/paths/path_algorithms/two_sided_balls_repetition.py +0 -39
  611. v0/relationalai/experimental/paths/path_algorithms/two_sided_balls_upto.py +0 -103
  612. v0/relationalai/experimental/paths/path_algorithms/usp-old.py +0 -130
  613. v0/relationalai/experimental/paths/path_algorithms/usp-tuple.py +0 -183
  614. v0/relationalai/experimental/paths/path_algorithms/usp.py +0 -150
  615. v0/relationalai/experimental/paths/product_graph.py +0 -93
  616. v0/relationalai/experimental/paths/rpq/automaton.py +0 -584
  617. v0/relationalai/experimental/paths/rpq/diagnostics.py +0 -56
  618. v0/relationalai/experimental/paths/rpq/rpq.py +0 -378
  619. v0/relationalai/experimental/paths/tests/tests_limit_sp_max_length.py +0 -90
  620. v0/relationalai/experimental/paths/tests/tests_limit_sp_multiple.py +0 -119
  621. v0/relationalai/experimental/paths/tests/tests_limit_sp_single.py +0 -104
  622. v0/relationalai/experimental/paths/tests/tests_limit_walks_multiple.py +0 -113
  623. v0/relationalai/experimental/paths/tests/tests_limit_walks_single.py +0 -149
  624. v0/relationalai/experimental/paths/tests/tests_one_sided_ball_repetition_multiple.py +0 -70
  625. v0/relationalai/experimental/paths/tests/tests_one_sided_ball_repetition_single.py +0 -64
  626. v0/relationalai/experimental/paths/tests/tests_one_sided_ball_upto_multiple.py +0 -115
  627. v0/relationalai/experimental/paths/tests/tests_one_sided_ball_upto_single.py +0 -75
  628. v0/relationalai/experimental/paths/tests/tests_single_paths.py +0 -152
  629. v0/relationalai/experimental/paths/tests/tests_single_walks.py +0 -208
  630. v0/relationalai/experimental/paths/tests/tests_single_walks_undirected.py +0 -297
  631. v0/relationalai/experimental/paths/tests/tests_two_sided_balls_repetition_multiple.py +0 -107
  632. v0/relationalai/experimental/paths/tests/tests_two_sided_balls_repetition_single.py +0 -76
  633. v0/relationalai/experimental/paths/tests/tests_two_sided_balls_upto_multiple.py +0 -76
  634. v0/relationalai/experimental/paths/tests/tests_two_sided_balls_upto_single.py +0 -110
  635. v0/relationalai/experimental/paths/tests/tests_usp_nsp_multiple.py +0 -229
  636. v0/relationalai/experimental/paths/tests/tests_usp_nsp_single.py +0 -108
  637. v0/relationalai/experimental/paths/tree_agg.py +0 -168
  638. v0/relationalai/experimental/paths/utilities/iterators.py +0 -27
  639. v0/relationalai/experimental/paths/utilities/prefix_sum.py +0 -91
  640. v0/relationalai/experimental/solvers.py +0 -1087
  641. v0/relationalai/loaders/csv.py +0 -195
  642. v0/relationalai/loaders/loader.py +0 -177
  643. v0/relationalai/loaders/types.py +0 -23
  644. v0/relationalai/rel_emitter.py +0 -373
  645. v0/relationalai/rel_utils.py +0 -185
  646. v0/relationalai/semantics/__init__.py +0 -29
  647. v0/relationalai/semantics/devtools/benchmark_lqp.py +0 -536
  648. v0/relationalai/semantics/devtools/compilation_manager.py +0 -294
  649. v0/relationalai/semantics/devtools/extract_lqp.py +0 -110
  650. v0/relationalai/semantics/internal/internal.py +0 -3785
  651. v0/relationalai/semantics/internal/snowflake.py +0 -324
  652. v0/relationalai/semantics/lqp/builtins.py +0 -16
  653. v0/relationalai/semantics/lqp/compiler.py +0 -22
  654. v0/relationalai/semantics/lqp/constructors.py +0 -68
  655. v0/relationalai/semantics/lqp/executor.py +0 -469
  656. v0/relationalai/semantics/lqp/intrinsics.py +0 -24
  657. v0/relationalai/semantics/lqp/model2lqp.py +0 -839
  658. v0/relationalai/semantics/lqp/passes.py +0 -680
  659. v0/relationalai/semantics/lqp/primitives.py +0 -252
  660. v0/relationalai/semantics/lqp/result_helpers.py +0 -202
  661. v0/relationalai/semantics/lqp/rewrite/annotate_constraints.py +0 -57
  662. v0/relationalai/semantics/lqp/rewrite/cdc.py +0 -216
  663. v0/relationalai/semantics/lqp/rewrite/extract_common.py +0 -338
  664. v0/relationalai/semantics/lqp/rewrite/extract_keys.py +0 -449
  665. v0/relationalai/semantics/lqp/rewrite/function_annotations.py +0 -114
  666. v0/relationalai/semantics/lqp/rewrite/functional_dependencies.py +0 -314
  667. v0/relationalai/semantics/lqp/rewrite/quantify_vars.py +0 -296
  668. v0/relationalai/semantics/lqp/rewrite/splinter.py +0 -76
  669. v0/relationalai/semantics/lqp/types.py +0 -101
  670. v0/relationalai/semantics/lqp/utils.py +0 -160
  671. v0/relationalai/semantics/lqp/validators.py +0 -57
  672. v0/relationalai/semantics/metamodel/__init__.py +0 -40
  673. v0/relationalai/semantics/metamodel/builtins.py +0 -774
  674. v0/relationalai/semantics/metamodel/compiler.py +0 -133
  675. v0/relationalai/semantics/metamodel/dependency.py +0 -862
  676. v0/relationalai/semantics/metamodel/executor.py +0 -61
  677. v0/relationalai/semantics/metamodel/factory.py +0 -287
  678. v0/relationalai/semantics/metamodel/helpers.py +0 -361
  679. v0/relationalai/semantics/metamodel/rewrite/discharge_constraints.py +0 -39
  680. v0/relationalai/semantics/metamodel/rewrite/dnf_union_splitter.py +0 -210
  681. v0/relationalai/semantics/metamodel/rewrite/extract_nested_logicals.py +0 -78
  682. v0/relationalai/semantics/metamodel/rewrite/flatten.py +0 -549
  683. v0/relationalai/semantics/metamodel/rewrite/format_outputs.py +0 -165
  684. v0/relationalai/semantics/metamodel/typer/checker.py +0 -353
  685. v0/relationalai/semantics/metamodel/typer/typer.py +0 -1395
  686. v0/relationalai/semantics/metamodel/util.py +0 -505
  687. v0/relationalai/semantics/reasoners/__init__.py +0 -10
  688. v0/relationalai/semantics/reasoners/graph/__init__.py +0 -37
  689. v0/relationalai/semantics/reasoners/graph/core.py +0 -9020
  690. v0/relationalai/semantics/reasoners/optimization/__init__.py +0 -68
  691. v0/relationalai/semantics/reasoners/optimization/common.py +0 -88
  692. v0/relationalai/semantics/reasoners/optimization/solvers_dev.py +0 -568
  693. v0/relationalai/semantics/reasoners/optimization/solvers_pb.py +0 -1163
  694. v0/relationalai/semantics/rel/builtins.py +0 -40
  695. v0/relationalai/semantics/rel/compiler.py +0 -989
  696. v0/relationalai/semantics/rel/executor.py +0 -359
  697. v0/relationalai/semantics/rel/rel.py +0 -482
  698. v0/relationalai/semantics/rel/rel_utils.py +0 -276
  699. v0/relationalai/semantics/snowflake/__init__.py +0 -3
  700. v0/relationalai/semantics/sql/compiler.py +0 -2503
  701. v0/relationalai/semantics/sql/executor/duck_db.py +0 -52
  702. v0/relationalai/semantics/sql/executor/result_helpers.py +0 -64
  703. v0/relationalai/semantics/sql/executor/snowflake.py +0 -145
  704. v0/relationalai/semantics/sql/rewrite/denormalize.py +0 -222
  705. v0/relationalai/semantics/sql/rewrite/double_negation.py +0 -49
  706. v0/relationalai/semantics/sql/rewrite/recursive_union.py +0 -127
  707. v0/relationalai/semantics/sql/rewrite/sort_output_query.py +0 -246
  708. v0/relationalai/semantics/sql/sql.py +0 -504
  709. v0/relationalai/semantics/std/__init__.py +0 -54
  710. v0/relationalai/semantics/std/constraints.py +0 -43
  711. v0/relationalai/semantics/std/datetime.py +0 -363
  712. v0/relationalai/semantics/std/decimals.py +0 -62
  713. v0/relationalai/semantics/std/floats.py +0 -7
  714. v0/relationalai/semantics/std/integers.py +0 -22
  715. v0/relationalai/semantics/std/math.py +0 -141
  716. v0/relationalai/semantics/std/pragmas.py +0 -11
  717. v0/relationalai/semantics/std/re.py +0 -83
  718. v0/relationalai/semantics/std/std.py +0 -14
  719. v0/relationalai/semantics/std/strings.py +0 -63
  720. v0/relationalai/semantics/tests/__init__.py +0 -0
  721. v0/relationalai/semantics/tests/test_snapshot_abstract.py +0 -143
  722. v0/relationalai/semantics/tests/test_snapshot_base.py +0 -9
  723. v0/relationalai/semantics/tests/utils.py +0 -46
  724. v0/relationalai/std/__init__.py +0 -70
  725. v0/relationalai/tools/__init__.py +0 -0
  726. v0/relationalai/tools/cli.py +0 -1940
  727. v0/relationalai/tools/cli_controls.py +0 -1826
  728. v0/relationalai/tools/cli_helpers.py +0 -390
  729. v0/relationalai/tools/debugger.py +0 -183
  730. v0/relationalai/tools/debugger_client.py +0 -109
  731. v0/relationalai/tools/debugger_server.py +0 -302
  732. v0/relationalai/tools/dev.py +0 -685
  733. v0/relationalai/tools/qb_debugger.py +0 -425
  734. v0/relationalai/util/clean_up_databases.py +0 -95
  735. v0/relationalai/util/format.py +0 -123
  736. v0/relationalai/util/list_databases.py +0 -9
  737. v0/relationalai/util/otel_configuration.py +0 -25
  738. v0/relationalai/util/otel_handler.py +0 -484
  739. v0/relationalai/util/snowflake_handler.py +0 -88
  740. v0/relationalai/util/span_format_test.py +0 -43
  741. v0/relationalai/util/span_tracker.py +0 -207
  742. v0/relationalai/util/spans_file_handler.py +0 -72
  743. v0/relationalai/util/tracing_handler.py +0 -34
  744. /relationalai/{semantics/frontend → analysis}/__init__.py +0 -0
  745. {v0/relationalai → relationalai}/analysis/mechanistic.py +0 -0
  746. {v0/relationalai → relationalai}/analysis/whynot.py +0 -0
  747. /relationalai/{shims → auth}/__init__.py +0 -0
  748. {v0/relationalai → relationalai}/auth/jwt_generator.py +0 -0
  749. {v0/relationalai → relationalai}/auth/oauth_callback_server.py +0 -0
  750. {v0/relationalai → relationalai}/auth/token_handler.py +0 -0
  751. {v0/relationalai → relationalai}/auth/util.py +0 -0
  752. {v0/relationalai/clients → relationalai/clients/resources/snowflake}/cache_store.py +0 -0
  753. {v0/relationalai → relationalai}/compiler.py +0 -0
  754. {v0/relationalai → relationalai}/dependencies.py +0 -0
  755. {v0/relationalai → relationalai}/docutils.py +0 -0
  756. {v0/relationalai/analysis → relationalai/early_access}/__init__.py +0 -0
  757. {v0/relationalai → relationalai}/early_access/dsl/__init__.py +0 -0
  758. {v0/relationalai/auth → relationalai/early_access/dsl/adapters}/__init__.py +0 -0
  759. {v0/relationalai/early_access → relationalai/early_access/dsl/adapters/orm}/__init__.py +0 -0
  760. {v0/relationalai → relationalai}/early_access/dsl/adapters/orm/model.py +0 -0
  761. {v0/relationalai/early_access/dsl/adapters → relationalai/early_access/dsl/adapters/owl}/__init__.py +0 -0
  762. {v0/relationalai → relationalai}/early_access/dsl/adapters/owl/model.py +0 -0
  763. {v0/relationalai/early_access/dsl/adapters/orm → relationalai/early_access/dsl/bindings}/__init__.py +0 -0
  764. {v0/relationalai/early_access/dsl/adapters/owl → relationalai/early_access/dsl/bindings/legacy}/__init__.py +0 -0
  765. {v0/relationalai/early_access/dsl/bindings → relationalai/early_access/dsl/codegen}/__init__.py +0 -0
  766. {v0/relationalai → relationalai}/early_access/dsl/constants.py +0 -0
  767. {v0/relationalai → relationalai}/early_access/dsl/core/__init__.py +0 -0
  768. {v0/relationalai → relationalai}/early_access/dsl/core/constraints/__init__.py +0 -0
  769. {v0/relationalai → relationalai}/early_access/dsl/core/constraints/predicate/__init__.py +0 -0
  770. {v0/relationalai → relationalai}/early_access/dsl/core/stack.py +0 -0
  771. {v0/relationalai/early_access/dsl/bindings/legacy → relationalai/early_access/dsl/core/temporal}/__init__.py +0 -0
  772. {v0/relationalai → relationalai}/early_access/dsl/core/utils.py +0 -0
  773. {v0/relationalai/early_access/dsl/codegen → relationalai/early_access/dsl/ir}/__init__.py +0 -0
  774. {v0/relationalai/early_access/dsl/core/temporal → relationalai/early_access/dsl/ontologies}/__init__.py +0 -0
  775. {v0/relationalai → relationalai}/early_access/dsl/ontologies/raw_source.py +0 -0
  776. {v0/relationalai/early_access/dsl/ir → relationalai/early_access/dsl/orm}/__init__.py +0 -0
  777. {v0/relationalai/early_access/dsl/ontologies → relationalai/early_access/dsl/orm/measures}/__init__.py +0 -0
  778. {v0/relationalai → relationalai}/early_access/dsl/orm/reasoner_errors.py +0 -0
  779. {v0/relationalai/early_access/dsl/orm → relationalai/early_access/dsl/physical_metadata}/__init__.py +0 -0
  780. {v0/relationalai/early_access/dsl/orm/measures → relationalai/early_access/dsl/serialize}/__init__.py +0 -0
  781. {v0/relationalai → relationalai}/early_access/dsl/serialize/binding_model.py +0 -0
  782. {v0/relationalai → relationalai}/early_access/dsl/serialize/model.py +0 -0
  783. {v0/relationalai/early_access/dsl/physical_metadata → relationalai/early_access/dsl/snow}/__init__.py +0 -0
  784. {v0/relationalai → relationalai}/early_access/tests/__init__.py +0 -0
  785. {v0/relationalai → relationalai}/environments/ci.py +0 -0
  786. {v0/relationalai → relationalai}/environments/hex.py +0 -0
  787. {v0/relationalai → relationalai}/environments/terminal.py +0 -0
  788. {v0/relationalai → relationalai}/experimental/__init__.py +0 -0
  789. {v0/relationalai → relationalai}/experimental/graphs.py +0 -0
  790. {v0/relationalai → relationalai}/experimental/paths/__init__.py +0 -0
  791. {v0/relationalai → relationalai}/experimental/paths/benchmarks/__init__.py +0 -0
  792. {v0/relationalai → relationalai}/experimental/paths/path_algorithms/__init__.py +0 -0
  793. {v0/relationalai → relationalai}/experimental/paths/rpq/__init__.py +0 -0
  794. {v0/relationalai → relationalai}/experimental/paths/rpq/filter.py +0 -0
  795. {v0/relationalai → relationalai}/experimental/paths/rpq/glushkov.py +0 -0
  796. {v0/relationalai → relationalai}/experimental/paths/rpq/transition.py +0 -0
  797. {v0/relationalai → relationalai}/experimental/paths/utilities/__init__.py +0 -0
  798. {v0/relationalai → relationalai}/experimental/paths/utilities/utilities.py +0 -0
  799. {v0/relationalai/early_access/dsl/serialize → relationalai/loaders}/__init__.py +0 -0
  800. {v0/relationalai → relationalai}/metagen.py +0 -0
  801. {v0/relationalai → relationalai}/metamodel.py +0 -0
  802. {v0/relationalai → relationalai}/rel.py +0 -0
  803. {v0/relationalai → relationalai}/semantics/devtools/__init__.py +0 -0
  804. {v0/relationalai → relationalai}/semantics/internal/__init__.py +0 -0
  805. {v0/relationalai → relationalai}/semantics/internal/annotations.py +0 -0
  806. {v0/relationalai → relationalai}/semantics/lqp/__init__.py +0 -0
  807. {v0/relationalai → relationalai}/semantics/lqp/ir.py +0 -0
  808. {v0/relationalai → relationalai}/semantics/lqp/pragmas.py +0 -0
  809. {v0/relationalai → relationalai}/semantics/lqp/rewrite/__init__.py +0 -0
  810. {v0/relationalai → relationalai}/semantics/metamodel/dataflow.py +0 -0
  811. {v0/relationalai → relationalai}/semantics/metamodel/ir.py +0 -0
  812. {v0/relationalai → relationalai}/semantics/metamodel/rewrite/__init__.py +0 -0
  813. {v0/relationalai → relationalai}/semantics/metamodel/typer/__init__.py +0 -0
  814. {v0/relationalai → relationalai}/semantics/metamodel/types.py +0 -0
  815. {v0/relationalai → relationalai}/semantics/metamodel/visitor.py +0 -0
  816. {v0/relationalai → relationalai}/semantics/reasoners/experimental/__init__.py +0 -0
  817. {v0/relationalai → relationalai}/semantics/rel/__init__.py +0 -0
  818. {v0/relationalai → relationalai}/semantics/sql/__init__.py +0 -0
  819. {v0/relationalai → relationalai}/semantics/sql/executor/__init__.py +0 -0
  820. {v0/relationalai → relationalai}/semantics/sql/rewrite/__init__.py +0 -0
  821. {v0/relationalai/early_access/dsl/snow → relationalai/semantics/tests}/__init__.py +0 -0
  822. {v0/relationalai → relationalai}/semantics/tests/logging.py +0 -0
  823. {v0/relationalai → relationalai}/std/aggregates.py +0 -0
  824. {v0/relationalai → relationalai}/std/dates.py +0 -0
  825. {v0/relationalai → relationalai}/std/graphs.py +0 -0
  826. {v0/relationalai → relationalai}/std/inspect.py +0 -0
  827. {v0/relationalai → relationalai}/std/math.py +0 -0
  828. {v0/relationalai → relationalai}/std/re.py +0 -0
  829. {v0/relationalai → relationalai}/std/strings.py +0 -0
  830. {v0/relationalai/loaders → relationalai/tools}/__init__.py +0 -0
  831. {v0/relationalai → relationalai}/tools/cleanup_snapshots.py +0 -0
  832. {v0/relationalai → relationalai}/tools/constants.py +0 -0
  833. {v0/relationalai → relationalai}/tools/query_utils.py +0 -0
  834. {v0/relationalai → relationalai}/tools/snapshot_viewer.py +0 -0
  835. {v0/relationalai → relationalai}/util/__init__.py +0 -0
  836. {v0/relationalai → relationalai}/util/constants.py +0 -0
  837. {v0/relationalai → relationalai}/util/graph.py +0 -0
  838. {v0/relationalai → relationalai}/util/timeout.py +0 -0
@@ -1,3869 +0,0 @@
1
- # pyright: reportUnusedExpression=false
2
- from __future__ import annotations
3
- import base64
4
- import decimal
5
- import importlib.resources
6
- import io
7
- from numbers import Number
8
- import re
9
- import json
10
- import time
11
- import textwrap
12
- import ast
13
- import uuid
14
- import warnings
15
- import atexit
16
- import hashlib
17
-
18
-
19
- from v0.relationalai.auth.token_handler import TokenHandler
20
- from v0.relationalai.clients.use_index_poller import DirectUseIndexPoller, UseIndexPoller
21
- import snowflake.snowpark
22
-
23
- from v0.relationalai.rel_utils import sanitize_identifier, to_fqn_relation_name
24
- from v0.relationalai.tools.constants import FIELD_PLACEHOLDER, RAI_APP_NAME, SNOWFLAKE_AUTHS, USE_GRAPH_INDEX, USE_DIRECT_ACCESS, DEFAULT_QUERY_TIMEOUT_MINS, WAIT_FOR_STREAM_SYNC, Generation
25
- from .. import std
26
- from collections import defaultdict
27
- import requests
28
- import snowflake.connector
29
- import pyarrow as pa
30
-
31
- from snowflake.snowpark import Session
32
- from snowflake.snowpark.context import get_active_session
33
- from . import result_helpers
34
- from .. import debugging
35
- from typing import Any, Dict, Iterable, Optional, Tuple, List, Literal, Union, cast
36
-
37
- from pandas import DataFrame
38
-
39
- from ..tools.cli_controls import Spinner
40
- from ..clients.types import AvailableModel, EngineState, Import, ImportSource, ImportSourceTable, ImportsStatus, SourceInfo, TransactionAsyncResponse
41
- from ..clients.config import Config, ConfigStore, ENDPOINT_FILE
42
- from ..clients.client import Client, ExportParams, ProviderBase, ResourcesBase
43
- from ..clients.direct_access_client import DirectAccessClient
44
- from ..clients.util import IdentityParser, escape_for_f_string, get_pyrel_version, get_with_retries, poll_with_specified_overhead, safe_json_loads, sanitize_module_name, scrub_exception, wrap_with_request_id, ms_to_timestamp, normalize_datetime
45
- from ..environments import runtime_env, HexEnvironment, SnowbookEnvironment
46
- from .. import dsl, rel, metamodel as m
47
- from ..errors import DuoSecurityFailed, EngineProvisioningFailed, EngineNameValidationException, EngineNotFoundException, EnginePending, EngineSizeMismatchWarning, EngineResumeFailed, Errors, InvalidAliasError, InvalidEngineSizeError, InvalidSourceTypeWarning, RAIAbortedTransactionError, RAIException, HexSessionException, SnowflakeAppMissingException, SnowflakeChangeTrackingNotEnabledException, SnowflakeDatabaseException, SnowflakeImportMissingException, SnowflakeInvalidSource, SnowflakeMissingConfigValuesException, SnowflakeProxyAPIDeprecationWarning, SnowflakeProxySourceError, SnowflakeRaiAppNotStarted, ModelNotFoundException, UnknownSourceWarning, ResponseStatusException, RowsDroppedFromTargetTableWarning, QueryTimeoutExceededException
48
- from concurrent.futures import ThreadPoolExecutor
49
- from datetime import datetime, date, timedelta
50
- from snowflake.snowpark.types import StringType, StructField, StructType
51
-
52
- # warehouse-based snowflake notebooks currently don't have hazmat
53
- crypto_disabled = False
54
- try:
55
- from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
56
- from cryptography.hazmat.backends import default_backend
57
- from cryptography.hazmat.primitives import padding
58
- except (ModuleNotFoundError, ImportError):
59
- crypto_disabled = True
60
-
61
- #--------------------------------------------------
62
- # Constants
63
- #--------------------------------------------------
64
-
65
- VALID_POOL_STATUS = ["ACTIVE", "IDLE", "SUSPENDED"]
66
- # transaction list and get return different fields (duration vs timings)
67
- LIST_TXN_SQL_FIELDS = ["id", "database_name", "engine_name", "state", "abort_reason", "read_only","created_by", "created_on", "finished_at", "duration"]
68
- GET_TXN_SQL_FIELDS = ["id", "database", "engine", "state", "abort_reason", "read_only","created_by", "created_on", "finished_at", "timings"]
69
- IMPORT_STREAM_FIELDS = ["ID", "CREATED_AT", "CREATED_BY", "STATUS", "REFERENCE_NAME", "REFERENCE_ALIAS", "FQ_OBJECT_NAME", "RAI_DATABASE",
70
- "RAI_RELATION", "DATA_SYNC_STATUS", "PENDING_BATCHES_COUNT", "NEXT_BATCH_STATUS", "NEXT_BATCH_UNLOADED_TIMESTAMP",
71
- "NEXT_BATCH_DETAILS", "LAST_BATCH_DETAILS", "LAST_BATCH_UNLOADED_TIMESTAMP", "CDC_STATUS"]
72
- VALID_ENGINE_STATES = ["READY", "PENDING"]
73
-
74
- # Cloud-specific engine sizes
75
- INTERNAL_ENGINE_SIZES = ["XS", "S", "M", "L"]
76
- ENGINE_SIZES_AWS = ["HIGHMEM_X64_S", "HIGHMEM_X64_M", "HIGHMEM_X64_L"]
77
- ENGINE_SIZES_AZURE = ["HIGHMEM_X64_S", "HIGHMEM_X64_M", "HIGHMEM_X64_SL"]
78
-
79
- FIELD_MAP = {
80
- "database_name": "database",
81
- "engine_name": "engine",
82
- }
83
- VALID_IMPORT_STATES = ["PENDING", "PROCESSING", "QUARANTINED", "LOADED"]
84
- ENGINE_ERRORS = ["engine is suspended", "create/resume", "engine not found", "no engines found", "engine was deleted"]
85
- ENGINE_NOT_READY_MSGS = ["engine is in pending", "engine is provisioning"]
86
- DATABASE_ERRORS = ["database not found"]
87
- PYREL_ROOT_DB = 'pyrel_root_db'
88
-
89
- TERMINAL_TXN_STATES = ["COMPLETED", "ABORTED"]
90
-
91
- DUO_TEXT = "duo security"
92
-
93
- TXN_ABORT_REASON_TIMEOUT = "transaction timeout"
94
-
95
- #--------------------------------------------------
96
- # Helpers
97
- #--------------------------------------------------
98
-
99
- def process_jinja_template(template: str, indent_spaces = 0, **substitutions) -> str:
100
- """Process a Jinja-like template.
101
-
102
- Supports:
103
- - Variable substitution {{ var }}
104
- - Conditional blocks {% if condition %} ... {% endif %}
105
- - For loops {% for item in items %} ... {% endfor %}
106
- - Comments {# ... #}
107
- - Whitespace control with {%- and -%}
108
-
109
- Args:
110
- template: The template string
111
- indent_spaces: Number of spaces to indent the result
112
- **substitutions: Variable substitutions
113
- """
114
-
115
- def evaluate_condition(condition: str, context: dict) -> bool:
116
- """Safely evaluate a condition string using the context."""
117
- # Replace variables with their values
118
- for k, v in context.items():
119
- if isinstance(v, str):
120
- condition = condition.replace(k, f"'{v}'")
121
- else:
122
- condition = condition.replace(k, str(v))
123
- try:
124
- return bool(eval(condition, {"__builtins__": {}}, {}))
125
- except Exception:
126
- return False
127
-
128
- def process_expression(expr: str, context: dict) -> str:
129
- """Process a {{ expression }} block."""
130
- expr = expr.strip()
131
- if expr in context:
132
- return str(context[expr])
133
- return ""
134
-
135
- def process_block(lines: List[str], context: dict, indent: int = 0) -> List[str]:
136
- """Process a block of template lines recursively."""
137
- result = []
138
- i = 0
139
- while i < len(lines):
140
- line = lines[i]
141
-
142
- # Handle comments
143
- line = re.sub(r'{#.*?#}', '', line)
144
-
145
- # Handle if blocks
146
- if_match = re.search(r'{%\s*if\s+(.+?)\s*%}', line)
147
- if if_match:
148
- condition = if_match.group(1)
149
- if_block = []
150
- else_block = []
151
- i += 1
152
- nesting = 1
153
- in_else_block = False
154
- while i < len(lines) and nesting > 0:
155
- if re.search(r'{%\s*if\s+', lines[i]):
156
- nesting += 1
157
- elif re.search(r'{%\s*endif\s*%}', lines[i]):
158
- nesting -= 1
159
- elif nesting == 1 and re.search(r'{%\s*else\s*%}', lines[i]):
160
- in_else_block = True
161
- i += 1
162
- continue
163
-
164
- if nesting > 0:
165
- if in_else_block:
166
- else_block.append(lines[i])
167
- else:
168
- if_block.append(lines[i])
169
- i += 1
170
- if evaluate_condition(condition, context):
171
- result.extend(process_block(if_block, context, indent))
172
- else:
173
- result.extend(process_block(else_block, context, indent))
174
- continue
175
-
176
- # Handle for loops
177
- for_match = re.search(r'{%\s*for\s+(\w+)\s+in\s+(\w+)\s*%}', line)
178
- if for_match:
179
- var_name, iterable_name = for_match.groups()
180
- for_block = []
181
- i += 1
182
- nesting = 1
183
- while i < len(lines) and nesting > 0:
184
- if re.search(r'{%\s*for\s+', lines[i]):
185
- nesting += 1
186
- elif re.search(r'{%\s*endfor\s*%}', lines[i]):
187
- nesting -= 1
188
- if nesting > 0:
189
- for_block.append(lines[i])
190
- i += 1
191
- if iterable_name in context and isinstance(context[iterable_name], (list, tuple)):
192
- for item in context[iterable_name]:
193
- loop_context = dict(context)
194
- loop_context[var_name] = item
195
- result.extend(process_block(for_block, loop_context, indent))
196
- continue
197
-
198
- # Handle variable substitution
199
- line = re.sub(r'{{\s*(\w+)\s*}}', lambda m: process_expression(m.group(1), context), line)
200
-
201
- # Handle whitespace control
202
- line = re.sub(r'{%-', '{%', line)
203
- line = re.sub(r'-%}', '%}', line)
204
-
205
- # Add line with proper indentation, preserving blank lines
206
- if line.strip():
207
- result.append(" " * (indent_spaces + indent) + line)
208
- else:
209
- result.append("")
210
-
211
- i += 1
212
-
213
- return result
214
-
215
- # Split template into lines and process
216
- lines = template.split('\n')
217
- processed_lines = process_block(lines, substitutions)
218
-
219
- return '\n'.join(processed_lines)
220
-
221
- def type_to_sql(type) -> str:
222
- if type is str:
223
- return "VARCHAR"
224
- if type is int:
225
- return "NUMBER"
226
- if type is Number:
227
- return "DECIMAL(38, 15)"
228
- if type is float:
229
- return "FLOAT"
230
- if type is decimal.Decimal:
231
- return "DECIMAL(38, 15)"
232
- if type is bool:
233
- return "BOOLEAN"
234
- if type is dict:
235
- return "VARIANT"
236
- if type is list:
237
- return "ARRAY"
238
- if type is bytes:
239
- return "BINARY"
240
- if type is datetime:
241
- return "TIMESTAMP"
242
- if type is date:
243
- return "DATE"
244
- if isinstance(type, dsl.Type):
245
- return "VARCHAR"
246
- raise ValueError(f"Unknown type {type}")
247
-
248
- def type_to_snowpark(type) -> str:
249
- if type is str:
250
- return "StringType()"
251
- if type is int:
252
- return "IntegerType()"
253
- if type is float:
254
- return "FloatType()"
255
- if type is Number:
256
- return "DecimalType(38, 15)"
257
- if type is decimal.Decimal:
258
- return "DecimalType(38, 15)"
259
- if type is bool:
260
- return "BooleanType()"
261
- if type is dict:
262
- return "MapType()"
263
- if type is list:
264
- return "ArrayType()"
265
- if type is bytes:
266
- return "BinaryType()"
267
- if type is datetime:
268
- return "TimestampType()"
269
- if type is date:
270
- return "DateType()"
271
- if isinstance(type, dsl.Type):
272
- return "StringType()"
273
- raise ValueError(f"Unknown type {type}")
274
-
275
- def _sanitize_user_name(user: str) -> str:
276
- # Extract the part before the '@'
277
- sanitized_user = user.split('@')[0]
278
- # Replace any character that is not a letter, number, or underscore with '_'
279
- sanitized_user = re.sub(r'[^a-zA-Z0-9_]', '_', sanitized_user)
280
- return sanitized_user
281
-
282
- def _is_engine_issue(response_message: str) -> bool:
283
- return any(kw in response_message.lower() for kw in ENGINE_ERRORS + ENGINE_NOT_READY_MSGS)
284
-
285
- def _is_database_issue(response_message: str) -> bool:
286
- return any(kw in response_message.lower() for kw in DATABASE_ERRORS)
287
-
288
-
289
- #--------------------------------------------------
290
- # Resources
291
- #--------------------------------------------------
292
-
293
- APP_NAME = "___RAI_APP___"
294
-
295
- class Resources(ResourcesBase):
296
- def __init__(
297
- self,
298
- profile: str | None = None,
299
- config: Config | None = None,
300
- connection: Session | None = None,
301
- dry_run: bool = False,
302
- reset_session: bool = False,
303
- generation: Generation | None = None,
304
- language: str = "rel",
305
- ):
306
- super().__init__(profile, config=config)
307
- self._token_handler: TokenHandler | None = None
308
- self._session = connection
309
- self.generation = generation
310
- if self._session is None and not dry_run:
311
- try:
312
- # we may still be constructing the config, so this can fail now,
313
- # if so we'll create later
314
- self._session = self.get_sf_session(reset_session)
315
- except Exception:
316
- pass
317
- self._pending_transactions: list[str] = []
318
- self._ns_cache = {}
319
- # self.sources contains fully qualified Snowflake table/view names
320
- self.sources: set[str] = set()
321
- self._sproc_models = None
322
- self.database = ""
323
- self.language = language
324
- atexit.register(self.cancel_pending_transactions)
325
-
326
- @property
327
- def token_handler(self) -> TokenHandler:
328
- if not self._token_handler:
329
- self._token_handler = TokenHandler.from_config(self.config)
330
- return self._token_handler
331
-
332
- def is_erp_running(self, app_name: str) -> bool:
333
- """Check if the ERP is running. The app.service_status() returns single row/column containing an array of JSON service status objects."""
334
- query = f"CALL {app_name}.app.service_status();"
335
- try:
336
- result = self._exec(query)
337
- # The result is a list of dictionaries, each with a "STATUS" key
338
- # The column name containing the result is "SERVICE_STATUS"
339
- services_status = json.loads(result[0]["SERVICE_STATUS"])
340
- # Find the dictionary with "name" of "main" and check if its "status" is "READY"
341
- for service in services_status:
342
- if service.get("name") == "main" and service.get("status") == "READY":
343
- return True
344
- return False
345
- except Exception:
346
- return False
347
-
348
- def get_sf_session(self, reset_session: bool = False):
349
- if self._session:
350
- return self._session
351
-
352
- if isinstance(runtime_env, HexEnvironment):
353
- raise HexSessionException()
354
- if isinstance(runtime_env, SnowbookEnvironment):
355
- return get_active_session()
356
- else:
357
- # if there's already been a session created, try using that
358
- # if reset_session is true always try to get the new session
359
- if not reset_session:
360
- try:
361
- return get_active_session()
362
- except Exception:
363
- pass
364
-
365
- # otherwise, create a new session
366
- missing_keys = []
367
- connection_parameters = {}
368
-
369
- authenticator = self.config.get('authenticator', None)
370
- passcode = self.config.get("passcode", "")
371
- private_key_file = self.config.get("private_key_file", "")
372
-
373
- # If the authenticator is not set, we need to set it based on the provided parameters
374
- if authenticator is None:
375
- if private_key_file != "":
376
- authenticator = "snowflake_jwt"
377
- elif passcode != "":
378
- authenticator = "username_password_mfa"
379
- else:
380
- authenticator = "snowflake"
381
- # set the default authenticator in the config so we can skip it when we check for missing keys
382
- self.config.set("authenticator", authenticator)
383
-
384
- if authenticator in SNOWFLAKE_AUTHS:
385
- required_keys = {
386
- key for key, value in SNOWFLAKE_AUTHS[authenticator].items() if value.get("required", True)
387
- }
388
- for key in required_keys:
389
- if self.config.get(key, None) is None:
390
- default = SNOWFLAKE_AUTHS[authenticator][key].get("value", None)
391
- if default is None or default == FIELD_PLACEHOLDER:
392
- # No default value and no value in the config, add to missing keys
393
- missing_keys.append(key)
394
- else:
395
- # Set the default value in the config from the auth defaults
396
- self.config.set(key, default)
397
- if missing_keys:
398
- profile = getattr(self.config, 'profile', None)
399
- config_file_path = getattr(self.config, 'file_path', None)
400
- raise SnowflakeMissingConfigValuesException(missing_keys, profile, config_file_path)
401
- for key in SNOWFLAKE_AUTHS[authenticator]:
402
- connection_parameters[key] = self.config.get(key, None)
403
- else:
404
- raise ValueError(f'Authenticator "{authenticator}" not supported')
405
-
406
- return self._build_snowflake_session(connection_parameters)
407
-
408
- def _build_snowflake_session(self, connection_parameters: Dict[str, Any]) -> Session:
409
- try:
410
- tmp = {
411
- "client_session_keep_alive": True,
412
- "client_session_keep_alive_heartbeat_frequency": 60 * 5,
413
- }
414
- tmp.update(connection_parameters)
415
- connection_parameters = tmp
416
- # authenticator programmatic access token needs to be upper cased to work...
417
- connection_parameters["authenticator"] = connection_parameters["authenticator"].upper()
418
- if "authenticator" in connection_parameters and connection_parameters["authenticator"] == "OAUTH_AUTHORIZATION_CODE":
419
- # we are replicating OAUTH_AUTHORIZATION_CODE by first retrieving the token
420
- # and then authenticating with the token via the OAUTH authenticator
421
- connection_parameters["token"] = self.token_handler.get_session_login_token()
422
- connection_parameters["authenticator"] = "OAUTH"
423
- return Session.builder.configs(connection_parameters).create()
424
- except snowflake.connector.errors.Error as e:
425
- raise SnowflakeDatabaseException(e)
426
- except Exception as e:
427
- raise e
428
-
429
- def _exec_sql(self, code: str, params: List[Any] | None, raw=False):
430
- assert self._session is not None
431
- sess_results = self._session.sql(
432
- code.replace(APP_NAME, self.get_app_name()),
433
- params
434
- )
435
- if raw:
436
- return sess_results
437
- return sess_results.collect()
438
-
439
- def _exec(
440
- self,
441
- code: str,
442
- params: List[Any] | Any | None = None,
443
- raw: bool = False,
444
- help: bool = True,
445
- skip_engine_db_error_retry: bool = False
446
- ) -> Any:
447
- # print(f"\n--- sql---\n{code}\n--- end sql---\n")
448
- if not self._session:
449
- self._session = self.get_sf_session()
450
-
451
- try:
452
- if params is not None and not isinstance(params, list):
453
- params = cast(List[Any], [params])
454
- return self._exec_sql(code, params, raw=raw)
455
- except Exception as e:
456
- if not help:
457
- raise e
458
- orig_message = str(e).lower()
459
- rai_app = self.config.get("rai_app_name", "")
460
- current_role = self.config.get("role")
461
- engine = self.get_default_engine_name()
462
- engine_size = self.config.get_default_engine_size()
463
- assert isinstance(rai_app, str), f"rai_app_name must be a string, not {type(rai_app)}"
464
- assert isinstance(engine, str), f"engine must be a string, not {type(engine)}"
465
- print("\n")
466
- if DUO_TEXT in orig_message:
467
- raise DuoSecurityFailed(e)
468
- if re.search(f"database '{rai_app}' does not exist or not authorized.".lower(), orig_message):
469
- exception = SnowflakeAppMissingException(rai_app, current_role)
470
- raise exception from None
471
- # skip initializing the index if the query is a user transaction. exec_raw/exec_lqp will handle that case with the correct request headers.
472
- if (_is_engine_issue(orig_message) or _is_database_issue(orig_message)) and not skip_engine_db_error_retry:
473
- try:
474
- self._poll_use_index(
475
- app_name=self.get_app_name(),
476
- sources=self.sources,
477
- model=self.database,
478
- engine_name=engine,
479
- engine_size=engine_size
480
- )
481
- return self._exec(code, params, raw=raw, help=help)
482
- except EngineNameValidationException as e:
483
- raise EngineNameValidationException(engine) from e
484
- except Exception as e:
485
- raise EngineProvisioningFailed(engine, e) from e
486
- elif re.search(r"javascript execution error", orig_message):
487
- match = re.search(r"\"message\":\"(.*)\"", orig_message)
488
- if match:
489
- message = match.group(1)
490
- if "engine is in pending" in message or "engine is provisioning" in message:
491
- raise EnginePending(engine)
492
- else:
493
- raise RAIException(message) from None
494
-
495
- if re.search(r"the relationalai service has not been started.", orig_message):
496
- app_name = self.config.get("rai_app_name", "")
497
- assert isinstance(app_name, str), f"rai_app_name must be a string, not {type(app_name)}"
498
- raise SnowflakeRaiAppNotStarted(app_name)
499
-
500
- if re.search(r"state:\s*aborted", orig_message):
501
- txn_id_match = re.search(r"id:\s*([0-9a-f\-]+)", orig_message)
502
- if txn_id_match:
503
- txn_id = txn_id_match.group(1)
504
- problems = self.get_transaction_problems(txn_id)
505
- if problems:
506
- for problem in problems:
507
- if isinstance(problem, dict):
508
- type_field = problem.get('TYPE')
509
- message_field = problem.get('MESSAGE')
510
- report_field = problem.get('REPORT')
511
- else:
512
- type_field = problem.TYPE
513
- message_field = problem.MESSAGE
514
- report_field = problem.REPORT
515
-
516
- raise RAIAbortedTransactionError(type_field, message_field, report_field)
517
- raise RAIException(str(e))
518
- raise RAIException(str(e))
519
-
520
-
521
- def reset(self):
522
- self._session = None
523
-
524
- #--------------------------------------------------
525
- # Check direct access is enabled
526
- #--------------------------------------------------
527
-
528
- def is_direct_access_enabled(self) -> bool:
529
- try:
530
- feature_enabled = self._exec(
531
- f"call {APP_NAME}.APP.DIRECT_INGRESS_ENABLED();"
532
- )
533
- if not feature_enabled:
534
- return False
535
-
536
- # Even if the feature is enabled, customers still need to reactivate ERP to ensure the endpoint is available.
537
- endpoint = self._exec(
538
- f"call {APP_NAME}.APP.SERVICE_ENDPOINT(true);"
539
- )
540
- if not endpoint or endpoint[0][0] is None:
541
- return False
542
-
543
- return feature_enabled[0][0]
544
- except Exception as e:
545
- raise Exception(f"Unable to determine if direct access is enabled. Details error: {e}") from e
546
-
547
- #--------------------------------------------------
548
- # Snowflake Account Flags
549
- #--------------------------------------------------
550
-
551
- def is_account_flag_set(self, flag: str) -> bool:
552
- results = self._exec(
553
- f"SHOW PARAMETERS LIKE '%{flag}%' IN ACCOUNT;"
554
- )
555
- if not results:
556
- return False
557
- return results[0]["value"] == "true"
558
-
559
- #--------------------------------------------------
560
- # Databases
561
- #--------------------------------------------------
562
-
563
- def get_database(self, database: str):
564
- try:
565
- results = self._exec(
566
- f"call {APP_NAME}.api.get_database('{database}');"
567
- )
568
- except Exception as e:
569
- if "Database does not exist" in str(e):
570
- return None
571
- raise e
572
-
573
- if not results:
574
- return None
575
- db = results[0]
576
- if not db:
577
- return None
578
- return {
579
- "id": db["ID"],
580
- "name": db["NAME"],
581
- "created_by": db["CREATED_BY"],
582
- "created_on": db["CREATED_ON"],
583
- "deleted_by": db["DELETED_BY"],
584
- "deleted_on": db["DELETED_ON"],
585
- "state": db["STATE"],
586
- }
587
-
588
- def get_installed_packages(self, database: str) -> Dict | None:
589
- query = f"call {APP_NAME}.api.get_installed_package_versions('{database}');"
590
- try:
591
- results = self._exec(query)
592
- except Exception as e:
593
- if "Database does not exist" in str(e):
594
- return None
595
- # fallback to None for old sql-lib versions
596
- if "Unknown user-defined function" in str(e):
597
- return None
598
- raise e
599
-
600
- if not results:
601
- return None
602
-
603
- row = results[0]
604
- if not row:
605
- return None
606
-
607
- return safe_json_loads(row["PACKAGE_VERSIONS"])
608
-
609
- #--------------------------------------------------
610
- # Engines
611
- #--------------------------------------------------
612
-
613
- def get_engine_sizes(self, cloud_provider: str|None=None):
614
- sizes = []
615
- if cloud_provider is None:
616
- cloud_provider = self.get_cloud_provider()
617
- if cloud_provider == 'azure':
618
- sizes = ENGINE_SIZES_AZURE
619
- else:
620
- sizes = ENGINE_SIZES_AWS
621
- if self.config.show_all_engine_sizes():
622
- return INTERNAL_ENGINE_SIZES + sizes
623
- else:
624
- return sizes
625
-
626
- def list_engines(self, state: str | None = None):
627
- where_clause = f"WHERE STATUS = '{state.upper()}'" if state else ""
628
- statement = f"SELECT NAME, ID, SIZE, STATUS, CREATED_BY, CREATED_ON, UPDATED_ON FROM {APP_NAME}.api.engines {where_clause} ORDER BY NAME ASC;"
629
- results = self._exec(statement)
630
- if not results:
631
- return []
632
- return [
633
- {
634
- "name": row["NAME"],
635
- "id": row["ID"],
636
- "size": row["SIZE"],
637
- "state": row["STATUS"], # callers are expecting 'state'
638
- "created_by": row["CREATED_BY"],
639
- "created_on": row["CREATED_ON"],
640
- "updated_on": row["UPDATED_ON"],
641
- }
642
- for row in results
643
- ]
644
-
645
- def get_engine(self, name: str):
646
- results = self._exec(
647
- f"SELECT NAME, ID, SIZE, STATUS, CREATED_BY, CREATED_ON, UPDATED_ON, VERSION, AUTO_SUSPEND_MINS, SUSPENDS_AT FROM {APP_NAME}.api.engines WHERE NAME='{name}';"
648
- )
649
- if not results:
650
- return None
651
- engine = results[0]
652
- if not engine:
653
- return None
654
- engine_state: EngineState = {
655
- "name": engine["NAME"],
656
- "id": engine["ID"],
657
- "size": engine["SIZE"],
658
- "state": engine["STATUS"], # callers are expecting 'state'
659
- "created_by": engine["CREATED_BY"],
660
- "created_on": engine["CREATED_ON"],
661
- "updated_on": engine["UPDATED_ON"],
662
- "version": engine["VERSION"],
663
- "auto_suspend": engine["AUTO_SUSPEND_MINS"],
664
- "suspends_at": engine["SUSPENDS_AT"]
665
- }
666
- return engine_state
667
-
668
- def get_default_engine_name(self) -> str:
669
- if self.config.get("engine_name", None) is not None:
670
- profile = self.config.profile
671
- raise InvalidAliasError(f"""
672
- 'engine_name' is not a valid config option.
673
- If you meant to use a specific engine, use 'engine' instead.
674
- Otherwise, remove it from your '{profile}' configuration profile.
675
- """)
676
- engine = self.config.get("engine", None)
677
- if not engine and self.config.get("user", None):
678
- engine = _sanitize_user_name(str(self.config.get("user")))
679
- if not engine:
680
- engine = self.get_user_based_engine_name()
681
- self.config.set("engine", engine)
682
- return engine
683
-
684
- def is_valid_engine_state(self, name:str):
685
- return name in VALID_ENGINE_STATES
686
-
687
- def _create_engine(
688
- self,
689
- name: str,
690
- size: str | None = None,
691
- auto_suspend_mins: int | None= None,
692
- is_async: bool = False,
693
- headers: Dict | None = None,
694
- ):
695
- api = "create_engine_async" if is_async else "create_engine"
696
- if size is None:
697
- size = self.config.get_default_engine_size()
698
- # if auto_suspend_mins is None, get the default value from the config
699
- if auto_suspend_mins is None:
700
- auto_suspend_mins = self.config.get_default_auto_suspend_mins()
701
- try:
702
- headers = debugging.gen_current_propagation_headers()
703
- with debugging.span(api, name=name, size=size, auto_suspend_mins=auto_suspend_mins):
704
- # check in case the config default is missing
705
- if auto_suspend_mins is None:
706
- self._exec(f"call {APP_NAME}.api.{api}('{name}', '{size}', null, {headers});")
707
- else:
708
- self._exec(f"call {APP_NAME}.api.{api}('{name}', '{size}', PARSE_JSON('{{\"auto_suspend_mins\": {auto_suspend_mins}}}'), {headers});")
709
- except Exception as e:
710
- raise EngineProvisioningFailed(name, e) from e
711
-
712
- def create_engine(self, name:str, size:str|None=None, auto_suspend_mins:int|None=None, headers: Dict | None = None):
713
- self._create_engine(name, size, auto_suspend_mins, headers=headers)
714
-
715
- def create_engine_async(self, name:str, size:str|None=None, auto_suspend_mins:int|None=None):
716
- self._create_engine(name, size, auto_suspend_mins, True)
717
-
718
- def delete_engine(self, name:str, force:bool = False, headers: Dict | None = None):
719
- request_headers = debugging.add_current_propagation_headers(headers)
720
- self._exec(f"call {APP_NAME}.api.delete_engine('{name}', {force},{request_headers});")
721
-
722
- def suspend_engine(self, name:str):
723
- self._exec(f"call {APP_NAME}.api.suspend_engine('{name}');")
724
-
725
- def resume_engine(self, name:str, headers: Dict | None = None) -> Dict:
726
- request_headers = debugging.add_current_propagation_headers(headers)
727
- self._exec(f"call {APP_NAME}.api.resume_engine('{name}',{request_headers});")
728
- # returning empty dict to match the expected return type
729
- return {}
730
-
731
- def resume_engine_async(self, name:str, headers: Dict | None = None) -> Dict:
732
- if headers is None:
733
- headers = {}
734
- self._exec(f"call {APP_NAME}.api.resume_engine_async('{name}',{headers});")
735
- # returning empty dict to match the expected return type
736
- return {}
737
-
738
- def alter_engine_pool(self, size:str|None=None, mins:int|None=None, maxs:int|None=None):
739
- """Alter engine pool node limits for Snowflake."""
740
- self._exec(f"call {APP_NAME}.api.alter_engine_pool_node_limits('{size}', {mins}, {maxs});")
741
-
742
- #--------------------------------------------------
743
- # Graphs
744
- #--------------------------------------------------
745
-
746
- def list_graphs(self) -> List[AvailableModel]:
747
- with debugging.span("list_models"):
748
- query = textwrap.dedent(f"""
749
- SELECT NAME, ID, CREATED_BY, CREATED_ON, STATE, DELETED_BY, DELETED_ON
750
- FROM {APP_NAME}.api.databases
751
- WHERE state <> 'DELETED'
752
- ORDER BY NAME ASC;
753
- """)
754
- results = self._exec(query)
755
- if not results:
756
- return []
757
- return [
758
- {
759
- "name": row["NAME"],
760
- "id": row["ID"],
761
- "created_by": row["CREATED_BY"],
762
- "created_on": row["CREATED_ON"],
763
- "state": row["STATE"],
764
- "deleted_by": row["DELETED_BY"],
765
- "deleted_on": row["DELETED_ON"],
766
- }
767
- for row in results
768
- ]
769
-
770
- def get_graph(self, name: str):
771
- res = self.get_database(name)
772
- if res and res.get("state") != "DELETED":
773
- return res
774
-
775
- def create_graph(self, name: str):
776
- with debugging.span("create_model", name=name):
777
- self._exec(f"call {APP_NAME}.api.create_database('{name}', false, {debugging.gen_current_propagation_headers()});")
778
-
779
- def delete_graph(self, name:str, force=False, language:str="rel"):
780
- prop_hdrs = debugging.gen_current_propagation_headers()
781
- if self.config.get("use_graph_index", USE_GRAPH_INDEX):
782
- keep_database = not force and self.config.get("reuse_model", True)
783
- with debugging.span("release_index", name=name, keep_database=keep_database, language=language):
784
- #TODO add headers to release_index
785
- response = self._exec(f"call {APP_NAME}.api.release_index('{name}', OBJECT_CONSTRUCT('keep_database', {keep_database}, 'language', '{language}', 'user_agent', '{get_pyrel_version(self.generation)}'));")
786
- if response:
787
- result = next(iter(response))
788
- obj = json.loads(result["RELEASE_INDEX"])
789
- error = obj.get('error', None)
790
- if error and "Model database not found" not in error:
791
- raise Exception(f"Error releasing index: {error}")
792
- else:
793
- raise Exception("There was no response from the release index call.")
794
- else:
795
- with debugging.span("delete_model", name=name):
796
- self._exec(f"call {APP_NAME}.api.delete_database('{name}', false, {prop_hdrs});")
797
-
798
- def clone_graph(self, target_name:str, source_name:str, nowait_durable=True, force=False):
799
- if force and self.get_graph(target_name):
800
- self.delete_graph(target_name)
801
- with debugging.span("clone_model", target_name=target_name, source_name=source_name):
802
- # not a mistake: the clone_database argument order is indeed target then source:
803
- headers = debugging.gen_current_propagation_headers()
804
- self._exec(f"call {APP_NAME}.api.clone_database('{target_name}', '{source_name}', {nowait_durable}, {headers});")
805
-
806
- def _poll_use_index(
807
- self,
808
- app_name: str,
809
- sources: Iterable[str],
810
- model: str,
811
- engine_name: str,
812
- engine_size: str | None = None,
813
- program_span_id: str | None = None,
814
- headers: Dict | None = None,
815
- ):
816
- return UseIndexPoller(
817
- self,
818
- app_name,
819
- sources,
820
- model,
821
- engine_name,
822
- engine_size,
823
- self.language,
824
- program_span_id,
825
- headers,
826
- self.generation
827
- ).poll()
828
-
829
- def maybe_poll_use_index(
830
- self,
831
- app_name: str,
832
- sources: Iterable[str],
833
- model: str,
834
- engine_name: str,
835
- engine_size: str | None = None,
836
- program_span_id: str | None = None,
837
- headers: Dict | None = None,
838
- ):
839
- """Only call poll() if there are sources to process and cache is not valid."""
840
- sources_list = list(sources)
841
- self.database = model
842
- if sources_list:
843
- poller = UseIndexPoller(
844
- self,
845
- app_name,
846
- sources_list,
847
- model,
848
- engine_name,
849
- engine_size,
850
- self.language,
851
- program_span_id,
852
- headers,
853
- self.generation
854
- )
855
- # If cache is valid (data freshness has not expired), skip polling
856
- if poller.cache.is_valid():
857
- cached_sources = len(poller.cache.sources)
858
- total_sources = len(sources_list)
859
- cached_timestamp = poller.cache._metadata.get("cachedIndices", {}).get(poller.cache.key, {}).get("last_use_index_update_on", "")
860
-
861
- message = f"Using cached data for {cached_sources}/{total_sources} data streams"
862
- if cached_timestamp:
863
- print(f"\n{message} (cached at {cached_timestamp})\n")
864
- else:
865
- print(f"\n{message}\n")
866
- else:
867
- return poller.poll()
868
-
869
- #--------------------------------------------------
870
- # Models
871
- #--------------------------------------------------
872
-
873
- def list_models(self, database: str, engine: str):
874
- pass
875
-
876
- def create_models(self, database: str, engine: str | None, models:List[Tuple[str, str]]) -> List[Any]:
877
- rel_code = self.create_models_code(models)
878
- self.exec_raw(database, engine, rel_code, readonly=False)
879
- # TODO: handle SPCS errors once they're figured out
880
- return []
881
-
882
- def delete_model(self, database:str, engine:str | None, name:str):
883
- self.exec_raw(database, engine, f"def delete[:rel, :catalog, :model, \"{name}\"]: rel[:catalog, :model, \"{name}\"]", readonly=False)
884
-
885
- def create_models_code(self, models:List[Tuple[str, str]]) -> str:
886
- lines = []
887
- for (name, code) in models:
888
- name = name.replace("\"", "\\\"")
889
- assert "\"\"\"\"\"\"\"" not in code, "Code literals must use fewer than 7 quotes."
890
-
891
- lines.append(textwrap.dedent(f"""
892
- def delete[:rel, :catalog, :model, "{name}"]: rel[:catalog, :model, "{name}"]
893
- def insert[:rel, :catalog, :model, "{name}"]: raw\"\"\"\"\"\"\"
894
- """) + code + "\n\"\"\"\"\"\"\"")
895
- rel_code = "\n\n".join(lines)
896
- return rel_code
897
-
898
- #--------------------------------------------------
899
- # Exports
900
- #--------------------------------------------------
901
-
902
- def list_exports(self, database: str, engine: str):
903
- return []
904
-
905
- def format_sproc_name(self, name: str, type:Any) -> str:
906
- if type is datetime:
907
- return f"{name}.astimezone(ZoneInfo('UTC')).isoformat(timespec='milliseconds')"
908
- else:
909
- return name
910
-
911
- def get_export_code(self, params: ExportParams, all_installs):
912
- sql_inputs = ", ".join([f"{name} {type_to_sql(type)}" for (name, _, type) in params.inputs])
913
- input_names = [name for (name, *_) in params.inputs]
914
- has_return_hint = params.out_fields and isinstance(params.out_fields[0], tuple)
915
- if has_return_hint:
916
- sql_out = ", ".join([f"\"{name}\" {type_to_sql(type)}" for (name, type) in params.out_fields])
917
- sql_out_names = ", ".join([f"('{name}', '{type_to_sql(type)}')" for (ix, (name, type)) in enumerate(params.out_fields)])
918
- py_outs = ", ".join([f"StructField(\"{name}\", {type_to_snowpark(type)})" for (name, type) in params.out_fields])
919
- else:
920
- sql_out = ""
921
- sql_out_names = ", ".join([f"'{name}'" for name in params.out_fields])
922
- py_outs = ", ".join([f"StructField(\"{name}\", {type_to_snowpark(str)})" for name in params.out_fields])
923
- py_inputs = ", ".join([name for (name, *_) in params.inputs])
924
- safe_rel = escape_for_f_string(params.code).strip()
925
- clean_inputs = []
926
- for (name, var, type) in params.inputs:
927
- if type is str:
928
- clean_inputs.append(f"{name} = '\"' + escape({name}) + '\"'")
929
- # Replace `var` with `name` and keep the following non-word character unchanged
930
- pattern = re.compile(re.escape(var) + r'(\W)')
931
- value = self.format_sproc_name(name, type)
932
- safe_rel = re.sub(pattern, rf"{{{value}}}\1", safe_rel)
933
- if py_inputs:
934
- py_inputs = f", {py_inputs}"
935
- clean_inputs = ("\n").join(clean_inputs)
936
- assert __package__ is not None, "Package name must be set"
937
- file = "export_procedure.py.jinja"
938
- with importlib.resources.open_text(
939
- __package__, file
940
- ) as f:
941
- template = f.read()
942
- def quote(s: str, f = False) -> str:
943
- return '"' + s + '"' if not f else 'f"' + s + '"'
944
-
945
- wait_for_stream_sync = self.config.get("wait_for_stream_sync", WAIT_FOR_STREAM_SYNC)
946
- # 1. Check the sources for staled sources
947
- # 2. Get the object references for the sources
948
- # TODO: this could be optimized to do it in the run time of the stored procedure
949
- # instead of doing it here. It will make it more reliable when sources are
950
- # modified after the stored procedure is created.
951
- checked_sources = self._check_source_updates(self.sources)
952
- source_obj_references = self._get_source_references(checked_sources)
953
-
954
- # Escape double quotes in the source object references
955
- escaped_source_obj_references = [source.replace('"', '\\"') for source in source_obj_references]
956
- escaped_proc_database = params.proc_database.replace('"', '\\"')
957
-
958
- normalized_func_name = IdentityParser(params.func_name).identity
959
- assert normalized_func_name is not None, "Function name must be set"
960
- skip_invalid_data = params.skip_invalid_data
961
- python_code = process_jinja_template(
962
- template,
963
- func_name=quote(normalized_func_name),
964
- database=quote(params.root_database),
965
- proc_database=quote(escaped_proc_database),
966
- engine=quote(params.engine),
967
- rel_code=quote(safe_rel, f=True),
968
- APP_NAME=quote(APP_NAME),
969
- input_names=input_names,
970
- outputs=sql_out,
971
- sql_out_names=sql_out_names,
972
- clean_inputs=clean_inputs,
973
- py_inputs=py_inputs,
974
- py_outs=py_outs,
975
- skip_invalid_data=skip_invalid_data,
976
- source_references=", ".join(escaped_source_obj_references),
977
- install_code=all_installs.replace("\\", "\\\\").replace("\n", "\\n"),
978
- has_return_hint=has_return_hint,
979
- wait_for_stream_sync=wait_for_stream_sync,
980
- ).strip()
981
- return_clause = f"TABLE({sql_out})" if sql_out else "STRING"
982
- destination_input = "" if sql_out else "save_as_table STRING DEFAULT NULL,"
983
- module_name = sanitize_module_name(normalized_func_name)
984
- stage = f"@{self.get_app_name()}.app_state.stored_proc_code_stage"
985
- file_loc = f"{stage}/{module_name}.py"
986
- python_code = python_code.replace(APP_NAME, self.get_app_name())
987
-
988
- hash = hashlib.sha256()
989
- hash.update(python_code.encode('utf-8'))
990
- code_hash = hash.hexdigest()
991
- print(code_hash)
992
-
993
- sql_code = textwrap.dedent(f"""
994
- CREATE OR REPLACE PROCEDURE {normalized_func_name}({sql_inputs}{sql_inputs and ',' or ''} {destination_input} engine STRING DEFAULT NULL)
995
- RETURNS {return_clause}
996
- LANGUAGE PYTHON
997
- RUNTIME_VERSION = '3.10'
998
- IMPORTS = ('{file_loc}')
999
- PACKAGES = ('snowflake-snowpark-python')
1000
- HANDLER = 'checked_handle'
1001
- EXECUTE AS CALLER
1002
- AS
1003
- $$
1004
- import {module_name}
1005
- import inspect, hashlib, os, sys
1006
- def checked_handle(*args, **kwargs):
1007
- import_dir = sys._xoptions["snowflake_import_directory"]
1008
- wheel_path = os.path.join(import_dir, '{module_name}.py')
1009
- h = hashlib.sha256()
1010
- with open(wheel_path, 'rb') as f:
1011
- for chunk in iter(lambda: f.read(1<<20), b''):
1012
- h.update(chunk)
1013
- code_hash = h.hexdigest()
1014
- if code_hash != '{code_hash}':
1015
- raise RuntimeError("Code hash mismatch. The code has been modified since it was uploaded.")
1016
- # Call the handle function with the provided arguments
1017
- return {module_name}.handle(*args, **kwargs)
1018
-
1019
- $$;
1020
- """)
1021
- # print(f"\n--- python---\n{python_code}\n--- end python---\n")
1022
- # This check helps catch invalid code early and for dry runs:
1023
- try:
1024
- ast.parse(python_code)
1025
- except SyntaxError:
1026
- raise ValueError(f"Internal error: invalid Python code generated:\n{python_code}")
1027
- return (sql_code, python_code, file_loc)
1028
-
1029
- def get_sproc_models(self, params: ExportParams):
1030
- if self._sproc_models is not None:
1031
- return self._sproc_models
1032
-
1033
- with debugging.span("get_sproc_models"):
1034
- code = """
1035
- def output(name, model):
1036
- rel(:catalog, :model, name, model)
1037
- and not starts_with(name, "rel/")
1038
- and not starts_with(name, "pkg/rel")
1039
- and not starts_with(name, "pkg/std")
1040
- and starts_with(name, "pkg/")
1041
- """
1042
- res = self.exec_raw(params.model_database, params.engine, code, readonly=True, nowait_durable=True)
1043
- df, errors = result_helpers.format_results(res, None, ["name", "model"])
1044
- models = []
1045
- for row in df.itertuples():
1046
- models.append((row.name, row.model))
1047
- self._sproc_models = models
1048
- return models
1049
-
1050
- def create_export(self, params: ExportParams):
1051
- with debugging.span("create_export") as span:
1052
- if params.dry_run:
1053
- (sql_code, python_code, file_loc) = self.get_export_code(params, params.install_code)
1054
- span["sql"] = sql_code
1055
- return
1056
-
1057
- start = time.perf_counter()
1058
- use_graph_index = self.config.get("use_graph_index", USE_GRAPH_INDEX)
1059
- # for the non graph index case we need to create the cloned proc database
1060
- if not use_graph_index:
1061
- raise RAIException(
1062
- "To ensure permissions are properly accounted for, stored procedures require using the graph index. "
1063
- "Set use_graph_index=True in your config to proceed."
1064
- )
1065
-
1066
- models = self.get_sproc_models(params)
1067
- lib_installs = self.create_models_code(models)
1068
- all_installs = lib_installs + "\n\n" + params.install_code
1069
-
1070
- (sql_code, python_code, file_loc) = self.get_export_code(params, all_installs)
1071
-
1072
- span["sql"] = sql_code
1073
- assert self._session
1074
-
1075
- with debugging.span("upload_sproc_code"):
1076
- code_bytes = python_code.encode('utf-8')
1077
- code_stream = io.BytesIO(code_bytes)
1078
- self._session.file.put_stream(code_stream, file_loc, auto_compress=False, overwrite=True)
1079
-
1080
- with debugging.span("sql_install"):
1081
- self._exec(sql_code)
1082
-
1083
- debugging.time("export", time.perf_counter() - start, DataFrame(), code=sql_code.replace(APP_NAME, self.get_app_name()))
1084
-
1085
-
1086
- def create_export_table(self, database: str, engine: str, table: str, relation: str, columns: Dict[str, str], code: str, refresh: str|None=None):
1087
- print("Snowflake doesn't support creating export tables yet. Try creating the table manually first.")
1088
- pass
1089
-
1090
- def delete_export(self, database: str, engine: str, name: str):
1091
- pass
1092
-
1093
- #--------------------------------------------------
1094
- # Imports
1095
- #--------------------------------------------------
1096
-
1097
- def is_valid_import_state(self, state:str):
1098
- return state in VALID_IMPORT_STATES
1099
-
1100
- def imports_to_dicts(self, results):
1101
- parsed_results = [
1102
- {field.lower(): row[field] for field in IMPORT_STREAM_FIELDS}
1103
- for row in results
1104
- ]
1105
- return parsed_results
1106
-
1107
- def change_stream_status(self, stream_id: str, model:str, suspend: bool):
1108
- if stream_id and model:
1109
- if suspend:
1110
- self._exec(f"CALL {APP_NAME}.api.suspend_data_stream('{stream_id}', '{model}');")
1111
- else:
1112
- self._exec(f"CALL {APP_NAME}.api.resume_data_stream('{stream_id}', '{model}');")
1113
-
1114
- def change_imports_status(self, suspend: bool):
1115
- if suspend:
1116
- self._exec(f"CALL {APP_NAME}.app.suspend_cdc();")
1117
- else:
1118
- self._exec(f"CALL {APP_NAME}.app.resume_cdc();")
1119
-
1120
- def get_imports_status(self) -> ImportsStatus|None:
1121
- # NOTE: We expect there to only ever be one result?
1122
- results = self._exec(f"CALL {APP_NAME}.app.cdc_status();")
1123
- if results:
1124
- result = next(iter(results))
1125
- engine = result['CDC_ENGINE_NAME']
1126
- engine_status = result['CDC_ENGINE_STATUS']
1127
- engine_size = result['CDC_ENGINE_SIZE']
1128
- task_status = result['CDC_TASK_STATUS']
1129
- info = result['CDC_TASK_INFO']
1130
- enabled = result['CDC_ENABLED']
1131
- return {"engine": engine, "engine_size": engine_size, "engine_status": engine_status, "status": task_status, "enabled": enabled, "info": info }
1132
- return None
1133
-
1134
- def set_imports_engine_size(self, size:str):
1135
- try:
1136
- self._exec(f"CALL {APP_NAME}.app.alter_cdc_engine_size('{size}');")
1137
- except Exception as e:
1138
- raise e
1139
-
1140
- def list_imports(
1141
- self,
1142
- id:str|None = None,
1143
- name:str|None = None,
1144
- model:str|None = None,
1145
- status:str|None = None,
1146
- creator:str|None = None,
1147
- ) -> list[Import]:
1148
- where = []
1149
- if id and isinstance(id, str):
1150
- where.append(f"LOWER(ID) = '{id.lower()}'")
1151
- if name and isinstance(name, str):
1152
- where.append(f"LOWER(FQ_OBJECT_NAME) = '{name.lower()}'")
1153
- if model and isinstance(model, str):
1154
- where.append(f"LOWER(RAI_DATABASE) = '{model.lower()}'")
1155
- if creator and isinstance(creator, str):
1156
- where.append(f"LOWER(CREATED_BY) = '{creator.lower()}'")
1157
- if status and isinstance(status, str):
1158
- where.append(f"LOWER(batch_status) = '{status.lower()}'")
1159
- where_clause = " AND ".join(where)
1160
-
1161
- # This is roughly inspired by the native app code because we don't have a way to
1162
- # get the status of multiple streams at once and doing them individually is way
1163
- # too slow. We use window functions to get the status of the stream and the batch
1164
- # details.
1165
- statement = f"""
1166
- SELECT
1167
- ID,
1168
- RAI_DATABASE,
1169
- FQ_OBJECT_NAME,
1170
- CREATED_AT,
1171
- CREATED_BY,
1172
- CASE
1173
- WHEN nextBatch.quarantined > 0 THEN 'quarantined'
1174
- ELSE nextBatch.status
1175
- END as batch_status,
1176
- nextBatch.processing_errors,
1177
- nextBatch.batches
1178
- FROM {APP_NAME}.api.data_streams as ds
1179
- LEFT JOIN (
1180
- SELECT DISTINCT
1181
- data_stream_id,
1182
- -- Get status from the progress record using window functions
1183
- FIRST_VALUE(status) OVER (
1184
- PARTITION BY data_stream_id
1185
- ORDER BY
1186
- CASE WHEN unloaded IS NOT NULL THEN 1 ELSE 0 END DESC,
1187
- unloaded ASC
1188
- ) as status,
1189
- -- Get batch_details from the same record
1190
- FIRST_VALUE(batch_details) OVER (
1191
- PARTITION BY data_stream_id
1192
- ORDER BY
1193
- CASE WHEN unloaded IS NOT NULL THEN 1 ELSE 0 END DESC,
1194
- unloaded ASC
1195
- ) as batch_details,
1196
- -- Aggregate the other fields
1197
- FIRST_VALUE(processing_details:processingErrors) OVER (
1198
- PARTITION BY data_stream_id
1199
- ORDER BY
1200
- CASE WHEN unloaded IS NOT NULL THEN 1 ELSE 0 END DESC,
1201
- unloaded ASC
1202
- ) as processing_errors,
1203
- MIN(unloaded) OVER (PARTITION BY data_stream_id) as unloaded,
1204
- COUNT(*) OVER (PARTITION BY data_stream_id) as batches,
1205
- COUNT_IF(status = 'quarantined') OVER (PARTITION BY data_stream_id) as quarantined
1206
- FROM {APP_NAME}.api.data_stream_batches
1207
- ) nextBatch
1208
- ON ds.id = nextBatch.data_stream_id
1209
- {f"where {where_clause}" if where_clause else ""}
1210
- ORDER BY FQ_OBJECT_NAME ASC;
1211
- """
1212
- results = self._exec(statement)
1213
- items = []
1214
- if results:
1215
- for stream in results:
1216
- (id, db, name, created_at, created_by, status, processing_errors, batches) = stream
1217
- if status and isinstance(status, str):
1218
- status = status.upper()
1219
- if processing_errors:
1220
- if status in ["QUARANTINED", "PENDING"]:
1221
- start = processing_errors.rfind("Error")
1222
- if start != -1:
1223
- processing_errors = processing_errors[start:-1]
1224
- else:
1225
- processing_errors = None
1226
- items.append(cast(Import, {
1227
- "id": id,
1228
- "model": db,
1229
- "name": name,
1230
- "created": created_at,
1231
- "creator": created_by,
1232
- "status": status.upper() if status else None,
1233
- "errors": processing_errors if processing_errors != "[]" else None,
1234
- "batches": f"{batches}" if batches else "",
1235
- }))
1236
- return items
1237
-
1238
- def poll_imports(self, sources:List[str], model:str):
1239
- source_set = self._create_source_set(sources)
1240
- def check_imports():
1241
- imports = [
1242
- import_
1243
- for import_ in self.list_imports(model=model)
1244
- if import_["name"] in source_set
1245
- ]
1246
- # loop through printing status for each in the format (index): (name) - (status)
1247
- statuses = [import_["status"] for import_ in imports]
1248
- if all(status == "LOADED" for status in statuses):
1249
- return True
1250
- if any(status == "QUARANTINED" for status in statuses):
1251
- failed_imports = [import_["name"] for import_ in imports if import_["status"] == "QUARANTINED"]
1252
- raise RAIException("Imports failed:" + ", ".join(failed_imports)) from None
1253
- # this check is necessary in case some of the tables are empty;
1254
- # such tables may be synced even though their status is None:
1255
- def synced(import_):
1256
- if import_["status"] == "LOADED":
1257
- return True
1258
- if import_["status"] is None:
1259
- import_full_status = self.get_import_stream(import_["name"], model)
1260
- if import_full_status and import_full_status[0]["data_sync_status"] == "SYNCED":
1261
- return True
1262
- return False
1263
- if all(synced(import_) for import_ in imports):
1264
- return True
1265
- poll_with_specified_overhead(check_imports, overhead_rate=0.1, max_delay=10)
1266
-
1267
- def _create_source_set(self, sources: List[str]) -> set:
1268
- return {
1269
- source.upper() if not IdentityParser(source).has_double_quoted_identifier else IdentityParser(source).identity
1270
- for source in sources
1271
- }
1272
-
1273
- def get_import_stream(self, name:str|None, model:str|None):
1274
- results = self._exec(f"CALL {APP_NAME}.api.get_data_stream('{name}', '{model}');")
1275
- if not results:
1276
- return None
1277
- return self.imports_to_dicts(results)
1278
-
1279
- def create_import_stream(self, source:ImportSource, model:str, rate = 1, options: dict|None = None):
1280
- assert isinstance(source, ImportSourceTable), "Snowflake integration only supports loading from SF Tables. Try loading your data as a table via the Snowflake interface first."
1281
- object = source.fqn
1282
-
1283
- # Parse only to the schema level
1284
- schemaParser = IdentityParser(f"{source.database}.{source.schema}")
1285
-
1286
- if object.lower() in [x["name"].lower() for x in self.list_imports(model=model)]:
1287
- return
1288
-
1289
- query = f"SHOW OBJECTS LIKE '{source.table}' IN {schemaParser.identity}"
1290
-
1291
- info = self._exec(query)
1292
- if not info:
1293
- raise ValueError(f"Object {source.table} not found in schema {schemaParser.identity}")
1294
- else:
1295
- data = info[0]
1296
- if not data:
1297
- raise ValueError(f"Object {source.table} not found in {schemaParser.identity}")
1298
- # (time, name, db_name, schema_name, kind, *rest)
1299
- kind = data["kind"]
1300
-
1301
- relation_name = to_fqn_relation_name(object)
1302
-
1303
- command = f"""call {APP_NAME}.api.create_data_stream(
1304
- {APP_NAME}.api.object_reference('{kind}', '{object}'),
1305
- '{model}',
1306
- '{relation_name}');"""
1307
-
1308
- def create_stream(tracking_just_changed=False):
1309
- try:
1310
- self._exec(command)
1311
- except Exception as e:
1312
- if "ensure that CHANGE_TRACKING is enabled on the source object" in str(e):
1313
- if self.config.get("ensure_change_tracking", False) and not tracking_just_changed:
1314
- try:
1315
- self._exec(f"ALTER {kind} {object} SET CHANGE_TRACKING = TRUE;")
1316
- create_stream(tracking_just_changed=True)
1317
- except Exception:
1318
- pass
1319
- else:
1320
- print("\n")
1321
- exception = SnowflakeChangeTrackingNotEnabledException((object, kind))
1322
- raise exception from None
1323
- elif "Database does not exist" in str(e):
1324
- print("\n")
1325
- raise ModelNotFoundException(model) from None
1326
- raise e
1327
-
1328
- create_stream()
1329
-
1330
- def create_import_snapshot(self, source:ImportSource, model:str, options: dict|None = None):
1331
- raise Exception("Snowflake integration doesn't support snapshot imports yet")
1332
-
1333
- def delete_import(self, import_name:str, model:str, force = False):
1334
- engine = self.get_default_engine_name()
1335
- rel_name = to_fqn_relation_name(import_name)
1336
- try:
1337
- self._exec(f"""call {APP_NAME}.api.delete_data_stream(
1338
- '{import_name}',
1339
- '{model}'
1340
- );""")
1341
- except RAIException as err:
1342
- if "streams do not exist" not in str(err) or not force:
1343
- raise
1344
-
1345
- # if force is true, we delete the leftover relation to free up the name (in case the user re-creates the stream)
1346
- if force:
1347
- self.exec_raw(model, engine, f"""
1348
- declare ::{rel_name}
1349
- def delete[:\"{rel_name}\"]: {{ {rel_name} }}
1350
- """, readonly=False, bypass_index=True)
1351
-
1352
- #--------------------------------------------------
1353
- # Exec Async
1354
- #--------------------------------------------------
1355
-
1356
- def _check_exec_async_status(self, txn_id: str, headers: Dict | None = None):
1357
- """Check whether the given transaction has completed."""
1358
- if headers is None:
1359
- headers = {}
1360
-
1361
- with debugging.span("check_status"):
1362
- response = self._exec(f"CALL {APP_NAME}.api.get_transaction('{txn_id}',{headers});")
1363
- assert response, f"No results from get_transaction('{txn_id}')"
1364
-
1365
- response_row = next(iter(response)).asDict()
1366
- status: str = response_row['STATE']
1367
-
1368
- # remove the transaction from the pending list if it's completed or aborted
1369
- if status in ["COMPLETED", "ABORTED"]:
1370
- if txn_id in self._pending_transactions:
1371
- self._pending_transactions.remove(txn_id)
1372
-
1373
- if status == "ABORTED" and response_row.get("ABORT_REASON", "") == TXN_ABORT_REASON_TIMEOUT:
1374
- config_file_path = getattr(self.config, 'file_path', None)
1375
- # todo: use the timeout returned alongside the transaction as soon as it's exposed
1376
- timeout_mins = int(self.config.get("query_timeout_mins", DEFAULT_QUERY_TIMEOUT_MINS) or DEFAULT_QUERY_TIMEOUT_MINS)
1377
- raise QueryTimeoutExceededException(
1378
- timeout_mins=timeout_mins,
1379
- query_id=txn_id,
1380
- config_file_path=config_file_path,
1381
- )
1382
-
1383
- # @TODO: Find some way to tunnel the ABORT_REASON out. Azure doesn't have this, but it's handy
1384
- return status == "COMPLETED" or status == "ABORTED"
1385
-
1386
- def decrypt_stream(self, key: bytes, iv: bytes, src: bytes) -> bytes:
1387
- """Decrypt the provided stream with PKCS#5 padding handling."""
1388
-
1389
- if crypto_disabled:
1390
- if isinstance(runtime_env, SnowbookEnvironment) and runtime_env.runner == "warehouse":
1391
- raise Exception("Please open the navigation-bar dropdown labeled *Packages* and select `cryptography` under the *Anaconda Packages* section, and then re-run your query.")
1392
- else:
1393
- raise Exception("library `cryptography.hazmat` missing; please install")
1394
-
1395
- # `type:ignore`s are because of the conditional import, which
1396
- # we have because warehouse-based snowflake notebooks don't support
1397
- # the crypto library we're using.
1398
- cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend()) # type: ignore
1399
- decryptor = cipher.decryptor()
1400
-
1401
- # Decrypt the data
1402
- decrypted_padded_data = decryptor.update(src) + decryptor.finalize()
1403
-
1404
- # Unpad the decrypted data using PKCS#5
1405
- unpadder = padding.PKCS7(128).unpadder() # type: ignore # Use 128 directly for AES
1406
- unpadded_data = unpadder.update(decrypted_padded_data) + unpadder.finalize()
1407
-
1408
- return unpadded_data
1409
-
1410
- def _decrypt_artifact(self, data: bytes, encryption_material: str) -> bytes:
1411
- """Decrypts the artifact data using provided encryption material."""
1412
- encryption_material_parts = encryption_material.split("|")
1413
- assert len(encryption_material_parts) == 3, "Invalid encryption material"
1414
-
1415
- algorithm, key_base64, iv_base64 = encryption_material_parts
1416
- assert algorithm == "AES_128_CBC", f"Unsupported encryption algorithm {algorithm}"
1417
-
1418
- key = base64.standard_b64decode(key_base64)
1419
- iv = base64.standard_b64decode(iv_base64)
1420
-
1421
- return self.decrypt_stream(key, iv, data)
1422
-
1423
- def _list_exec_async_artifacts(self, txn_id: str, headers: Dict | None = None) -> Dict[str, Dict]:
1424
- """Grab the list of artifacts produced in the transaction and the URLs to retrieve their contents."""
1425
- if headers is None:
1426
- headers = {}
1427
- with debugging.span("list_results"):
1428
- response = self._exec(
1429
- f"CALL {APP_NAME}.api.get_own_transaction_artifacts('{txn_id}',{headers});"
1430
- )
1431
- assert response, f"No results from get_own_transaction_artifacts('{txn_id}')"
1432
- return {row["FILENAME"]: row for row in response}
1433
-
1434
- def _fetch_exec_async_artifacts(
1435
- self, artifact_info: Dict[str, Dict[str, Any]]
1436
- ) -> Dict[str, Any]:
1437
- """Grab the contents of the given artifacts from SF in parallel using threads."""
1438
-
1439
- with requests.Session() as session:
1440
- def _fetch_data(name_info):
1441
- filename, metadata = name_info
1442
-
1443
- try:
1444
- # Extract the presigned URL and encryption material from metadata
1445
- url_key = self.get_url_key(metadata)
1446
- presigned_url = metadata[url_key]
1447
- encryption_material = metadata["ENCRYPTION_MATERIAL"]
1448
-
1449
- response = get_with_retries(session, presigned_url, config=self.config)
1450
- response.raise_for_status() # Throw if something goes wrong
1451
-
1452
- decrypted = self._maybe_decrypt(response.content, encryption_material)
1453
- return (filename, decrypted)
1454
-
1455
- except requests.RequestException as e:
1456
- raise scrub_exception(wrap_with_request_id(e))
1457
-
1458
- # Create a list of tuples for the map function
1459
- name_info_pairs = list(artifact_info.items())
1460
-
1461
- with ThreadPoolExecutor(max_workers=5) as executor:
1462
- results = executor.map(_fetch_data, name_info_pairs)
1463
-
1464
- return {name: data for (name, data) in results}
1465
-
1466
- def _maybe_decrypt(self, content: bytes, encryption_material: str) -> bytes:
1467
- # Decrypt if encryption material is present
1468
- if encryption_material:
1469
- # if there's no padding, the initial file was empty
1470
- if len(content) == 0:
1471
- return b""
1472
-
1473
- return self._decrypt_artifact(content, encryption_material)
1474
-
1475
- # otherwise, return content directly
1476
- return content
1477
-
1478
- def _parse_exec_async_results(self, arrow_files: List[Tuple[str, bytes]]):
1479
- """Mimics the logic in _parse_arrow_results of railib/api.py#L303 without requiring a wrapping multipart form."""
1480
- results = []
1481
-
1482
- for file_name, file_content in arrow_files:
1483
- with pa.ipc.open_stream(file_content) as reader:
1484
- schema = reader.schema
1485
- batches = [batch for batch in reader]
1486
- table = pa.Table.from_batches(batches=batches, schema=schema)
1487
- results.append({"relationId": file_name, "table": table})
1488
-
1489
- return results
1490
-
1491
- def _download_results(
1492
- self, artifact_info: Dict[str, Dict], txn_id: str, state: str
1493
- ) -> TransactionAsyncResponse:
1494
- with debugging.span("download_results"):
1495
- # Fetch artifacts
1496
- artifacts = self._fetch_exec_async_artifacts(artifact_info)
1497
-
1498
- # Directly use meta_json as it is fetched
1499
- meta_json_bytes = artifacts["metadata.json"]
1500
-
1501
- # Decode the bytes and parse the JSON
1502
- meta_json_str = meta_json_bytes.decode('utf-8')
1503
- meta_json = json.loads(meta_json_str) # Parse the JSON string
1504
-
1505
- # Use the metadata to map arrow files to the relations they contain
1506
- try:
1507
- arrow_files_to_relations = {
1508
- artifact["filename"]: artifact["relationId"]
1509
- for artifact in meta_json
1510
- }
1511
- except KeyError:
1512
- # TODO: Remove this fallback mechanism later once several engine versions are updated
1513
- arrow_files_to_relations = {
1514
- f"{ix}.arrow": artifact["relationId"]
1515
- for ix, artifact in enumerate(meta_json)
1516
- }
1517
-
1518
- # Hydrate the arrow files into tables
1519
- results = self._parse_exec_async_results(
1520
- [
1521
- (arrow_files_to_relations[name], content)
1522
- for name, content in artifacts.items()
1523
- if name.endswith(".arrow")
1524
- ]
1525
- )
1526
-
1527
- # Create and return the response
1528
- rsp = TransactionAsyncResponse()
1529
- rsp.transaction = {
1530
- "id": txn_id,
1531
- "state": state,
1532
- "response_format_version": None,
1533
- }
1534
- rsp.metadata = meta_json
1535
- rsp.problems = artifacts.get(
1536
- "problems.json"
1537
- ) # Safely access possible missing keys
1538
- rsp.results = results
1539
- return rsp
1540
-
1541
- def get_transaction_problems(self, txn_id: str) -> List[Dict[str, Any]]:
1542
- with debugging.span("get_own_transaction_problems"):
1543
- response = self._exec(
1544
- f"select * from table({APP_NAME}.api.get_own_transaction_problems('{txn_id}'));"
1545
- )
1546
- if not response:
1547
- return []
1548
- return response
1549
-
1550
- def get_url_key(self, metadata) -> str:
1551
- # In Azure, there is only one type of URL, which is used for both internal and
1552
- # external access; always use that one
1553
- if self.is_azure(metadata['PRESIGNED_URL']):
1554
- return 'PRESIGNED_URL'
1555
-
1556
- configured = self.config.get("download_url_type", None)
1557
- if configured == "internal":
1558
- return 'PRESIGNED_URL_AP'
1559
- elif configured == "external":
1560
- return "PRESIGNED_URL"
1561
-
1562
- if self.is_container_runtime():
1563
- return 'PRESIGNED_URL_AP'
1564
-
1565
- return 'PRESIGNED_URL'
1566
-
1567
- def is_azure(self, url) -> bool:
1568
- return "blob.core.windows.net" in url
1569
-
1570
- def is_container_runtime(self) -> bool:
1571
- return isinstance(runtime_env, SnowbookEnvironment) and runtime_env.runner == "container"
1572
-
1573
- def _exec_rai_app(
1574
- self,
1575
- database: str,
1576
- engine: str | None,
1577
- raw_code: str,
1578
- inputs: Dict,
1579
- readonly=True,
1580
- nowait_durable=False,
1581
- request_headers: Dict | None = None,
1582
- bypass_index=False,
1583
- language: str = "rel",
1584
- query_timeout_mins: int | None = None,
1585
- ):
1586
- assert language == "rel" or language == "lqp", "Only 'rel' and 'lqp' languages are supported"
1587
- if query_timeout_mins is None and (timeout_value := self.config.get("query_timeout_mins", DEFAULT_QUERY_TIMEOUT_MINS)) is not None:
1588
- query_timeout_mins = int(timeout_value)
1589
- # Depending on the shape of the input, the behavior of exec_async_v2 changes.
1590
- # When using the new format (with an object), the function retrieves the
1591
- # 'rai' database by hashing the model and username. In contrast, the
1592
- # current version directly uses the passed database value.
1593
- # Therefore, we must use the original exec_async_v2 when not using the
1594
- # graph index to ensure the correct database is utilized.
1595
- use_graph_index = self.config.get("use_graph_index", USE_GRAPH_INDEX)
1596
- if use_graph_index and not bypass_index:
1597
- payload = {
1598
- 'database': database,
1599
- 'engine': engine,
1600
- 'inputs': inputs,
1601
- 'readonly': readonly,
1602
- 'nowait_durable': nowait_durable,
1603
- 'language': language,
1604
- 'headers': request_headers
1605
- }
1606
- if query_timeout_mins is not None:
1607
- payload["timeout_mins"] = query_timeout_mins
1608
- sql_string = f"CALL {APP_NAME}.api.exec_async_v2(?, {payload});"
1609
- else:
1610
- if query_timeout_mins is not None:
1611
- sql_string = f"CALL {APP_NAME}.api.exec_async_v2('{database}','{engine}', ?, {inputs}, {readonly}, {nowait_durable}, '{language}', {query_timeout_mins}, {request_headers});"
1612
- else:
1613
- sql_string = f"CALL {APP_NAME}.api.exec_async_v2('{database}','{engine}', ?, {inputs}, {readonly}, {nowait_durable}, '{language}', {request_headers});"
1614
- # Don't let exec setup GI on failure, exec_raw and exec_lqp will do that and add the correct headers.
1615
- response = self._exec(
1616
- sql_string,
1617
- raw_code,
1618
- skip_engine_db_error_retry=True,
1619
- )
1620
- if not response:
1621
- raise Exception("Failed to create transaction")
1622
- return response
1623
-
1624
- def _exec_async_v2(
1625
- self,
1626
- database: str,
1627
- engine: str | None,
1628
- raw_code: str,
1629
- inputs: Dict | None = None,
1630
- readonly=True,
1631
- nowait_durable=False,
1632
- headers: Dict | None = None,
1633
- bypass_index=False,
1634
- language: str = "rel",
1635
- query_timeout_mins: int | None = None,
1636
- gi_setup_skipped: bool = False,
1637
- ):
1638
- if inputs is None:
1639
- inputs = {}
1640
- request_headers = debugging.add_current_propagation_headers(headers)
1641
- query_attrs_dict = json.loads(request_headers.get("X-Query-Attributes", "{}"))
1642
-
1643
- with debugging.span("transaction", **query_attrs_dict) as txn_span:
1644
- with debugging.span("create_v2", **query_attrs_dict) as create_span:
1645
- request_headers['user-agent'] = get_pyrel_version(self.generation)
1646
- request_headers['gi_setup_skipped'] = str(gi_setup_skipped)
1647
- request_headers['pyrel_program_id'] = debugging.get_program_span_id() or ""
1648
- response = self._exec_rai_app(
1649
- database=database,
1650
- engine=engine,
1651
- raw_code=raw_code,
1652
- inputs=inputs,
1653
- readonly=readonly,
1654
- nowait_durable=nowait_durable,
1655
- request_headers=request_headers,
1656
- bypass_index=bypass_index,
1657
- language=language,
1658
- query_timeout_mins=query_timeout_mins,
1659
- )
1660
-
1661
- artifact_info = {}
1662
- rows = list(iter(response))
1663
-
1664
- # process the first row since txn_id and state are the same for all rows
1665
- first_row = rows[0]
1666
- txn_id = first_row['ID']
1667
- state = first_row['STATE']
1668
- filename = first_row['FILENAME']
1669
-
1670
- txn_span["txn_id"] = txn_id
1671
- create_span["txn_id"] = txn_id
1672
- debugging.event("transaction_created", txn_span, txn_id=txn_id)
1673
-
1674
- # fast path: transaction already finished
1675
- if state in ["COMPLETED", "ABORTED"]:
1676
- if txn_id in self._pending_transactions:
1677
- self._pending_transactions.remove(txn_id)
1678
-
1679
- # Process rows to get the rest of the artifacts
1680
- for row in rows:
1681
- filename = row['FILENAME']
1682
- artifact_info[filename] = row
1683
-
1684
- # Slow path: transaction not done yet; start polling
1685
- else:
1686
- self._pending_transactions.append(txn_id)
1687
- with debugging.span("wait", txn_id=txn_id):
1688
- poll_with_specified_overhead(
1689
- lambda: self._check_exec_async_status(txn_id, headers=request_headers), 0.1
1690
- )
1691
- artifact_info = self._list_exec_async_artifacts(txn_id, headers=request_headers)
1692
-
1693
- with debugging.span("fetch"):
1694
- return self._download_results(artifact_info, txn_id, state)
1695
-
1696
- def get_user_based_engine_name(self):
1697
- if not self._session:
1698
- self._session = self.get_sf_session()
1699
- user_table = self._session.sql("select current_user()").collect()
1700
- user = user_table[0][0]
1701
- assert isinstance(user, str), f"current_user() must return a string, not {type(user)}"
1702
- return _sanitize_user_name(user)
1703
-
1704
- def is_engine_ready(self, engine_name: str):
1705
- engine = self.get_engine(engine_name)
1706
- return engine and engine["state"] == "READY"
1707
-
1708
- def auto_create_engine(self, name: str | None = None, size: str | None = None, headers: Dict | None = None):
1709
- from v0.relationalai.tools.cli_helpers import validate_engine_name
1710
- with debugging.span("auto_create_engine", active=self._active_engine) as span:
1711
- active = self._get_active_engine()
1712
- if active:
1713
- return active
1714
-
1715
- engine_name = name or self.get_default_engine_name()
1716
-
1717
- # Use the provided size or fall back to the config
1718
- if size:
1719
- engine_size = size
1720
- else:
1721
- engine_size = self.config.get("engine_size", None)
1722
-
1723
- # Validate engine size
1724
- if engine_size:
1725
- is_size_valid, sizes = self.validate_engine_size(engine_size)
1726
- if not is_size_valid:
1727
- raise Exception(f"Invalid engine size '{engine_size}'. Valid sizes are: {', '.join(sizes)}")
1728
-
1729
- # Validate engine name
1730
- is_name_valid, _ = validate_engine_name(engine_name)
1731
- if not is_name_valid:
1732
- raise EngineNameValidationException(engine_name)
1733
-
1734
- try:
1735
- engine = self.get_engine(engine_name)
1736
- if engine:
1737
- span.update(cast(dict, engine))
1738
-
1739
- # if engine is in the pending state, poll until its status changes
1740
- # if engine is gone, delete it and create new one
1741
- # if engine is in the ready state, return engine name
1742
- if engine:
1743
- if engine["state"] == "PENDING":
1744
- # if the user explicitly specified a size, warn if the pending engine size doesn't match it
1745
- if size is not None and engine["size"] != size:
1746
- EngineSizeMismatchWarning(engine_name, engine["size"], size)
1747
- # poll until engine is ready
1748
- with Spinner(
1749
- "Waiting for engine to be initialized",
1750
- "Engine ready",
1751
- ):
1752
- poll_with_specified_overhead(lambda: self.is_engine_ready(engine_name), overhead_rate=0.1, max_delay=0.5, timeout=900)
1753
-
1754
- elif engine["state"] == "SUSPENDED":
1755
- with Spinner(f"Resuming engine '{engine_name}'", f"Engine '{engine_name}' resumed", f"Failed to resume engine '{engine_name}'"):
1756
- try:
1757
- self.resume_engine_async(engine_name, headers=headers)
1758
- poll_with_specified_overhead(lambda: self.is_engine_ready(engine_name), overhead_rate=0.1, max_delay=0.5, timeout=900)
1759
- except Exception:
1760
- raise EngineResumeFailed(engine_name)
1761
- elif engine["state"] == "READY":
1762
- # if the user explicitly specified a size, warn if the ready engine size doesn't match it
1763
- if size is not None and engine["size"] != size:
1764
- EngineSizeMismatchWarning(engine_name, engine["size"], size)
1765
- self._set_active_engine(engine)
1766
- return engine_name
1767
- elif engine["state"] == "GONE":
1768
- try:
1769
- # "Gone" is abnormal condition when metadata and SF service don't match
1770
- # Therefore, we have to delete the engine and create a new one
1771
- # it could be case that engine is already deleted, so we have to catch the exception
1772
- self.delete_engine(engine_name, headers=headers)
1773
- # After deleting the engine, set it to None so that we can create a new engine
1774
- engine = None
1775
- except Exception as e:
1776
- # if engine is already deleted, we will get an exception
1777
- # we can ignore this exception and create a new engine
1778
- if isinstance(e, EngineNotFoundException):
1779
- engine = None
1780
- pass
1781
- else:
1782
- raise EngineProvisioningFailed(engine_name, e) from e
1783
-
1784
- if not engine:
1785
- with Spinner(
1786
- f"Auto-creating engine {engine_name}",
1787
- f"Auto-created engine {engine_name}",
1788
- "Engine creation failed",
1789
- ):
1790
- self.create_engine(engine_name, size=engine_size, headers=headers)
1791
- except Exception as e:
1792
- print(e)
1793
- if DUO_TEXT in str(e).lower():
1794
- raise DuoSecurityFailed(e)
1795
- raise EngineProvisioningFailed(engine_name, e) from e
1796
- return engine_name
1797
-
1798
- def auto_create_engine_async(self, name: str | None = None):
1799
- active = self._get_active_engine()
1800
- if active and (active == name or name is None):
1801
- return # @NOTE: This method weirdly doesn't return engine name even though all the other ones do?
1802
-
1803
- with Spinner(
1804
- "Checking engine status",
1805
- leading_newline=True,
1806
- ) as spinner:
1807
- from v0.relationalai.tools.cli_helpers import validate_engine_name
1808
- with debugging.span("auto_create_engine_async", active=self._active_engine):
1809
- engine_name = name or self.get_default_engine_name()
1810
- engine_size = self.config.get("engine_size", None)
1811
- if engine_size:
1812
- is_size_valid, sizes = self.validate_engine_size(engine_size)
1813
- if not is_size_valid:
1814
- raise Exception(f"Invalid engine size in config: '{engine_size}'. Valid sizes are: {', '.join(sizes)}")
1815
- else:
1816
- engine_size = self.config.get_default_engine_size()
1817
-
1818
- is_name_valid, _ = validate_engine_name(engine_name)
1819
- if not is_name_valid:
1820
- raise EngineNameValidationException(engine_name)
1821
- try:
1822
- engine = self.get_engine(engine_name)
1823
- # if engine is gone, delete it and create new one
1824
- # in case of pending state, do nothing, it is use_index responsibility to wait for engine to be ready
1825
- if engine:
1826
- if engine["state"] == "PENDING":
1827
- spinner.update_messages(
1828
- {
1829
- "finished_message": f"Starting engine {engine_name}",
1830
- }
1831
- )
1832
- pass
1833
- elif engine["state"] == "SUSPENDED":
1834
- spinner.update_messages(
1835
- {
1836
- "finished_message": f"Resuming engine {engine_name}",
1837
- }
1838
- )
1839
- try:
1840
- self.resume_engine_async(engine_name)
1841
- except Exception:
1842
- raise EngineResumeFailed(engine_name)
1843
- elif engine["state"] == "READY":
1844
- spinner.update_messages(
1845
- {
1846
- "finished_message": f"Engine {engine_name} initialized",
1847
- }
1848
- )
1849
- pass
1850
- elif engine["state"] == "GONE":
1851
- spinner.update_messages(
1852
- {
1853
- "message": f"Restarting engine {engine_name}",
1854
- }
1855
- )
1856
- try:
1857
- # "Gone" is abnormal condition when metadata and SF service don't match
1858
- # Therefore, we have to delete the engine and create a new one
1859
- # it could be case that engine is already deleted, so we have to catch the exception
1860
- # set it to None so that we can create a new engine
1861
- engine = None
1862
- self.delete_engine(engine_name)
1863
- except Exception as e:
1864
- # if engine is already deleted, we will get an exception
1865
- # we can ignore this exception and create a new engine asynchronously
1866
- if isinstance(e, EngineNotFoundException):
1867
- engine = None
1868
- pass
1869
- else:
1870
- print(e)
1871
- raise EngineProvisioningFailed(engine_name, e) from e
1872
-
1873
- if not engine:
1874
- self.create_engine_async(engine_name, size=self.config.get("engine_size", None))
1875
- spinner.update_messages(
1876
- {
1877
- "finished_message": f"Starting engine {engine_name}...",
1878
- }
1879
- )
1880
- else:
1881
- self._set_active_engine(engine)
1882
-
1883
- except Exception as e:
1884
- spinner.update_messages(
1885
- {
1886
- "finished_message": f"Failed to create engine {engine_name}",
1887
- }
1888
- )
1889
- if DUO_TEXT in str(e).lower():
1890
- raise DuoSecurityFailed(e)
1891
- if isinstance(e, RAIException):
1892
- raise e
1893
- print(e)
1894
- raise EngineProvisioningFailed(engine_name, e) from e
1895
-
1896
- def validate_engine_size(self, size: str) -> Tuple[bool, List[str]]:
1897
- if size is not None:
1898
- sizes = self.get_engine_sizes()
1899
- if size not in sizes:
1900
- return False, sizes
1901
- return True, []
1902
-
1903
- #--------------------------------------------------
1904
- # Exec
1905
- #--------------------------------------------------
1906
-
1907
- def _exec_with_gi_retry(
1908
- self,
1909
- database: str,
1910
- engine: str | None,
1911
- raw_code: str,
1912
- inputs: Dict | None,
1913
- readonly: bool,
1914
- nowait_durable: bool,
1915
- headers: Dict | None,
1916
- bypass_index: bool,
1917
- language: str,
1918
- query_timeout_mins: int | None,
1919
- ):
1920
- """Execute with graph index retry logic.
1921
-
1922
- Attempts execution with gi_setup_skipped=True first. If an engine or database
1923
- issue occurs, polls use_index and retries with gi_setup_skipped=False.
1924
- """
1925
- try:
1926
- return self._exec_async_v2(
1927
- database, engine, raw_code, inputs, readonly, nowait_durable,
1928
- headers=headers, bypass_index=bypass_index, language=language,
1929
- query_timeout_mins=query_timeout_mins, gi_setup_skipped=True,
1930
- )
1931
- except Exception as e:
1932
- err_message = str(e).lower()
1933
- if _is_engine_issue(err_message) or _is_database_issue(err_message):
1934
- engine_name = engine or self.get_default_engine_name()
1935
- engine_size = self.config.get_default_engine_size()
1936
- self._poll_use_index(
1937
- app_name=self.get_app_name(),
1938
- sources=self.sources,
1939
- model=database,
1940
- engine_name=engine_name,
1941
- engine_size=engine_size,
1942
- headers=headers,
1943
- )
1944
-
1945
- return self._exec_async_v2(
1946
- database, engine, raw_code, inputs, readonly, nowait_durable,
1947
- headers=headers, bypass_index=bypass_index, language=language,
1948
- query_timeout_mins=query_timeout_mins, gi_setup_skipped=False,
1949
- )
1950
- else:
1951
- raise e
1952
-
1953
- def exec_lqp(
1954
- self,
1955
- database: str,
1956
- engine: str | None,
1957
- raw_code: bytes,
1958
- readonly=True,
1959
- *,
1960
- inputs: Dict | None = None,
1961
- nowait_durable=False,
1962
- headers: Dict | None = None,
1963
- bypass_index=False,
1964
- query_timeout_mins: int | None = None,
1965
- ):
1966
- raw_code_b64 = base64.b64encode(raw_code).decode("utf-8")
1967
- return self._exec_with_gi_retry(
1968
- database, engine, raw_code_b64, inputs, readonly, nowait_durable,
1969
- headers, bypass_index, 'lqp', query_timeout_mins
1970
- )
1971
-
1972
-
1973
- def exec_raw(
1974
- self,
1975
- database: str,
1976
- engine: str | None,
1977
- raw_code: str,
1978
- readonly=True,
1979
- *,
1980
- inputs: Dict | None = None,
1981
- nowait_durable=False,
1982
- headers: Dict | None = None,
1983
- bypass_index=False,
1984
- query_timeout_mins: int | None = None,
1985
- ):
1986
- raw_code = raw_code.replace("'", "\\'")
1987
- return self._exec_with_gi_retry(
1988
- database, engine, raw_code, inputs, readonly, nowait_durable,
1989
- headers, bypass_index, 'rel', query_timeout_mins
1990
- )
1991
-
1992
-
1993
- def format_results(self, results, task:m.Task|None=None) -> Tuple[DataFrame, List[Any]]:
1994
- return result_helpers.format_results(results, task)
1995
-
1996
- #--------------------------------------------------
1997
- # Exec format
1998
- #--------------------------------------------------
1999
-
2000
- def exec_format(
2001
- self,
2002
- database: str,
2003
- engine: str,
2004
- raw_code: str,
2005
- cols: List[str],
2006
- format: str,
2007
- inputs: Dict | None = None,
2008
- readonly=True,
2009
- nowait_durable=False,
2010
- skip_invalid_data=False,
2011
- headers: Dict | None = None,
2012
- query_timeout_mins: int | None = None,
2013
- ):
2014
- if inputs is None:
2015
- inputs = {}
2016
- if headers is None:
2017
- headers = {}
2018
- if 'user-agent' not in headers:
2019
- headers['user-agent'] = get_pyrel_version(self.generation)
2020
- if query_timeout_mins is None and (timeout_value := self.config.get("query_timeout_mins", DEFAULT_QUERY_TIMEOUT_MINS)) is not None:
2021
- query_timeout_mins = int(timeout_value)
2022
- # TODO: add headers
2023
- start = time.perf_counter()
2024
- output_table = "out" + str(uuid.uuid4()).replace("-", "_")
2025
- temp_table = f"temp_{output_table}"
2026
- use_graph_index = self.config.get("use_graph_index", USE_GRAPH_INDEX)
2027
- txn_id = None
2028
- rejected_rows = None
2029
- col_names_map = None
2030
- artifacts = None
2031
- assert self._session
2032
- temp = self._session.createDataFrame([], StructType([StructField(name, StringType()) for name in cols]))
2033
- with debugging.span("transaction") as txn_span:
2034
- try:
2035
- # In the graph index case we need to use the new exec_into_table proc as it obfuscates the db name
2036
- with debugging.span("exec_format"):
2037
- if use_graph_index:
2038
- # we do not provide a default value for query_timeout_mins so that we can control the default on app level
2039
- if query_timeout_mins is not None:
2040
- res = self._exec(f"call {APP_NAME}.api.exec_into_table(?, ?, ?, ?, ?, NULL, ?, {headers}, ?, ?);", [database, engine, raw_code, output_table, readonly, nowait_durable, skip_invalid_data, query_timeout_mins])
2041
- else:
2042
- res = self._exec(f"call {APP_NAME}.api.exec_into_table(?, ?, ?, ?, ?, NULL, ?, {headers}, ?);", [database, engine, raw_code, output_table, readonly, nowait_durable, skip_invalid_data])
2043
- txn_id = json.loads(res[0]["EXEC_INTO_TABLE"])["rai_transaction_id"]
2044
- rejected_rows = json.loads(res[0]["EXEC_INTO_TABLE"]).get("rejected_rows", [])
2045
- rejected_rows_count = json.loads(res[0]["EXEC_INTO_TABLE"]).get("rejected_rows_count", 0)
2046
- else:
2047
- if query_timeout_mins is not None:
2048
- res = self._exec(f"call {APP_NAME}.api.exec_into(?, ?, ?, ?, ?, {inputs}, ?, {headers}, ?, ?);", [database, engine, raw_code, output_table, readonly, nowait_durable, skip_invalid_data, query_timeout_mins])
2049
- else:
2050
- res = self._exec(f"call {APP_NAME}.api.exec_into(?, ?, ?, ?, ?, {inputs}, ?, {headers}, ?);", [database, engine, raw_code, output_table, readonly, nowait_durable, skip_invalid_data])
2051
- txn_id = json.loads(res[0]["EXEC_INTO"])["rai_transaction_id"]
2052
- rejected_rows = json.loads(res[0]["EXEC_INTO"]).get("rejected_rows", [])
2053
- rejected_rows_count = json.loads(res[0]["EXEC_INTO"]).get("rejected_rows_count", 0)
2054
- debugging.event("transaction_created", txn_span, txn_id=txn_id)
2055
- debugging.time("exec_format", time.perf_counter() - start, DataFrame())
2056
-
2057
- with debugging.span("temp_table_swap", txn_id=txn_id):
2058
- out_sample = self._exec(f"select * from {APP_NAME}.results.{output_table} limit 1;")
2059
- if out_sample:
2060
- keys = set([k.lower() for k in out_sample[0].as_dict().keys()])
2061
- col_names_map = {}
2062
- for ix, name in enumerate(cols):
2063
- col_key = f"col{ix:03}"
2064
- if col_key in keys:
2065
- col_names_map[col_key] = IdentityParser(name).identity
2066
- else:
2067
- col_names_map[col_key] = name
2068
-
2069
- names = ", ".join([
2070
- f"{col_key} as {alias}" if col_key in keys else f"NULL as {alias}"
2071
- for col_key, alias in col_names_map.items()
2072
- ])
2073
- self._exec(f"CREATE TEMPORARY TABLE {APP_NAME}.results.{temp_table} AS SELECT {names} FROM {APP_NAME}.results.{output_table};")
2074
- self._exec(f"call {APP_NAME}.api.drop_result_table(?)", [output_table])
2075
- temp = cast(snowflake.snowpark.DataFrame, self._exec(f"select * from {APP_NAME}.results.{temp_table}", raw=True))
2076
- if rejected_rows:
2077
- debugging.warn(RowsDroppedFromTargetTableWarning(rejected_rows, rejected_rows_count, col_names_map))
2078
- except Exception as e:
2079
- msg = str(e).lower()
2080
- if "no columns returned" in msg or "columns of results could not be determined" in msg:
2081
- pass
2082
- else:
2083
- raise e
2084
- if txn_id:
2085
- artifact_info = self._list_exec_async_artifacts(txn_id)
2086
- with debugging.span("fetch"):
2087
- artifacts = self._download_results(artifact_info, txn_id, "ABORTED")
2088
- return (temp, artifacts)
2089
-
2090
- #--------------------------------------------------
2091
- # Custom model types
2092
- #--------------------------------------------------
2093
-
2094
- def _get_ns(self, model:dsl.Graph):
2095
- if model not in self._ns_cache:
2096
- self._ns_cache[model] = _Snowflake(model)
2097
- return self._ns_cache[model]
2098
-
2099
- def to_model_type(self, model:dsl.Graph, name: str, source:str):
2100
- parser = IdentityParser(source)
2101
- if not parser.is_complete:
2102
- raise SnowflakeInvalidSource(Errors.call_source(), source)
2103
- ns = self._get_ns(model)
2104
- # skip the last item in the list (the full identifier)
2105
- for part in parser.to_list()[:-1]:
2106
- ns = ns._safe_get(part)
2107
- assert parser.identity, f"Error parsing source in to_model_type: {source}"
2108
- self.sources.add(parser.identity)
2109
- return ns
2110
-
2111
- def _check_source_updates(self, sources: Iterable[str]):
2112
- if not sources:
2113
- return {}
2114
- app_name = self.get_app_name()
2115
-
2116
- source_types = dict[str, SourceInfo]()
2117
- partitioned_sources: dict[str, dict[str, list[dict[str, str]]]] = defaultdict(
2118
- lambda: defaultdict(list)
2119
- )
2120
- fqn_to_parts: dict[str, tuple[str, str, str]] = {}
2121
-
2122
- for source in sources:
2123
- parser = IdentityParser(source, True)
2124
- parsed = parser.to_list()
2125
- assert len(parsed) == 4, f"Invalid source: {source}"
2126
- db, schema, entity, identity = parsed
2127
- assert db and schema and entity and identity, f"Invalid source: {source}"
2128
- source_types[identity] = cast(
2129
- SourceInfo,
2130
- {
2131
- "type": None,
2132
- "state": "",
2133
- "columns_hash": None,
2134
- "table_created_at": None,
2135
- "stream_created_at": None,
2136
- "last_ddl": None,
2137
- },
2138
- )
2139
- partitioned_sources[db][schema].append({"entity": entity, "identity": identity})
2140
- fqn_to_parts[identity] = (db, schema, entity)
2141
-
2142
- if not partitioned_sources:
2143
- return source_types
2144
-
2145
- state_queries: list[str] = []
2146
- for db, schemas in partitioned_sources.items():
2147
- select_rows: list[str] = []
2148
- for schema, tables in schemas.items():
2149
- for table_info in tables:
2150
- select_rows.append(
2151
- "SELECT "
2152
- f"{IdentityParser.to_sql_value(db)} AS catalog_name, "
2153
- f"{IdentityParser.to_sql_value(schema)} AS schema_name, "
2154
- f"{IdentityParser.to_sql_value(table_info['entity'])} AS table_name"
2155
- )
2156
-
2157
- if not select_rows:
2158
- continue
2159
-
2160
- target_entities_clause = "\n UNION ALL\n ".join(select_rows)
2161
- # Main query:
2162
- # 1. Enumerate the target tables via target_entities.
2163
- # 2. Pull their metadata (last_altered, type) from INFORMATION_SCHEMA.TABLES.
2164
- # 3. Look up the most recent stream activity for those FQNs only.
2165
- # 4. Capture creation timestamps and use last_ddl vs created_at to classify each target,
2166
- # so we mark tables as stale when they were recreated even if column hashes still match.
2167
- state_queries.append(
2168
- f"""WITH target_entities AS (
2169
- {target_entities_clause}
2170
- ),
2171
- table_info AS (
2172
- SELECT
2173
- {app_name}.api.normalize_fq_ids(
2174
- ARRAY_CONSTRUCT(
2175
- CASE
2176
- WHEN t.table_catalog = UPPER(t.table_catalog) THEN t.table_catalog
2177
- ELSE '"' || t.table_catalog || '"'
2178
- END || '.' ||
2179
- CASE
2180
- WHEN t.table_schema = UPPER(t.table_schema) THEN t.table_schema
2181
- ELSE '"' || t.table_schema || '"'
2182
- END || '.' ||
2183
- CASE
2184
- WHEN t.table_name = UPPER(t.table_name) THEN t.table_name
2185
- ELSE '"' || t.table_name || '"'
2186
- END
2187
- )
2188
- )[0]:identifier::string AS fqn,
2189
- CONVERT_TIMEZONE('UTC', t.last_altered) AS last_ddl,
2190
- CONVERT_TIMEZONE('UTC', t.created) AS table_created_at,
2191
- t.table_type AS kind
2192
- FROM {db}.INFORMATION_SCHEMA.tables t
2193
- JOIN target_entities te
2194
- ON t.table_catalog = te.catalog_name
2195
- AND t.table_schema = te.schema_name
2196
- AND t.table_name = te.table_name
2197
- ),
2198
- stream_activity AS (
2199
- SELECT
2200
- sa.fqn,
2201
- MAX(sa.created_at) AS created_at
2202
- FROM (
2203
- SELECT
2204
- {app_name}.api.normalize_fq_ids(ARRAY_CONSTRUCT(fq_object_name))[0]:identifier::string AS fqn,
2205
- created_at
2206
- FROM {app_name}.api.data_streams
2207
- WHERE rai_database = '{PYREL_ROOT_DB}'
2208
- ) sa
2209
- JOIN table_info ti
2210
- ON sa.fqn = ti.fqn
2211
- GROUP BY sa.fqn
2212
- )
2213
- SELECT
2214
- ti.fqn,
2215
- ti.kind,
2216
- ti.last_ddl,
2217
- ti.table_created_at,
2218
- sa.created_at AS stream_created_at,
2219
- IFF(
2220
- DATEDIFF(second, sa.created_at::timestamp, ti.last_ddl::timestamp) > 0,
2221
- 'STALE',
2222
- 'CURRENT'
2223
- ) AS state
2224
- FROM table_info ti
2225
- LEFT JOIN stream_activity sa
2226
- ON sa.fqn = ti.fqn
2227
- """
2228
- )
2229
-
2230
- stale_fqns: list[str] = []
2231
- for state_query in state_queries:
2232
- for row in self._exec(state_query):
2233
- row_dict = row.as_dict() if hasattr(row, "as_dict") else dict(row)
2234
- row_fqn = row_dict["FQN"]
2235
- parser = IdentityParser(row_fqn, True)
2236
- fqn = parser.identity
2237
- assert fqn, f"Error parsing returned FQN: {row_fqn}"
2238
-
2239
- source_types[fqn]["type"] = (
2240
- "TABLE" if row_dict["KIND"] == "BASE TABLE" else row_dict["KIND"]
2241
- )
2242
- source_types[fqn]["state"] = row_dict["STATE"]
2243
- source_types[fqn]["last_ddl"] = normalize_datetime(row_dict.get("LAST_DDL"))
2244
- source_types[fqn]["table_created_at"] = normalize_datetime(row_dict.get("TABLE_CREATED_AT"))
2245
- source_types[fqn]["stream_created_at"] = normalize_datetime(row_dict.get("STREAM_CREATED_AT"))
2246
- if row_dict["STATE"] == "STALE":
2247
- stale_fqns.append(fqn)
2248
-
2249
- if not stale_fqns:
2250
- return source_types
2251
-
2252
- # We batch stale tables by database/schema so each Snowflake query can hash
2253
- # multiple objects at once instead of issuing one statement per table.
2254
- stale_partitioned: dict[str, dict[str, list[dict[str, str]]]] = defaultdict(
2255
- lambda: defaultdict(list)
2256
- )
2257
- for fqn in stale_fqns:
2258
- db, schema, table = fqn_to_parts[fqn]
2259
- stale_partitioned[db][schema].append({"table": table, "identity": fqn})
2260
-
2261
- # Build one hash query per database, grouping schemas/tables inside so we submit
2262
- # at most a handful of set-based statements to Snowflake.
2263
- for db, schemas in stale_partitioned.items():
2264
- column_select_rows: list[str] = []
2265
- for schema, tables in schemas.items():
2266
- for table_info in tables:
2267
- # Build the literal rows for this db/schema so we can join back
2268
- # against INFORMATION_SCHEMA.COLUMNS in a single statement.
2269
- column_select_rows.append(
2270
- "SELECT "
2271
- f"{IdentityParser.to_sql_value(db)} AS catalog_name, "
2272
- f"{IdentityParser.to_sql_value(schema)} AS schema_name, "
2273
- f"{IdentityParser.to_sql_value(table_info['table'])} AS table_name"
2274
- )
2275
-
2276
- if not column_select_rows:
2277
- continue
2278
-
2279
- target_entities_clause = "\n UNION ALL\n ".join(column_select_rows)
2280
- # Main query: compute deterministic column hashes for every stale table
2281
- # in this database/schema batch so we can compare schemas without a round trip per table.
2282
- column_query = f"""WITH target_entities AS (
2283
- {target_entities_clause}
2284
- ),
2285
- column_info AS (
2286
- SELECT
2287
- {app_name}.api.normalize_fq_ids(
2288
- ARRAY_CONSTRUCT(
2289
- CASE
2290
- WHEN c.table_catalog = UPPER(c.table_catalog) THEN c.table_catalog
2291
- ELSE '"' || c.table_catalog || '"'
2292
- END || '.' ||
2293
- CASE
2294
- WHEN c.table_schema = UPPER(c.table_schema) THEN c.table_schema
2295
- ELSE '"' || c.table_schema || '"'
2296
- END || '.' ||
2297
- CASE
2298
- WHEN c.table_name = UPPER(c.table_name) THEN c.table_name
2299
- ELSE '"' || c.table_name || '"'
2300
- END
2301
- )
2302
- )[0]:identifier::string AS fqn,
2303
- c.column_name,
2304
- CASE
2305
- WHEN c.numeric_precision IS NOT NULL AND c.numeric_scale IS NOT NULL
2306
- THEN c.data_type || '(' || c.numeric_precision || ',' || c.numeric_scale || ')'
2307
- WHEN c.datetime_precision IS NOT NULL
2308
- THEN c.data_type || '(0,' || c.datetime_precision || ')'
2309
- WHEN c.character_maximum_length IS NOT NULL
2310
- THEN c.data_type || '(' || c.character_maximum_length || ')'
2311
- ELSE c.data_type
2312
- END AS type_signature,
2313
- IFF(c.is_nullable = 'YES', 'YES', 'NO') AS nullable_flag
2314
- FROM {db}.INFORMATION_SCHEMA.COLUMNS c
2315
- JOIN target_entities te
2316
- ON c.table_catalog = te.catalog_name
2317
- AND c.table_schema = te.schema_name
2318
- AND c.table_name = te.table_name
2319
- )
2320
- SELECT
2321
- fqn,
2322
- HEX_ENCODE(
2323
- HASH_AGG(
2324
- HASH(
2325
- column_name,
2326
- type_signature,
2327
- nullable_flag
2328
- )
2329
- )
2330
- ) AS columns_hash
2331
- FROM column_info
2332
- GROUP BY fqn
2333
- """
2334
-
2335
- for row in self._exec(column_query):
2336
- row_fqn = row["FQN"]
2337
- parser = IdentityParser(row_fqn, True)
2338
- fqn = parser.identity
2339
- assert fqn, f"Error parsing returned FQN: {row_fqn}"
2340
- source_types[fqn]["columns_hash"] = row["COLUMNS_HASH"]
2341
-
2342
- return source_types
2343
-
2344
- def _get_source_references(self, source_info: dict[str, SourceInfo]):
2345
- app_name = self.get_app_name()
2346
- missing_sources = []
2347
- invalid_sources = {}
2348
- source_references = []
2349
- for source, info in source_info.items():
2350
- source_type = info.get("type")
2351
- if source_type is None:
2352
- missing_sources.append(source)
2353
- elif source_type not in ("TABLE", "VIEW"):
2354
- invalid_sources[source] = source_type
2355
- else:
2356
- source_references.append(f"{app_name}.api.object_reference('{source_type}', '{source}')")
2357
-
2358
- if missing_sources:
2359
- current_role = self.get_sf_session().get_current_role()
2360
- if current_role is None:
2361
- current_role = self.config.get("role", None)
2362
- debugging.warn(UnknownSourceWarning(missing_sources, current_role))
2363
-
2364
- if invalid_sources:
2365
- debugging.warn(InvalidSourceTypeWarning(invalid_sources))
2366
-
2367
- self.source_references = source_references
2368
- return source_references
2369
-
2370
- #--------------------------------------------------
2371
- # Transactions
2372
- #--------------------------------------------------
2373
- def txn_list_to_dicts(self, transactions):
2374
- dicts = []
2375
- for txn in transactions:
2376
- dict = {}
2377
- txn_dict = txn.asDict()
2378
- for key in txn_dict:
2379
- mapValue = FIELD_MAP.get(key.lower())
2380
- if mapValue:
2381
- dict[mapValue] = txn_dict[key]
2382
- else:
2383
- dict[key.lower()] = txn_dict[key]
2384
- dicts.append(dict)
2385
- return dicts
2386
-
2387
- def get_transaction(self, transaction_id):
2388
- results = self._exec(
2389
- f"CALL {APP_NAME}.api.get_transaction(?);", [transaction_id])
2390
- if not results:
2391
- return None
2392
-
2393
- results = self.txn_list_to_dicts(results)
2394
-
2395
- txn = {field: results[0][field] for field in GET_TXN_SQL_FIELDS}
2396
-
2397
- state = txn.get("state")
2398
- created_on = txn.get("created_on")
2399
- finished_at = txn.get("finished_at")
2400
- if created_on:
2401
- # Transaction is still running
2402
- if state not in TERMINAL_TXN_STATES:
2403
- tz_info = created_on.tzinfo
2404
- txn['duration'] = datetime.now(tz_info) - created_on
2405
- # Transaction is terminal
2406
- elif finished_at:
2407
- txn['duration'] = finished_at - created_on
2408
- # Transaction is still running and we have no state or finished_at
2409
- else:
2410
- txn['duration'] = timedelta(0)
2411
- return txn
2412
-
2413
- def list_transactions(self, **kwargs):
2414
- id = kwargs.get("id", None)
2415
- state = kwargs.get("state", None)
2416
- engine = kwargs.get("engine", None)
2417
- limit = kwargs.get("limit", 100)
2418
- all_users = kwargs.get("all_users", False)
2419
- created_by = kwargs.get("created_by", None)
2420
- only_active = kwargs.get("only_active", False)
2421
- where_clause_arr = []
2422
-
2423
- if id:
2424
- where_clause_arr.append(f"id = '{id}'")
2425
- if state:
2426
- where_clause_arr.append(f"state = '{state.upper()}'")
2427
- if engine:
2428
- where_clause_arr.append(f"LOWER(engine_name) = '{engine.lower()}'")
2429
- else:
2430
- if only_active:
2431
- where_clause_arr.append("state in ('CREATED', 'RUNNING', 'PENDING')")
2432
- if not all_users and created_by is not None:
2433
- where_clause_arr.append(f"LOWER(created_by) = '{created_by.lower()}'")
2434
-
2435
- if len(where_clause_arr):
2436
- where_clause = f'WHERE {" AND ".join(where_clause_arr)}'
2437
- else:
2438
- where_clause = ""
2439
-
2440
- sql_fields = ", ".join(LIST_TXN_SQL_FIELDS)
2441
- query = f"SELECT {sql_fields} from {APP_NAME}.api.transactions {where_clause} ORDER BY created_on DESC LIMIT ?"
2442
- results = self._exec(query, [limit])
2443
- if not results:
2444
- return []
2445
- return self.txn_list_to_dicts(results)
2446
-
2447
- def cancel_transaction(self, transaction_id):
2448
- self._exec(f"CALL {APP_NAME}.api.cancel_own_transaction(?);", [transaction_id])
2449
- if transaction_id in self._pending_transactions:
2450
- self._pending_transactions.remove(transaction_id)
2451
-
2452
- def cancel_pending_transactions(self):
2453
- for txn_id in self._pending_transactions:
2454
- self.cancel_transaction(txn_id)
2455
-
2456
- def get_transaction_events(self, transaction_id: str, continuation_token:str=''):
2457
- results = self._exec(
2458
- f"SELECT {APP_NAME}.api.get_own_transaction_events(?, ?);",
2459
- [transaction_id, continuation_token],
2460
- )
2461
- if not results:
2462
- return {
2463
- "events": [],
2464
- "continuation_token": None
2465
- }
2466
- row = results[0][0]
2467
- return json.loads(row)
2468
-
2469
- #--------------------------------------------------
2470
- # Snowflake specific
2471
- #--------------------------------------------------
2472
-
2473
- def get_version(self):
2474
- results = self._exec(f"SELECT {APP_NAME}.app.get_release()")
2475
- if not results:
2476
- return None
2477
- return results[0][0]
2478
-
2479
- def list_warehouses(self):
2480
- results = self._exec("SHOW WAREHOUSES")
2481
- if not results:
2482
- return []
2483
- return [{"name":name}
2484
- for (name, *rest) in results]
2485
-
2486
- def list_compute_pools(self):
2487
- results = self._exec("SHOW COMPUTE POOLS")
2488
- if not results:
2489
- return []
2490
- return [{"name":name, "status":status, "min_nodes":min_nodes, "max_nodes":max_nodes, "instance_family":instance_family}
2491
- for (name, status, min_nodes, max_nodes, instance_family, *rest) in results]
2492
-
2493
- def list_roles(self):
2494
- results = self._exec("SELECT CURRENT_AVAILABLE_ROLES()")
2495
- if not results:
2496
- return []
2497
- # the response is a single row with a single column containing
2498
- # a stringified JSON array of role names:
2499
- row = results[0]
2500
- if not row:
2501
- return []
2502
- return [{"name": name} for name in json.loads(row[0])]
2503
-
2504
- def list_apps(self):
2505
- all_apps = self._exec(f"SHOW APPLICATIONS LIKE '{RAI_APP_NAME}'")
2506
- if not all_apps:
2507
- all_apps = self._exec("SHOW APPLICATIONS")
2508
- if not all_apps:
2509
- return []
2510
- return [{"name":name}
2511
- for (time, name, *rest) in all_apps]
2512
-
2513
- def list_databases(self):
2514
- results = self._exec("SHOW DATABASES")
2515
- if not results:
2516
- return []
2517
- return [{"name":name}
2518
- for (time, name, *rest) in results]
2519
-
2520
- def list_sf_schemas(self, database:str):
2521
- results = self._exec(f"SHOW SCHEMAS IN {database}")
2522
- if not results:
2523
- return []
2524
- return [{"name":name}
2525
- for (time, name, *rest) in results]
2526
-
2527
- def list_tables(self, database:str, schema:str):
2528
- results = self._exec(f"SHOW OBJECTS IN {database}.{schema}")
2529
- items = []
2530
- if results:
2531
- for (time, name, db_name, schema_name, kind, *rest) in results:
2532
- items.append({"name":name, "kind":kind.lower()})
2533
- return items
2534
-
2535
- def schema_info(self, database:str, schema:str, tables:Iterable[str]):
2536
- app_name = self.get_app_name()
2537
- # Only pass the db + schema as the identifier so that the resulting identity is correct
2538
- parser = IdentityParser(f"{database}.{schema}")
2539
-
2540
- with debugging.span("schema_info"):
2541
- with debugging.span("primary_keys") as span:
2542
- pk_query = f"SHOW PRIMARY KEYS IN SCHEMA {parser.identity};"
2543
- pks = self._exec(pk_query)
2544
- span["sql"] = pk_query
2545
-
2546
- with debugging.span("foreign_keys") as span:
2547
- fk_query = f"SHOW IMPORTED KEYS IN SCHEMA {parser.identity};"
2548
- fks = self._exec(fk_query)
2549
- span["sql"] = fk_query
2550
-
2551
- # IdentityParser will parse a single value (with no ".") and store it in this case in the db field
2552
- with debugging.span("columns") as span:
2553
- tables = ", ".join([f"'{IdentityParser(t).db}'" for t in tables])
2554
- query = textwrap.dedent(f"""
2555
- begin
2556
- SHOW COLUMNS IN SCHEMA {parser.identity};
2557
- let r resultset := (
2558
- SELECT
2559
- CASE
2560
- WHEN "table_name" = UPPER("table_name") THEN "table_name"
2561
- ELSE '"' || "table_name" || '"'
2562
- END as "table_name",
2563
- "column_name",
2564
- "data_type",
2565
- CASE
2566
- WHEN ARRAY_CONTAINS(PARSE_JSON("data_type"):"type", {app_name}.app.get_supported_column_types()) THEN TRUE
2567
- ELSE FALSE
2568
- END as "supported_type"
2569
- FROM table(result_scan(-1)) as t
2570
- WHERE "table_name" in ({tables})
2571
- );
2572
- return table(r);
2573
- end;
2574
- """)
2575
- span["sql"] = query
2576
- columns = self._exec(query)
2577
-
2578
- results = defaultdict(lambda: {"pks": [], "fks": {}, "columns": {}, "invalid_columns": {}})
2579
- if pks:
2580
- for row in pks:
2581
- results[row[3]]["pks"].append(row[4]) # type: ignore
2582
- if fks:
2583
- for row in fks:
2584
- results[row[7]]["fks"][row[8]] = row[3]
2585
- if columns:
2586
- # It seems that a SF parameter (QUOTED_IDENTIFIERS_IGNORE_CASE) can control
2587
- # whether snowflake will ignore case on `row.data_type`,
2588
- # so we have to use column indexes instead :(
2589
- for row in columns:
2590
- table_name = row[0]
2591
- column_name = row[1]
2592
- data_type = row[2]
2593
- supported_type = row[3]
2594
- # Filter out unsupported types
2595
- if supported_type:
2596
- results[table_name]["columns"][column_name] = data_type
2597
- else:
2598
- results[table_name]["invalid_columns"][column_name] = data_type
2599
- return results
2600
-
2601
- def get_cloud_provider(self) -> str:
2602
- """
2603
- Detect whether this is Snowflake on Azure, or AWS using Snowflake's CURRENT_REGION().
2604
- Returns 'azure' or 'aws'.
2605
- """
2606
- if self._session:
2607
- try:
2608
- # Query Snowflake's current region using the built-in function
2609
- result = self._session.sql("SELECT CURRENT_REGION()").collect()
2610
- if result:
2611
- region_info = result[0][0]
2612
- # Check if the region string contains the cloud provider name
2613
- if isinstance(region_info, str):
2614
- region_str = region_info.lower()
2615
- # Check for cloud providers in the region string
2616
- if 'azure' in region_str:
2617
- return 'azure'
2618
- else:
2619
- return 'aws'
2620
- except Exception:
2621
- pass
2622
-
2623
- # Fallback to AWS as default if detection fails
2624
- return 'aws'
2625
-
2626
- #--------------------------------------------------
2627
- # Snowflake Wrapper
2628
- #--------------------------------------------------
2629
-
2630
- class PrimaryKey:
2631
- pass
2632
-
2633
- class _Snowflake:
2634
- def __init__(self, model, auto_import=False):
2635
- self._model = model
2636
- self._auto_import = auto_import
2637
- if not isinstance(model._client.resources, Resources):
2638
- raise ValueError("Snowflake model must be used with a snowflake config")
2639
- self._dbs = {}
2640
- imports = model._client.resources.list_imports(model=model.name)
2641
- self._import_structure(imports)
2642
-
2643
- def _import_structure(self, imports: list[Import]):
2644
- tree = self._dbs
2645
- # pre-create existing imports
2646
- schemas = set()
2647
- for item in imports:
2648
- parser = IdentityParser(item["name"])
2649
- database_name, schema_name, table_name = parser.to_list()[:-1]
2650
- database = getattr(self, database_name)
2651
- schema = getattr(database, schema_name)
2652
- schemas.add(schema)
2653
- schema._add(table_name, is_imported=True)
2654
- return tree
2655
-
2656
- def _safe_get(self, name:str) -> 'SnowflakeDB':
2657
- name = name
2658
- if name in self._dbs:
2659
- return self._dbs[name]
2660
- self._dbs[name] = SnowflakeDB(self, name)
2661
- return self._dbs[name]
2662
-
2663
- def __getattr__(self, name: str) -> 'SnowflakeDB':
2664
- return self._safe_get(name)
2665
-
2666
-
2667
- class Snowflake(_Snowflake):
2668
- def __init__(self, model: dsl.Graph, auto_import=False):
2669
- if model._config.get_bool("use_graph_index", USE_GRAPH_INDEX):
2670
- raise SnowflakeProxySourceError()
2671
- else:
2672
- debugging.warn(SnowflakeProxyAPIDeprecationWarning())
2673
-
2674
- super().__init__(model, auto_import)
2675
-
2676
- class SnowflakeDB:
2677
- def __init__(self, parent, name):
2678
- self._name = name
2679
- self._parent = parent
2680
- self._model = parent._model
2681
- self._schemas = {}
2682
-
2683
- def _safe_get(self, name: str) -> 'SnowflakeSchema':
2684
- name = name
2685
- if name in self._schemas:
2686
- return self._schemas[name]
2687
- self._schemas[name] = SnowflakeSchema(self, name)
2688
- return self._schemas[name]
2689
-
2690
- def __getattr__(self, name: str) -> 'SnowflakeSchema':
2691
- return self._safe_get(name)
2692
-
2693
- class SnowflakeSchema:
2694
- def __init__(self, parent, name):
2695
- self._name = name
2696
- self._parent = parent
2697
- self._model = parent._model
2698
- self._tables = {}
2699
- self._imported = set()
2700
- self._table_info = defaultdict(lambda: {"pks": [], "fks": {}, "columns": {}, "invalid_columns": {}})
2701
- self._dirty = True
2702
-
2703
- def _fetch_info(self):
2704
- if not self._dirty:
2705
- return
2706
- self._table_info = self._model._client.resources.schema_info(self._parent._name, self._name, list(self._tables.keys()))
2707
-
2708
- check_column_types = self._model._config.get("check_column_types", True)
2709
-
2710
- if check_column_types:
2711
- self._check_and_confirm_invalid_columns()
2712
-
2713
- self._dirty = False
2714
-
2715
- def _check_and_confirm_invalid_columns(self):
2716
- """Check for invalid columns across the schema's tables."""
2717
- tables_with_invalid_columns = {}
2718
- for table_name, table_info in self._table_info.items():
2719
- if table_info.get("invalid_columns"):
2720
- tables_with_invalid_columns[table_name] = table_info["invalid_columns"]
2721
-
2722
- if tables_with_invalid_columns:
2723
- from ..errors import UnsupportedColumnTypesWarning
2724
- UnsupportedColumnTypesWarning(tables_with_invalid_columns)
2725
-
2726
- def _add(self, name, is_imported=False):
2727
- if name in self._tables:
2728
- return self._tables[name]
2729
- self._dirty = True
2730
- if is_imported:
2731
- self._imported.add(name)
2732
- else:
2733
- self._tables[name] = SnowflakeTable(self, name)
2734
- return self._tables.get(name)
2735
-
2736
- def _safe_get(self, name: str) -> 'SnowflakeTable | None':
2737
- table = self._add(name)
2738
- return table
2739
-
2740
- def __getattr__(self, name: str) -> 'SnowflakeTable | None':
2741
- return self._safe_get(name)
2742
-
2743
-
2744
- class SnowflakeTable(dsl.Type):
2745
- def __init__(self, parent, name):
2746
- super().__init__(parent._model, f"sf_{name}")
2747
- # hack to make this work for pathfinder
2748
- self._type.parents.append(m.Builtins.PQFilterAnnotation)
2749
- self._name = name
2750
- self._model = parent._model
2751
- self._parent = parent
2752
- self._aliases = {}
2753
- self._finalzed = False
2754
- self._source = runtime_env.get_source()
2755
- relation_name = to_fqn_relation_name(self.fqname())
2756
- self._model.install_raw(f"declare {relation_name}")
2757
-
2758
- def __call__(self, *args, **kwargs):
2759
- self._lazy_init()
2760
- return super().__call__(*args, **kwargs)
2761
-
2762
- def add(self, *args, **kwargs):
2763
- self._lazy_init()
2764
- return super().add(*args, **kwargs)
2765
-
2766
- def extend(self, *args, **kwargs):
2767
- self._lazy_init()
2768
- return super().extend(*args, **kwargs)
2769
-
2770
- def known_properties(self):
2771
- self._lazy_init()
2772
- return super().known_properties()
2773
-
2774
- def _lazy_init(self):
2775
- if self._finalzed:
2776
- return
2777
-
2778
- parent = self._parent
2779
- name = self._name
2780
- use_graph_index = self._model._config.get("use_graph_index", USE_GRAPH_INDEX)
2781
-
2782
- if not use_graph_index and name not in parent._imported:
2783
- if self._parent._parent._parent._auto_import:
2784
- with Spinner(f"Creating stream for {self.fqname()}", f"Stream for {self.fqname()} created successfully"):
2785
- db_name = parent._parent._name
2786
- schema_name = parent._name
2787
- self._model._client.resources.create_import_stream(ImportSourceTable(db_name, schema_name, name), self._model.name)
2788
- print("")
2789
- parent._imported.add(name)
2790
- else:
2791
- imports = self._model._client.resources.list_imports(model=self._model.name)
2792
- for item in imports:
2793
- cur_name = item["name"].lower().split(".")[-1]
2794
- parent._imported.add(cur_name)
2795
- if name not in parent._imported:
2796
- exception = SnowflakeImportMissingException(runtime_env.get_source(), self.fqname(), self._model.name)
2797
- raise exception from None
2798
-
2799
- parent._fetch_info()
2800
- self._finalize()
2801
-
2802
- def _finalize(self):
2803
- if self._finalzed:
2804
- return
2805
-
2806
- self._finalzed = True
2807
- self._schema = self._parent._table_info[self._name]
2808
-
2809
- # Set the relation name to the sanitized version of the fully qualified name
2810
- relation_name = to_fqn_relation_name(self.fqname())
2811
-
2812
- model:dsl.Graph = self._model
2813
- edb = getattr(std.rel, relation_name)
2814
- edb._rel.parents.append(m.Builtins.EDB)
2815
- id_rel = getattr(std.rel, f"{relation_name}_pyrel_id")
2816
-
2817
- with model.rule(globalize=True, source=self._source):
2818
- id, val = dsl.create_vars(2)
2819
- edb(dsl.Symbol("METADATA$ROW_ID"), id, val)
2820
- std.rel.SHA1(id)
2821
- id_rel.add(id)
2822
-
2823
- with model.rule(dynamic=True, globalize=True, source=self._source):
2824
- prop, id, val = dsl.create_vars(3)
2825
- id_rel(id)
2826
- std.rel.SHA1(id)
2827
- self.add(snowflake_id=id)
2828
-
2829
- for prop, prop_type in self._schema["columns"].items():
2830
- _prop = prop
2831
- if _prop.startswith("_"):
2832
- _prop = "col" + prop
2833
-
2834
- prop_ident = sanitize_identifier(_prop.lower())
2835
-
2836
- with model.rule(dynamic=True, globalize=True, source=self._source):
2837
- id, val = dsl.create_vars(2)
2838
- edb(dsl.Symbol(prop), id, val)
2839
- std.rel.SHA1(id)
2840
- _prop = getattr(self, prop_ident)
2841
- if not _prop:
2842
- raise ValueError(f"Property {_prop} couldn't be accessed on {self.fqname()}")
2843
- if _prop.is_multi_valued:
2844
- inst = self(snowflake_id=id)
2845
- getattr(inst, prop_ident).add(val)
2846
- else:
2847
- self(snowflake_id=id).set(**{prop_ident: val})
2848
-
2849
- # Because we're bypassing a bunch of the normal Type.add machinery here,
2850
- # we need to manually account for the case where people are using value types.
2851
- def wrapped(x):
2852
- if not model._config.get("compiler.use_value_types", False):
2853
- return x
2854
- other_id = dsl.create_var()
2855
- model._action(dsl.build.construct(self._type, [x, other_id]))
2856
- return other_id
2857
-
2858
- # new UInt128 schema mapping rules
2859
- with model.rule(dynamic=True, globalize=True, source=self._source):
2860
- id = dsl.create_var()
2861
- # This will generate an arity mismatch warning when used with the old SHA-1 Data Streams.
2862
- # Ideally we have the `@no_diagnostics(:ARITY_MISMATCH)` attribute on the relation using
2863
- # the METADATA$KEY column but that ended up being a more involved change then expected
2864
- # for avoiding a non-blocking warning
2865
- edb(dsl.Symbol("METADATA$KEY"), id)
2866
- std.rel.UInt128(id)
2867
- self.add(wrapped(id), snowflake_id=id)
2868
-
2869
- for prop, prop_type in self._schema["columns"].items():
2870
- _prop = prop
2871
- if _prop.startswith("_"):
2872
- _prop = "col" + prop
2873
-
2874
- prop_ident = sanitize_identifier(_prop.lower())
2875
- with model.rule(dynamic=True, globalize=True, source=self._source):
2876
- id, val = dsl.create_vars(2)
2877
- edb(dsl.Symbol(prop), id, val)
2878
- std.rel.UInt128(id)
2879
- _prop = getattr(self, prop_ident)
2880
- if not _prop:
2881
- raise ValueError(f"Property {_prop} couldn't be accessed on {self.fqname()}")
2882
- if _prop.is_multi_valued:
2883
- inst = self(id)
2884
- getattr(inst, prop_ident).add(val)
2885
- else:
2886
- model._check_property(_prop._prop)
2887
- raw_relation = getattr(std.rel, prop_ident)
2888
- dsl.tag(raw_relation, dsl.Builtins.FunctionAnnotation)
2889
- raw_relation.add(wrapped(id), val)
2890
-
2891
- def namespace(self):
2892
- return f"{self._parent._parent._name}.{self._parent._name}"
2893
-
2894
- def fqname(self):
2895
- return f"{self.namespace()}.{self._name}"
2896
-
2897
- def describe(self, **kwargs):
2898
- model = self._model
2899
- for k, v in kwargs.items():
2900
- if v is PrimaryKey:
2901
- self._schema["pks"] = [k]
2902
- elif isinstance(v, tuple):
2903
- (table, name) = v
2904
- if isinstance(table, SnowflakeTable):
2905
- fk_table = table
2906
- pk = fk_table._schema["pks"]
2907
- with model.rule():
2908
- inst = fk_table()
2909
- me = self()
2910
- getattr(inst, pk[0]) == getattr(me, k)
2911
- if getattr(self, name).is_multi_valued:
2912
- getattr(me, name).add(inst)
2913
- else:
2914
- me.set(**{name: inst})
2915
- else:
2916
- raise ValueError(f"Invalid foreign key {v}")
2917
- else:
2918
- raise ValueError(f"Invalid column {k}={v}")
2919
- return self
2920
-
2921
- class Provider(ProviderBase):
2922
- def __init__(
2923
- self,
2924
- profile: str | None = None,
2925
- config: Config | None = None,
2926
- resources: Resources | None = None,
2927
- generation: Generation | None = None,
2928
- ):
2929
- if resources:
2930
- self.resources = resources
2931
- else:
2932
- resource_class = Resources
2933
- if config and config.get("use_direct_access", USE_DIRECT_ACCESS):
2934
- resource_class = DirectAccessResources
2935
- self.resources = resource_class(profile=profile, config=config, generation=generation)
2936
-
2937
- def list_streams(self, model:str):
2938
- return self.resources.list_imports(model=model)
2939
-
2940
- def create_streams(self, sources:List[str], model:str, force=False):
2941
- if not self.resources.get_graph(model):
2942
- self.resources.create_graph(model)
2943
- def parse_source(raw:str):
2944
- parser = IdentityParser(raw)
2945
- assert parser.is_complete, "Snowflake table imports must be in `database.schema.table` format"
2946
- return ImportSourceTable(*parser.to_list())
2947
- for source in sources:
2948
- source_table = parse_source(source)
2949
- try:
2950
- with Spinner(f"Creating stream for {source_table.name}", f"Stream for {source_table.name} created successfully"):
2951
- if force:
2952
- self.resources.delete_import(source_table.name, model, True)
2953
- self.resources.create_import_stream(source_table, model)
2954
- except Exception as e:
2955
- if "stream already exists" in f"{e}":
2956
- raise Exception(f"\n\nStream'{source_table.name.upper()}' already exists.")
2957
- elif "engine not found" in f"{e}":
2958
- raise Exception("\n\nNo engines found in a READY state. Please use `engines:create` to create an engine that will be used to initialize the target relation.")
2959
- else:
2960
- raise e
2961
- with Spinner("Waiting for imports to complete", "Imports complete"):
2962
- self.resources.poll_imports(sources, model)
2963
-
2964
- def delete_stream(self, stream_id: str, model: str):
2965
- return self.resources.delete_import(stream_id, model)
2966
-
2967
- def sql(self, query:str, params:List[Any]=[], format:Literal["list", "pandas", "polars", "lazy"]="list"):
2968
- # note: default format cannot be pandas because .to_pandas() only works on SELECT queries
2969
- result = self.resources._exec(query, params, raw=True, help=False)
2970
- if format == "lazy":
2971
- return cast(snowflake.snowpark.DataFrame, result)
2972
- elif format == "list":
2973
- return cast(list, result.collect())
2974
- elif format == "pandas":
2975
- import pandas as pd
2976
- try:
2977
- # use to_pandas for SELECT queries
2978
- return cast(pd.DataFrame, result.to_pandas())
2979
- except Exception:
2980
- # handle non-SELECT queries like SHOW
2981
- return pd.DataFrame(result.collect())
2982
- elif format == "polars":
2983
- import polars as pl # type: ignore
2984
- return pl.DataFrame(
2985
- [row.as_dict() for row in result.collect()],
2986
- orient="row",
2987
- strict=False,
2988
- infer_schema_length=None
2989
- )
2990
- else:
2991
- raise ValueError(f"Invalid format {format}. Should be one of 'list', 'pandas', 'polars', 'lazy'")
2992
-
2993
- def activate(self):
2994
- with Spinner("Activating RelationalAI app...", "RelationalAI app activated"):
2995
- self.sql("CALL RELATIONALAI.APP.ACTIVATE();")
2996
-
2997
- def deactivate(self):
2998
- with Spinner("Deactivating RelationalAI app...", "RelationalAI app deactivated"):
2999
- self.sql("CALL RELATIONALAI.APP.DEACTIVATE();")
3000
-
3001
- def drop_service(self):
3002
- warnings.warn(
3003
- "The drop_service method has been deprecated in favor of deactivate",
3004
- DeprecationWarning,
3005
- stacklevel=2,
3006
- )
3007
- self.deactivate()
3008
-
3009
- def resume_service(self):
3010
- warnings.warn(
3011
- "The resume_service method has been deprecated in favor of activate",
3012
- DeprecationWarning,
3013
- stacklevel=2,
3014
- )
3015
- self.activate()
3016
-
3017
-
3018
- #--------------------------------------------------
3019
- # SnowflakeClient
3020
- #--------------------------------------------------
3021
- class SnowflakeClient(Client):
3022
- def create_database(self, isolated=True, nowait_durable=True, headers: Dict | None = None):
3023
- from v0.relationalai.tools.cli_helpers import validate_engine_name
3024
-
3025
- assert isinstance(self.resources, Resources)
3026
-
3027
- if self.last_database_version == len(self.resources.sources):
3028
- return
3029
-
3030
- model = self._source_database
3031
- app_name = self.resources.get_app_name()
3032
- engine_name = self.resources.get_default_engine_name()
3033
- engine_size = self.resources.config.get_default_engine_size()
3034
-
3035
- # Validate engine name
3036
- is_name_valid, _ = validate_engine_name(engine_name)
3037
- if not is_name_valid:
3038
- raise EngineNameValidationException(engine_name)
3039
-
3040
- # Validate engine size
3041
- valid_sizes = self.resources.get_engine_sizes()
3042
- if not isinstance(engine_size, str) or engine_size not in valid_sizes:
3043
- raise InvalidEngineSizeError(str(engine_size), valid_sizes)
3044
-
3045
- program_span_id = debugging.get_program_span_id()
3046
-
3047
- query_attrs_dict = json.loads(headers.get("X-Query-Attributes", "{}")) if headers else {}
3048
- with debugging.span("poll_use_index", sources=self.resources.sources, model=model, engine=engine_name, **query_attrs_dict):
3049
- self.maybe_poll_use_index(
3050
- app_name=app_name,
3051
- sources=self.resources.sources,
3052
- model=model,
3053
- engine_name=engine_name,
3054
- engine_size=engine_size,
3055
- program_span_id=program_span_id,
3056
- headers=headers
3057
- )
3058
-
3059
- self.last_database_version = len(self.resources.sources)
3060
- self._manage_packages()
3061
-
3062
- if isolated and not self.keep_model:
3063
- atexit.register(self.delete_database)
3064
-
3065
- def maybe_poll_use_index(
3066
- self,
3067
- app_name: str,
3068
- sources: Iterable[str],
3069
- model: str,
3070
- engine_name: str,
3071
- engine_size: str | None = None,
3072
- program_span_id: str | None = None,
3073
- headers: Dict | None = None,
3074
- ):
3075
- """Only call _poll_use_index if there are sources to process."""
3076
- assert isinstance(self.resources, Resources)
3077
- return self.resources.maybe_poll_use_index(
3078
- app_name=app_name,
3079
- sources=sources,
3080
- model=model,
3081
- engine_name=engine_name,
3082
- engine_size=engine_size,
3083
- program_span_id=program_span_id,
3084
- headers=headers
3085
- )
3086
-
3087
-
3088
- #--------------------------------------------------
3089
- # Graph
3090
- #--------------------------------------------------
3091
-
3092
- def Graph(
3093
- name,
3094
- *,
3095
- profile: str | None = None,
3096
- config: Config,
3097
- dry_run: bool = False,
3098
- isolated: bool = True,
3099
- connection: Session | None = None,
3100
- keep_model: bool = False,
3101
- nowait_durable: bool = True,
3102
- format: str = "default",
3103
- ):
3104
-
3105
- client_class = Client
3106
- resource_class = Resources
3107
- use_graph_index = config.get("use_graph_index", USE_GRAPH_INDEX)
3108
- use_monotype_operators = config.get("compiler.use_monotype_operators", False)
3109
- use_direct_access = config.get("use_direct_access", USE_DIRECT_ACCESS)
3110
-
3111
- if use_graph_index:
3112
- client_class = SnowflakeClient
3113
- if use_direct_access:
3114
- resource_class = DirectAccessResources
3115
- client = client_class(
3116
- resource_class(generation=Generation.V0, profile=profile, config=config, connection=connection),
3117
- rel.Compiler(config),
3118
- name,
3119
- config,
3120
- dry_run=dry_run,
3121
- isolated=isolated,
3122
- keep_model=keep_model,
3123
- nowait_durable=nowait_durable
3124
- )
3125
- base_rel = """
3126
- @inline
3127
- def make_identity(x..., z):
3128
- rel_primitive_hash_tuple_uint128(x..., z)
3129
-
3130
- @inline
3131
- def pyrel_default({F}, c, k..., v):
3132
- F(k..., v) or (not F(k..., _) and v = c)
3133
-
3134
- @inline
3135
- def pyrel_unwrap(x in UInt128, y): y = x
3136
-
3137
- @inline
3138
- def pyrel_dates_period_days(x in Date, y in Date, z in Int):
3139
- exists((u) | dates_period_days(x, y , u) and u = ::std::common::^Day[z])
3140
-
3141
- @inline
3142
- def pyrel_datetimes_period_milliseconds(x in DateTime, y in DateTime, z in Int):
3143
- exists((u) | datetimes_period_milliseconds(x, y , u) and u = ^Millisecond[z])
3144
-
3145
- @inline
3146
- def pyrel_bool_filter(a, b, {F}, z): { z = if_then_else[F(a, b), boolean_true, boolean_false] }
3147
-
3148
- @inline
3149
- def pyrel_strftime(v, fmt, tz in String, s in String):
3150
- (Date(v) and s = format_date[v, fmt])
3151
- or (DateTime(v) and s = format_datetime[v, fmt, tz])
3152
-
3153
- @inline
3154
- def pyrel_regex_match_all(pattern, string in String, pos in Int, offset in Int, match in String):
3155
- regex_match_all(pattern, string, offset, match) and offset >= pos
3156
-
3157
- @inline
3158
- def pyrel_regex_match(pattern, string in String, pos in Int, offset in Int, match in String):
3159
- pyrel_regex_match_all(pattern, string, pos, offset, match) and offset = pos
3160
-
3161
- @inline
3162
- def pyrel_regex_search(pattern, string in String, pos in Int, offset in Int, match in String):
3163
- enumerate(pyrel_regex_match_all[pattern, string, pos], 1, offset, match)
3164
-
3165
- @inline
3166
- def pyrel_regex_sub(pattern, repl in String, string in String, result in String):
3167
- string_replace_multiple(string, {(last[regex_match_all[pattern, string]], repl)}, result)
3168
-
3169
- @inline
3170
- def pyrel_capture_group(regex in Pattern, string in String, pos in Int, index, match in String):
3171
- (Integer(index) and capture_group_by_index(regex, string, pos, index, match)) or
3172
- (String(index) and capture_group_by_name(regex, string, pos, index, match))
3173
-
3174
- declare __resource
3175
- declare __compiled_patterns
3176
- """
3177
- if use_monotype_operators:
3178
- base_rel += """
3179
-
3180
- // use monotyped operators
3181
- from ::std::monotype import +, -, *, /, <, <=, >, >=
3182
- """
3183
- pyrel_base = dsl.build.raw_task(base_rel)
3184
- debugging.set_source(pyrel_base)
3185
- client.install("pyrel_base", pyrel_base)
3186
- return dsl.Graph(client, name, format=format)
3187
-
3188
-
3189
-
3190
- #--------------------------------------------------
3191
- # Direct Access
3192
- #--------------------------------------------------
3193
- # Note: All direct access components should live in a separate file
3194
-
3195
- class DirectAccessResources(Resources):
3196
- """
3197
- Resources class for Direct Service Access avoiding Snowflake service functions.
3198
- """
3199
- def __init__(
3200
- self,
3201
- profile: Union[str, None] = None,
3202
- config: Union[Config, None] = None,
3203
- connection: Union[Session, None] = None,
3204
- dry_run: bool = False,
3205
- reset_session: bool = False,
3206
- generation: Optional[Generation] = None,
3207
- language: str = "rel",
3208
- ):
3209
- super().__init__(
3210
- generation=generation,
3211
- profile=profile,
3212
- config=config,
3213
- connection=connection,
3214
- reset_session=reset_session,
3215
- dry_run=dry_run,
3216
- language=language,
3217
- )
3218
- self._endpoint_info = ConfigStore(ENDPOINT_FILE)
3219
- self._service_endpoint = ""
3220
- self._direct_access_client = None
3221
- self.generation = generation
3222
- self.database = ""
3223
-
3224
- @property
3225
- def service_endpoint(self) -> str:
3226
- return self._retrieve_service_endpoint()
3227
-
3228
- def _retrieve_service_endpoint(self, enforce_update=False) -> str:
3229
- account = self.config.get("account")
3230
- app_name = self.config.get("rai_app_name")
3231
- service_endpoint_key = f"{account}.{app_name}.service_endpoint"
3232
- if self._service_endpoint and not enforce_update:
3233
- return self._service_endpoint
3234
- if self._endpoint_info.get(service_endpoint_key, "") and not enforce_update:
3235
- self._service_endpoint = str(self._endpoint_info.get(service_endpoint_key, ""))
3236
- return self._service_endpoint
3237
-
3238
- is_snowflake_notebook = isinstance(runtime_env, SnowbookEnvironment)
3239
- query = f"CALL {self.get_app_name()}.app.service_endpoint({not is_snowflake_notebook});"
3240
- result = self._exec(query)
3241
- assert result, f"Could not retrieve service endpoint for {self.get_app_name()}"
3242
- if is_snowflake_notebook:
3243
- self._service_endpoint = f"http://{result[0]['SERVICE_ENDPOINT']}"
3244
- else:
3245
- self._service_endpoint = f"https://{result[0]['SERVICE_ENDPOINT']}"
3246
-
3247
- self._endpoint_info.set(service_endpoint_key, self._service_endpoint)
3248
- # save the endpoint to `ENDPOINT_FILE` to avoid calling the endpoint with every
3249
- # pyrel execution
3250
- try:
3251
- self._endpoint_info.save()
3252
- except Exception:
3253
- print("Failed to persist endpoints to file. This might slow down future executions.")
3254
-
3255
- return self._service_endpoint
3256
-
3257
- @property
3258
- def direct_access_client(self) -> DirectAccessClient:
3259
- if self._direct_access_client:
3260
- return self._direct_access_client
3261
- try:
3262
- service_endpoint = self.service_endpoint
3263
- self._direct_access_client = DirectAccessClient(
3264
- self.config, self.token_handler, service_endpoint, self.generation,
3265
- )
3266
- except Exception as e:
3267
- raise e
3268
- return self._direct_access_client
3269
-
3270
- def request(
3271
- self,
3272
- endpoint: str,
3273
- payload: Dict[str, Any] | None = None,
3274
- headers: Dict[str, str] | None = None,
3275
- path_params: Dict[str, str] | None = None,
3276
- query_params: Dict[str, str] | None = None,
3277
- skip_auto_create: bool = False,
3278
- skip_engine_db_error_retry: bool = False,
3279
- ) -> requests.Response:
3280
- with debugging.span("direct_access_request"):
3281
- def _send_request():
3282
- return self.direct_access_client.request(
3283
- endpoint=endpoint,
3284
- payload=payload,
3285
- headers=headers,
3286
- path_params=path_params,
3287
- query_params=query_params,
3288
- )
3289
- try:
3290
- response = _send_request()
3291
- if response.status_code != 200:
3292
- # For 404 responses with skip_auto_create=True, return immediately to let caller handle it
3293
- # (e.g., get_engine needs to check 404 and return None for auto_create_engine)
3294
- # For skip_auto_create=False, continue to auto-creation logic below
3295
- if response.status_code == 404 and skip_auto_create:
3296
- return response
3297
-
3298
- try:
3299
- message = response.json().get("message", "")
3300
- except requests.exceptions.JSONDecodeError:
3301
- # Can't parse JSON response. For skip_auto_create=True (e.g., get_engine),
3302
- # this should have been caught by the 404 check above, so this is an error.
3303
- # For skip_auto_create=False, we explicitly check status_code below,
3304
- # so we don't need to parse the message.
3305
- if skip_auto_create:
3306
- raise ResponseStatusException(
3307
- f"Failed to parse error response from endpoint {endpoint}.", response
3308
- )
3309
- message = "" # Not used when we check status_code directly
3310
-
3311
- # fix engine on engine error and retry
3312
- # Skip setting up GI if skip_auto_create is True to avoid recursion or skip_engine_db_error_retry is true to let _exec_async_v2 perform the retry with the correct headers.
3313
- if ((_is_engine_issue(message) and not skip_auto_create) or _is_database_issue(message)) and not skip_engine_db_error_retry:
3314
- engine_name = payload.get("caller_engine_name", "") if payload else ""
3315
- engine_name = engine_name or self.get_default_engine_name()
3316
- engine_size = self.config.get_default_engine_size()
3317
- self._poll_use_index(
3318
- app_name=self.get_app_name(),
3319
- sources=self.sources,
3320
- model=self.database,
3321
- engine_name=engine_name,
3322
- engine_size=engine_size,
3323
- headers=headers,
3324
- )
3325
- response = _send_request()
3326
- except requests.exceptions.ConnectionError as e:
3327
- if "NameResolutionError" in str(e):
3328
- # when we can not resolve the service endpoint, we assume it is outdated
3329
- # hence, we try to retrieve it again and query again.
3330
- self.direct_access_client.service_endpoint = self._retrieve_service_endpoint(
3331
- enforce_update=True,
3332
- )
3333
- return _send_request()
3334
- # raise in all other cases
3335
- raise e
3336
- return response
3337
-
3338
- def _txn_request_with_gi_retry(
3339
- self,
3340
- payload: Dict,
3341
- headers: Dict[str, str],
3342
- query_params: Dict,
3343
- engine: Union[str, None],
3344
- ):
3345
- """Make request with graph index retry logic.
3346
-
3347
- Attempts request with gi_setup_skipped=True first. If an engine or database
3348
- issue occurs, polls use_index and retries with gi_setup_skipped=False.
3349
- """
3350
- response = self.request(
3351
- "create_txn", payload=payload, headers=headers, query_params=query_params, skip_auto_create=True, skip_engine_db_error_retry=True
3352
- )
3353
-
3354
- if response.status_code != 200:
3355
- try:
3356
- message = response.json().get("message", "")
3357
- except requests.exceptions.JSONDecodeError:
3358
- message = ""
3359
-
3360
- if _is_engine_issue(message) or _is_database_issue(message):
3361
- engine_name = engine or self.get_default_engine_name()
3362
- engine_size = self.config.get_default_engine_size()
3363
- self._poll_use_index(
3364
- app_name=self.get_app_name(),
3365
- sources=self.sources,
3366
- model=self.database,
3367
- engine_name=engine_name,
3368
- engine_size=engine_size,
3369
- headers=headers,
3370
- )
3371
- headers['gi_setup_skipped'] = 'False'
3372
- response = self.request(
3373
- "create_txn", payload=payload, headers=headers, query_params=query_params, skip_auto_create=True, skip_engine_db_error_retry=True
3374
- )
3375
- else:
3376
- raise ResponseStatusException("Failed to create transaction.", response)
3377
-
3378
- return response
3379
-
3380
- def _exec_async_v2(
3381
- self,
3382
- database: str,
3383
- engine: Union[str, None],
3384
- raw_code: str,
3385
- inputs: Dict | None = None,
3386
- readonly=True,
3387
- nowait_durable=False,
3388
- headers: Dict[str, str] | None = None,
3389
- bypass_index=False,
3390
- language: str = "rel",
3391
- query_timeout_mins: int | None = None,
3392
- gi_setup_skipped: bool = False,
3393
- ):
3394
-
3395
- with debugging.span("transaction") as txn_span:
3396
- with debugging.span("create_v2") as create_span:
3397
-
3398
- use_graph_index = self.config.get("use_graph_index", USE_GRAPH_INDEX)
3399
-
3400
- payload = {
3401
- "dbname": database,
3402
- "engine_name": engine,
3403
- "query": raw_code,
3404
- "v1_inputs": inputs,
3405
- "nowait_durable": nowait_durable,
3406
- "readonly": readonly,
3407
- "language": language,
3408
- }
3409
- if query_timeout_mins is None and (timeout_value := self.config.get("query_timeout_mins", DEFAULT_QUERY_TIMEOUT_MINS)) is not None:
3410
- query_timeout_mins = int(timeout_value)
3411
- if query_timeout_mins is not None:
3412
- payload["timeout_mins"] = query_timeout_mins
3413
- query_params={"use_graph_index": str(use_graph_index and not bypass_index)}
3414
-
3415
- # Add gi_setup_skipped to headers
3416
- if headers is None:
3417
- headers = {}
3418
- headers["gi_setup_skipped"] = str(gi_setup_skipped)
3419
- headers['pyrel_program_id'] = debugging.get_program_span_id() or ""
3420
-
3421
- response = self._txn_request_with_gi_retry(
3422
- payload, headers, query_params, engine
3423
- )
3424
-
3425
- artifact_info = {}
3426
- response_content = response.json()
3427
-
3428
- txn_id = response_content["transaction"]['id']
3429
- state = response_content["transaction"]['state']
3430
-
3431
- txn_span["txn_id"] = txn_id
3432
- create_span["txn_id"] = txn_id
3433
- debugging.event("transaction_created", txn_span, txn_id=txn_id)
3434
-
3435
- # fast path: transaction already finished
3436
- if state in ["COMPLETED", "ABORTED"]:
3437
- if txn_id in self._pending_transactions:
3438
- self._pending_transactions.remove(txn_id)
3439
-
3440
- # Process rows to get the rest of the artifacts
3441
- for result in response_content.get("results", []):
3442
- filename = result['filename']
3443
- # making keys uppercase to match the old behavior
3444
- artifact_info[filename] = {k.upper(): v for k, v in result.items()}
3445
-
3446
- # Slow path: transaction not done yet; start polling
3447
- else:
3448
- self._pending_transactions.append(txn_id)
3449
- with debugging.span("wait", txn_id=txn_id):
3450
- poll_with_specified_overhead(
3451
- lambda: self._check_exec_async_status(txn_id, headers=headers), 0.1
3452
- )
3453
- artifact_info = self._list_exec_async_artifacts(txn_id, headers=headers)
3454
-
3455
- with debugging.span("fetch"):
3456
- return self._download_results(artifact_info, txn_id, state)
3457
-
3458
- def _prepare_index(
3459
- self,
3460
- model: str,
3461
- engine_name: str,
3462
- engine_size: str = "",
3463
- language: str = "rel",
3464
- rai_relations: List[str] | None = None,
3465
- pyrel_program_id: str | None = None,
3466
- skip_pull_relations: bool = False,
3467
- headers: Dict | None = None,
3468
- ):
3469
- """
3470
- Prepare the index for the given engine and model.
3471
- """
3472
- with debugging.span("prepare_index"):
3473
- if headers is None:
3474
- headers = {}
3475
-
3476
- payload = {
3477
- "model_name": model,
3478
- "caller_engine_name": engine_name,
3479
- "language": language,
3480
- "pyrel_program_id": pyrel_program_id,
3481
- "skip_pull_relations": skip_pull_relations,
3482
- "rai_relations": rai_relations or [],
3483
- "user_agent": get_pyrel_version(self.generation),
3484
- }
3485
- # Only include engine_size if it has a non-empty string value
3486
- if engine_size and engine_size.strip():
3487
- payload["caller_engine_size"] = engine_size
3488
-
3489
- response = self.request(
3490
- "prepare_index", payload=payload, headers=headers
3491
- )
3492
-
3493
- if response.status_code != 200:
3494
- raise ResponseStatusException("Failed to prepare index.", response)
3495
-
3496
- return response.json()
3497
-
3498
- def _poll_use_index(
3499
- self,
3500
- app_name: str,
3501
- sources: Iterable[str],
3502
- model: str,
3503
- engine_name: str,
3504
- engine_size: str | None = None,
3505
- program_span_id: str | None = None,
3506
- headers: Dict | None = None,
3507
- ):
3508
- return DirectUseIndexPoller(
3509
- self,
3510
- app_name=app_name,
3511
- sources=sources,
3512
- model=model,
3513
- engine_name=engine_name,
3514
- engine_size=engine_size,
3515
- language=self.language,
3516
- program_span_id=program_span_id,
3517
- headers=headers,
3518
- generation=self.generation,
3519
- ).poll()
3520
-
3521
- def maybe_poll_use_index(
3522
- self,
3523
- app_name: str,
3524
- sources: Iterable[str],
3525
- model: str,
3526
- engine_name: str,
3527
- engine_size: str | None = None,
3528
- program_span_id: str | None = None,
3529
- headers: Dict | None = None,
3530
- ):
3531
- """Only call poll() if there are sources to process and cache is not valid."""
3532
- sources_list = list(sources)
3533
- self.database = model
3534
- if sources_list:
3535
- poller = DirectUseIndexPoller(
3536
- self,
3537
- app_name=app_name,
3538
- sources=sources_list,
3539
- model=model,
3540
- engine_name=engine_name,
3541
- engine_size=engine_size,
3542
- language=self.language,
3543
- program_span_id=program_span_id,
3544
- headers=headers,
3545
- generation=self.generation,
3546
- )
3547
- # If cache is valid (data freshness has not expired), skip polling
3548
- if poller.cache.is_valid():
3549
- cached_sources = len(poller.cache.sources)
3550
- total_sources = len(sources_list)
3551
- cached_timestamp = poller.cache._metadata.get("cachedIndices", {}).get(poller.cache.key, {}).get("last_use_index_update_on", "")
3552
-
3553
- message = f"Using cached data for {cached_sources}/{total_sources} data streams"
3554
- if cached_timestamp:
3555
- print(f"\n{message} (cached at {cached_timestamp})\n")
3556
- else:
3557
- print(f"\n{message}\n")
3558
- else:
3559
- return poller.poll()
3560
-
3561
- def _check_exec_async_status(self, txn_id: str, headers: Dict[str, str] | None = None) -> bool:
3562
- """Check whether the given transaction has completed."""
3563
-
3564
- with debugging.span("check_status"):
3565
- response = self.request(
3566
- "get_txn",
3567
- headers=headers,
3568
- path_params={"txn_id": txn_id},
3569
- )
3570
- assert response, f"No results from get_transaction('{txn_id}')"
3571
-
3572
- response_content = response.json()
3573
- transaction = response_content["transaction"]
3574
- status: str = transaction['state']
3575
-
3576
- # remove the transaction from the pending list if it's completed or aborted
3577
- if status in ["COMPLETED", "ABORTED"]:
3578
- if txn_id in self._pending_transactions:
3579
- self._pending_transactions.remove(txn_id)
3580
-
3581
- if status == "ABORTED" and transaction.get("abort_reason", "") == TXN_ABORT_REASON_TIMEOUT:
3582
- config_file_path = getattr(self.config, 'file_path', None)
3583
- timeout_ms = int(transaction.get("timeout_ms", 0))
3584
- timeout_mins = timeout_ms // 60000 if timeout_ms > 0 else int(self.config.get("query_timeout_mins", DEFAULT_QUERY_TIMEOUT_MINS) or DEFAULT_QUERY_TIMEOUT_MINS)
3585
- raise QueryTimeoutExceededException(
3586
- timeout_mins=timeout_mins,
3587
- query_id=txn_id,
3588
- config_file_path=config_file_path,
3589
- )
3590
-
3591
- # @TODO: Find some way to tunnel the ABORT_REASON out. Azure doesn't have this, but it's handy
3592
- return status == "COMPLETED" or status == "ABORTED"
3593
-
3594
- def _list_exec_async_artifacts(self, txn_id: str, headers: Dict[str, str] | None = None) -> Dict[str, Dict]:
3595
- """Grab the list of artifacts produced in the transaction and the URLs to retrieve their contents."""
3596
- with debugging.span("list_results"):
3597
- response = self.request(
3598
- "get_txn_artifacts",
3599
- headers=headers,
3600
- path_params={"txn_id": txn_id},
3601
- )
3602
- assert response, f"No results from get_transaction_artifacts('{txn_id}')"
3603
- artifact_info = {}
3604
- for result in response.json()["results"]:
3605
- filename = result['filename']
3606
- # making keys uppercase to match the old behavior
3607
- artifact_info[filename] = {k.upper(): v for k, v in result.items()}
3608
- return artifact_info
3609
-
3610
- def get_transaction_problems(self, txn_id: str) -> List[Dict[str, Any]]:
3611
- with debugging.span("get_transaction_problems"):
3612
- response = self.request(
3613
- "get_txn_problems",
3614
- path_params={"txn_id": txn_id},
3615
- )
3616
- response_content = response.json()
3617
- if not response_content:
3618
- return []
3619
- return response_content.get("problems", [])
3620
-
3621
- def get_transaction_events(self, transaction_id: str, continuation_token: str = ''):
3622
- response = self.request(
3623
- "get_txn_events",
3624
- path_params={"txn_id": transaction_id, "stream_name": "profiler"},
3625
- query_params={"continuation_token": continuation_token},
3626
- )
3627
- response_content = response.json()
3628
- if not response_content:
3629
- return {
3630
- "events": [],
3631
- "continuation_token": None
3632
- }
3633
- return response_content
3634
-
3635
- #--------------------------------------------------
3636
- # Databases
3637
- #--------------------------------------------------
3638
-
3639
- def get_installed_packages(self, database: str) -> Union[Dict, None]:
3640
- use_graph_index = self.config.get("use_graph_index", USE_GRAPH_INDEX)
3641
- if use_graph_index:
3642
- response = self.request(
3643
- "get_model_package_versions",
3644
- payload={"model_name": database},
3645
- )
3646
- else:
3647
- response = self.request(
3648
- "get_package_versions",
3649
- path_params={"db_name": database},
3650
- )
3651
- if response.status_code == 404 and response.json().get("message", "") == "database not found":
3652
- return None
3653
- if response.status_code != 200:
3654
- raise ResponseStatusException(
3655
- f"Failed to retrieve package versions for {database}.", response
3656
- )
3657
-
3658
- content = response.json()
3659
- if not content:
3660
- return None
3661
-
3662
- return safe_json_loads(content["package_versions"])
3663
-
3664
- def get_database(self, database: str):
3665
- with debugging.span("get_database", dbname=database):
3666
- if not database:
3667
- raise ValueError("Database name must be provided to get database.")
3668
- response = self.request(
3669
- "get_db",
3670
- path_params={},
3671
- query_params={"name": database},
3672
- )
3673
- if response.status_code != 200:
3674
- raise ResponseStatusException(f"Failed to get db. db:{database}", response)
3675
-
3676
- response_content = response.json()
3677
-
3678
- if (response_content.get("databases") and len(response_content["databases"]) == 1):
3679
- db = response_content["databases"][0]
3680
- return {
3681
- "id": db["id"],
3682
- "name": db["name"],
3683
- "created_by": db.get("created_by"),
3684
- "created_on": ms_to_timestamp(db.get("created_on")),
3685
- "deleted_by": db.get("deleted_by"),
3686
- "deleted_on": ms_to_timestamp(db.get("deleted_on")),
3687
- "state": db["state"],
3688
- }
3689
- else:
3690
- return None
3691
-
3692
- def create_graph(self, name: str):
3693
- with debugging.span("create_model", dbname=name):
3694
- return self._create_database(name,"")
3695
-
3696
- def delete_graph(self, name:str, force=False, language: str = "rel"):
3697
- prop_hdrs = debugging.gen_current_propagation_headers()
3698
- if self.config.get("use_graph_index", USE_GRAPH_INDEX):
3699
- keep_database = not force and self.config.get("reuse_model", True)
3700
- with debugging.span("release_index", name=name, keep_database=keep_database, language=language):
3701
- response = self.request(
3702
- "release_index",
3703
- payload={
3704
- "model_name": name,
3705
- "keep_database": keep_database,
3706
- "language": language,
3707
- "user_agent": get_pyrel_version(self.generation),
3708
- },
3709
- headers=prop_hdrs,
3710
- )
3711
- if (
3712
- response.status_code != 200
3713
- and not (
3714
- response.status_code == 404
3715
- and "database not found" in response.json().get("message", "")
3716
- )
3717
- ):
3718
- raise ResponseStatusException(f"Failed to release index. Model: {name} ", response)
3719
- else:
3720
- with debugging.span("delete_model", name=name):
3721
- self._delete_database(name, headers=prop_hdrs)
3722
-
3723
- def clone_graph(self, target_name:str, source_name:str, nowait_durable=True, force=False):
3724
- if force and self.get_graph(target_name):
3725
- self.delete_graph(target_name)
3726
- with debugging.span("clone_model", target_name=target_name, source_name=source_name):
3727
- return self._create_database(target_name,source_name)
3728
-
3729
- def _delete_database(self, name:str, headers:Dict={}):
3730
- with debugging.span("_delete_database", dbname=name):
3731
- response = self.request(
3732
- "delete_db",
3733
- path_params={"db_name": name},
3734
- query_params={},
3735
- headers=headers,
3736
- )
3737
- if response.status_code != 200:
3738
- raise ResponseStatusException(f"Failed to delete db. db:{name} ", response)
3739
-
3740
- def _create_database(self, name:str, source_name:str):
3741
- with debugging.span("_create_database", dbname=name):
3742
- payload = {
3743
- "name": name,
3744
- "source_name": source_name,
3745
- }
3746
- response = self.request(
3747
- "create_db", payload=payload, headers={}, query_params={},
3748
- )
3749
- if response.status_code != 200:
3750
- raise ResponseStatusException(f"Failed to create db. db:{name}", response)
3751
-
3752
- #--------------------------------------------------
3753
- # Engines
3754
- #--------------------------------------------------
3755
-
3756
- def list_engines(self, state: str | None = None):
3757
- response = self.request("list_engines")
3758
- if response.status_code != 200:
3759
- raise ResponseStatusException(
3760
- "Failed to retrieve engines.", response
3761
- )
3762
- response_content = response.json()
3763
- if not response_content:
3764
- return []
3765
- engines = [
3766
- {
3767
- "name": engine["name"],
3768
- "id": engine["id"],
3769
- "size": engine["size"],
3770
- "state": engine["status"], # callers are expecting 'state'
3771
- "created_by": engine["created_by"],
3772
- "created_on": engine["created_on"],
3773
- "updated_on": engine["updated_on"],
3774
- }
3775
- for engine in response_content.get("engines", [])
3776
- if state is None or engine.get("status") == state
3777
- ]
3778
- return sorted(engines, key=lambda x: x["name"])
3779
-
3780
- def get_engine(self, name: str):
3781
- response = self.request("get_engine", path_params={"engine_name": name, "engine_type": "logic"}, skip_auto_create=True)
3782
- if response.status_code == 404: # engine not found return 404
3783
- return None
3784
- elif response.status_code != 200:
3785
- raise ResponseStatusException(
3786
- f"Failed to retrieve engine {name}.", response
3787
- )
3788
- engine = response.json()
3789
- if not engine:
3790
- return None
3791
- engine_state: EngineState = {
3792
- "name": engine["name"],
3793
- "id": engine["id"],
3794
- "size": engine["size"],
3795
- "state": engine["status"], # callers are expecting 'state'
3796
- "created_by": engine["created_by"],
3797
- "created_on": engine["created_on"],
3798
- "updated_on": engine["updated_on"],
3799
- "version": engine["version"],
3800
- "auto_suspend": engine["auto_suspend_mins"],
3801
- "suspends_at": engine["suspends_at"],
3802
- }
3803
- return engine_state
3804
-
3805
- def _create_engine(
3806
- self,
3807
- name: str,
3808
- size: str | None = None,
3809
- auto_suspend_mins: int | None = None,
3810
- is_async: bool = False,
3811
- headers: Dict[str, str] | None = None
3812
- ):
3813
- # only async engine creation supported via direct access
3814
- if not is_async:
3815
- return super()._create_engine(name, size, auto_suspend_mins, is_async, headers=headers)
3816
- payload:Dict[str, Any] = {
3817
- "name": name,
3818
- }
3819
- if auto_suspend_mins is not None:
3820
- payload["auto_suspend_mins"] = auto_suspend_mins
3821
- if size is not None:
3822
- payload["size"] = size
3823
- response = self.request(
3824
- "create_engine",
3825
- payload=payload,
3826
- path_params={"engine_type": "logic"},
3827
- headers=headers,
3828
- skip_auto_create=True,
3829
- )
3830
- if response.status_code != 200:
3831
- raise ResponseStatusException(
3832
- f"Failed to create engine {name} with size {size}.", response
3833
- )
3834
-
3835
- def delete_engine(self, name:str, force:bool = False, headers={}):
3836
- response = self.request(
3837
- "delete_engine",
3838
- path_params={"engine_name": name, "engine_type": "logic"},
3839
- headers=headers,
3840
- skip_auto_create=True,
3841
- )
3842
- if response.status_code != 200:
3843
- raise ResponseStatusException(
3844
- f"Failed to delete engine {name}.", response
3845
- )
3846
-
3847
- def suspend_engine(self, name:str):
3848
- response = self.request(
3849
- "suspend_engine",
3850
- path_params={"engine_name": name, "engine_type": "logic"},
3851
- skip_auto_create=True,
3852
- )
3853
- if response.status_code != 200:
3854
- raise ResponseStatusException(
3855
- f"Failed to suspend engine {name}.", response
3856
- )
3857
-
3858
- def resume_engine_async(self, name:str, headers={}):
3859
- response = self.request(
3860
- "resume_engine",
3861
- path_params={"engine_name": name, "engine_type": "logic"},
3862
- headers=headers,
3863
- skip_auto_create=True,
3864
- )
3865
- if response.status_code != 200:
3866
- raise ResponseStatusException(
3867
- f"Failed to resume engine {name}.", response
3868
- )
3869
- return {}