apache-airflow-providers-snowflake 6.3.1__tar.gz → 6.4.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of apache-airflow-providers-snowflake might be problematic. Click here for more details.

Files changed (71) hide show
  1. {apache_airflow_providers_snowflake-6.3.1 → apache_airflow_providers_snowflake-6.4.0}/PKG-INFO +6 -6
  2. {apache_airflow_providers_snowflake-6.3.1 → apache_airflow_providers_snowflake-6.4.0}/README.rst +3 -3
  3. {apache_airflow_providers_snowflake-6.3.1 → apache_airflow_providers_snowflake-6.4.0}/docs/changelog.rst +23 -0
  4. {apache_airflow_providers_snowflake-6.3.1 → apache_airflow_providers_snowflake-6.4.0}/docs/connections/snowflake.rst +2 -0
  5. {apache_airflow_providers_snowflake-6.3.1 → apache_airflow_providers_snowflake-6.4.0}/docs/index.rst +3 -3
  6. {apache_airflow_providers_snowflake-6.3.1 → apache_airflow_providers_snowflake-6.4.0}/provider.yaml +2 -1
  7. {apache_airflow_providers_snowflake-6.3.1 → apache_airflow_providers_snowflake-6.4.0}/pyproject.toml +3 -3
  8. {apache_airflow_providers_snowflake-6.3.1 → apache_airflow_providers_snowflake-6.4.0}/src/airflow/providers/snowflake/__init__.py +1 -1
  9. {apache_airflow_providers_snowflake-6.3.1 → apache_airflow_providers_snowflake-6.4.0}/src/airflow/providers/snowflake/hooks/snowflake.py +32 -7
  10. {apache_airflow_providers_snowflake-6.3.1 → apache_airflow_providers_snowflake-6.4.0}/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py +10 -2
  11. {apache_airflow_providers_snowflake-6.3.1 → apache_airflow_providers_snowflake-6.4.0}/src/airflow/providers/snowflake/operators/snowflake.py +17 -10
  12. {apache_airflow_providers_snowflake-6.3.1 → apache_airflow_providers_snowflake-6.4.0}/src/airflow/providers/snowflake/transfers/copy_into_snowflake.py +12 -3
  13. {apache_airflow_providers_snowflake-6.3.1 → apache_airflow_providers_snowflake-6.4.0}/src/airflow/providers/snowflake/triggers/snowflake_trigger.py +1 -4
  14. {apache_airflow_providers_snowflake-6.3.1 → apache_airflow_providers_snowflake-6.4.0}/src/airflow/providers/snowflake/utils/openlineage.py +15 -4
  15. {apache_airflow_providers_snowflake-6.3.1 → apache_airflow_providers_snowflake-6.4.0}/tests/unit/snowflake/decorators/test_snowpark.py +19 -10
  16. {apache_airflow_providers_snowflake-6.3.1 → apache_airflow_providers_snowflake-6.4.0}/tests/unit/snowflake/hooks/test_snowflake.py +137 -2
  17. {apache_airflow_providers_snowflake-6.3.1 → apache_airflow_providers_snowflake-6.4.0}/tests/unit/snowflake/hooks/test_snowflake_sql_api.py +40 -0
  18. {apache_airflow_providers_snowflake-6.3.1 → apache_airflow_providers_snowflake-6.4.0}/tests/unit/snowflake/operators/test_snowflake.py +42 -0
  19. {apache_airflow_providers_snowflake-6.3.1 → apache_airflow_providers_snowflake-6.4.0}/tests/unit/snowflake/transfers/test_copy_into_snowflake.py +23 -0
  20. {apache_airflow_providers_snowflake-6.3.1 → apache_airflow_providers_snowflake-6.4.0}/tests/unit/snowflake/triggers/test_snowflake.py +4 -3
  21. {apache_airflow_providers_snowflake-6.3.1 → apache_airflow_providers_snowflake-6.4.0}/tests/unit/snowflake/utils/test_openlineage.py +3 -0
  22. {apache_airflow_providers_snowflake-6.3.1 → apache_airflow_providers_snowflake-6.4.0}/docs/.latest-doc-only-change.txt +0 -0
  23. {apache_airflow_providers_snowflake-6.3.1 → apache_airflow_providers_snowflake-6.4.0}/docs/commits.rst +0 -0
  24. {apache_airflow_providers_snowflake-6.3.1 → apache_airflow_providers_snowflake-6.4.0}/docs/conf.py +0 -0
  25. {apache_airflow_providers_snowflake-6.3.1 → apache_airflow_providers_snowflake-6.4.0}/docs/decorators/index.rst +0 -0
  26. {apache_airflow_providers_snowflake-6.3.1 → apache_airflow_providers_snowflake-6.4.0}/docs/decorators/snowpark.rst +0 -0
  27. {apache_airflow_providers_snowflake-6.3.1 → apache_airflow_providers_snowflake-6.4.0}/docs/installing-providers-from-sources.rst +0 -0
  28. {apache_airflow_providers_snowflake-6.3.1 → apache_airflow_providers_snowflake-6.4.0}/docs/integration-logos/Snowflake.png +0 -0
  29. {apache_airflow_providers_snowflake-6.3.1 → apache_airflow_providers_snowflake-6.4.0}/docs/operators/copy_into_snowflake.rst +0 -0
  30. {apache_airflow_providers_snowflake-6.3.1 → apache_airflow_providers_snowflake-6.4.0}/docs/operators/index.rst +0 -0
  31. {apache_airflow_providers_snowflake-6.3.1 → apache_airflow_providers_snowflake-6.4.0}/docs/operators/snowflake.rst +0 -0
  32. {apache_airflow_providers_snowflake-6.3.1 → apache_airflow_providers_snowflake-6.4.0}/docs/operators/snowpark.rst +0 -0
  33. {apache_airflow_providers_snowflake-6.3.1 → apache_airflow_providers_snowflake-6.4.0}/docs/security.rst +0 -0
  34. {apache_airflow_providers_snowflake-6.3.1 → apache_airflow_providers_snowflake-6.4.0}/src/airflow/__init__.py +0 -0
  35. {apache_airflow_providers_snowflake-6.3.1 → apache_airflow_providers_snowflake-6.4.0}/src/airflow/providers/__init__.py +0 -0
  36. {apache_airflow_providers_snowflake-6.3.1 → apache_airflow_providers_snowflake-6.4.0}/src/airflow/providers/snowflake/LICENSE +0 -0
  37. {apache_airflow_providers_snowflake-6.3.1 → apache_airflow_providers_snowflake-6.4.0}/src/airflow/providers/snowflake/decorators/__init__.py +0 -0
  38. {apache_airflow_providers_snowflake-6.3.1 → apache_airflow_providers_snowflake-6.4.0}/src/airflow/providers/snowflake/decorators/snowpark.py +0 -0
  39. {apache_airflow_providers_snowflake-6.3.1 → apache_airflow_providers_snowflake-6.4.0}/src/airflow/providers/snowflake/get_provider_info.py +0 -0
  40. {apache_airflow_providers_snowflake-6.3.1 → apache_airflow_providers_snowflake-6.4.0}/src/airflow/providers/snowflake/hooks/__init__.py +0 -0
  41. {apache_airflow_providers_snowflake-6.3.1 → apache_airflow_providers_snowflake-6.4.0}/src/airflow/providers/snowflake/operators/__init__.py +0 -0
  42. {apache_airflow_providers_snowflake-6.3.1 → apache_airflow_providers_snowflake-6.4.0}/src/airflow/providers/snowflake/operators/snowpark.py +0 -0
  43. {apache_airflow_providers_snowflake-6.3.1 → apache_airflow_providers_snowflake-6.4.0}/src/airflow/providers/snowflake/transfers/__init__.py +0 -0
  44. {apache_airflow_providers_snowflake-6.3.1 → apache_airflow_providers_snowflake-6.4.0}/src/airflow/providers/snowflake/triggers/__init__.py +0 -0
  45. {apache_airflow_providers_snowflake-6.3.1 → apache_airflow_providers_snowflake-6.4.0}/src/airflow/providers/snowflake/utils/__init__.py +0 -0
  46. {apache_airflow_providers_snowflake-6.3.1 → apache_airflow_providers_snowflake-6.4.0}/src/airflow/providers/snowflake/utils/common.py +0 -0
  47. {apache_airflow_providers_snowflake-6.3.1 → apache_airflow_providers_snowflake-6.4.0}/src/airflow/providers/snowflake/utils/snowpark.py +0 -0
  48. {apache_airflow_providers_snowflake-6.3.1 → apache_airflow_providers_snowflake-6.4.0}/src/airflow/providers/snowflake/utils/sql_api_generate_jwt.py +0 -0
  49. {apache_airflow_providers_snowflake-6.3.1 → apache_airflow_providers_snowflake-6.4.0}/src/airflow/providers/snowflake/version_compat.py +0 -0
  50. {apache_airflow_providers_snowflake-6.3.1 → apache_airflow_providers_snowflake-6.4.0}/tests/conftest.py +0 -0
  51. {apache_airflow_providers_snowflake-6.3.1 → apache_airflow_providers_snowflake-6.4.0}/tests/system/__init__.py +0 -0
  52. {apache_airflow_providers_snowflake-6.3.1 → apache_airflow_providers_snowflake-6.4.0}/tests/system/snowflake/__init__.py +0 -0
  53. {apache_airflow_providers_snowflake-6.3.1 → apache_airflow_providers_snowflake-6.4.0}/tests/system/snowflake/example_copy_into_snowflake.py +0 -0
  54. {apache_airflow_providers_snowflake-6.3.1 → apache_airflow_providers_snowflake-6.4.0}/tests/system/snowflake/example_snowflake.py +0 -0
  55. {apache_airflow_providers_snowflake-6.3.1 → apache_airflow_providers_snowflake-6.4.0}/tests/system/snowflake/example_snowflake_snowflake_op_template_file.sql +0 -0
  56. {apache_airflow_providers_snowflake-6.3.1 → apache_airflow_providers_snowflake-6.4.0}/tests/system/snowflake/example_snowpark_decorator.py +0 -0
  57. {apache_airflow_providers_snowflake-6.3.1 → apache_airflow_providers_snowflake-6.4.0}/tests/system/snowflake/example_snowpark_operator.py +0 -0
  58. {apache_airflow_providers_snowflake-6.3.1 → apache_airflow_providers_snowflake-6.4.0}/tests/unit/__init__.py +0 -0
  59. {apache_airflow_providers_snowflake-6.3.1 → apache_airflow_providers_snowflake-6.4.0}/tests/unit/snowflake/__init__.py +0 -0
  60. {apache_airflow_providers_snowflake-6.3.1 → apache_airflow_providers_snowflake-6.4.0}/tests/unit/snowflake/decorators/__init__.py +0 -0
  61. {apache_airflow_providers_snowflake-6.3.1 → apache_airflow_providers_snowflake-6.4.0}/tests/unit/snowflake/hooks/__init__.py +0 -0
  62. {apache_airflow_providers_snowflake-6.3.1 → apache_airflow_providers_snowflake-6.4.0}/tests/unit/snowflake/hooks/test_sql.py +0 -0
  63. {apache_airflow_providers_snowflake-6.3.1 → apache_airflow_providers_snowflake-6.4.0}/tests/unit/snowflake/operators/__init__.py +0 -0
  64. {apache_airflow_providers_snowflake-6.3.1 → apache_airflow_providers_snowflake-6.4.0}/tests/unit/snowflake/operators/test_snowflake_sql.py +0 -0
  65. {apache_airflow_providers_snowflake-6.3.1 → apache_airflow_providers_snowflake-6.4.0}/tests/unit/snowflake/operators/test_snowpark.py +0 -0
  66. {apache_airflow_providers_snowflake-6.3.1 → apache_airflow_providers_snowflake-6.4.0}/tests/unit/snowflake/transfers/__init__.py +0 -0
  67. {apache_airflow_providers_snowflake-6.3.1 → apache_airflow_providers_snowflake-6.4.0}/tests/unit/snowflake/triggers/__init__.py +0 -0
  68. {apache_airflow_providers_snowflake-6.3.1 → apache_airflow_providers_snowflake-6.4.0}/tests/unit/snowflake/utils/__init__.py +0 -0
  69. {apache_airflow_providers_snowflake-6.3.1 → apache_airflow_providers_snowflake-6.4.0}/tests/unit/snowflake/utils/test_common.py +0 -0
  70. {apache_airflow_providers_snowflake-6.3.1 → apache_airflow_providers_snowflake-6.4.0}/tests/unit/snowflake/utils/test_snowpark.py +0 -0
  71. {apache_airflow_providers_snowflake-6.3.1 → apache_airflow_providers_snowflake-6.4.0}/tests/unit/snowflake/utils/test_sql_api_generate_jwt.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: apache-airflow-providers-snowflake
3
- Version: 6.3.1
3
+ Version: 6.4.0
4
4
  Summary: Provider package apache-airflow-providers-snowflake for Apache Airflow
5
5
  Keywords: airflow-provider,snowflake,airflow,integration
6
6
  Author-email: Apache Software Foundation <dev@airflow.apache.org>
@@ -30,8 +30,8 @@ Requires-Dist: snowflake-sqlalchemy>=1.4.0
30
30
  Requires-Dist: snowflake-snowpark-python>=1.17.0;python_version<'3.12'
31
31
  Requires-Dist: apache-airflow-providers-openlineage>=2.3.0 ; extra == "openlineage"
32
32
  Project-URL: Bug Tracker, https://github.com/apache/airflow/issues
33
- Project-URL: Changelog, https://airflow.apache.org/docs/apache-airflow-providers-snowflake/6.3.1/changelog.html
34
- Project-URL: Documentation, https://airflow.apache.org/docs/apache-airflow-providers-snowflake/6.3.1
33
+ Project-URL: Changelog, https://airflow.apache.org/docs/apache-airflow-providers-snowflake/6.4.0/changelog.html
34
+ Project-URL: Documentation, https://airflow.apache.org/docs/apache-airflow-providers-snowflake/6.4.0
35
35
  Project-URL: Mastodon, https://fosstodon.org/@airflow
36
36
  Project-URL: Slack Chat, https://s.apache.org/airflow-slack
37
37
  Project-URL: Source Code, https://github.com/apache/airflow
@@ -63,7 +63,7 @@ Provides-Extra: openlineage
63
63
 
64
64
  Package ``apache-airflow-providers-snowflake``
65
65
 
66
- Release: ``6.3.1``
66
+ Release: ``6.4.0``
67
67
 
68
68
 
69
69
  `Snowflake <https://www.snowflake.com/>`__
@@ -76,7 +76,7 @@ This is a provider package for ``snowflake`` provider. All classes for this prov
76
76
  are in ``airflow.providers.snowflake`` python package.
77
77
 
78
78
  You can find package information and changelog for the provider
79
- in the `documentation <https://airflow.apache.org/docs/apache-airflow-providers-snowflake/6.3.1/>`_.
79
+ in the `documentation <https://airflow.apache.org/docs/apache-airflow-providers-snowflake/6.4.0/>`_.
80
80
 
81
81
  Installation
82
82
  ------------
@@ -125,5 +125,5 @@ Dependent package
125
125
  ================================================================================================================== =================
126
126
 
127
127
  The changelog for the provider package can be found in the
128
- `changelog <https://airflow.apache.org/docs/apache-airflow-providers-snowflake/6.3.1/changelog.html>`_.
128
+ `changelog <https://airflow.apache.org/docs/apache-airflow-providers-snowflake/6.4.0/changelog.html>`_.
129
129
 
@@ -23,7 +23,7 @@
23
23
 
24
24
  Package ``apache-airflow-providers-snowflake``
25
25
 
26
- Release: ``6.3.1``
26
+ Release: ``6.4.0``
27
27
 
28
28
 
29
29
  `Snowflake <https://www.snowflake.com/>`__
@@ -36,7 +36,7 @@ This is a provider package for ``snowflake`` provider. All classes for this prov
36
36
  are in ``airflow.providers.snowflake`` python package.
37
37
 
38
38
  You can find package information and changelog for the provider
39
- in the `documentation <https://airflow.apache.org/docs/apache-airflow-providers-snowflake/6.3.1/>`_.
39
+ in the `documentation <https://airflow.apache.org/docs/apache-airflow-providers-snowflake/6.4.0/>`_.
40
40
 
41
41
  Installation
42
42
  ------------
@@ -85,4 +85,4 @@ Dependent package
85
85
  ================================================================================================================== =================
86
86
 
87
87
  The changelog for the provider package can be found in the
88
- `changelog <https://airflow.apache.org/docs/apache-airflow-providers-snowflake/6.3.1/changelog.html>`_.
88
+ `changelog <https://airflow.apache.org/docs/apache-airflow-providers-snowflake/6.4.0/changelog.html>`_.
@@ -27,6 +27,29 @@
27
27
  Changelog
28
28
  ---------
29
29
 
30
+ 6.4.0
31
+ .....
32
+
33
+ Features
34
+ ~~~~~~~~
35
+
36
+ * ``Extend SnowflakeHook OAuth implementation to support external IDPs and client_credentials grant (#51620)``
37
+
38
+ Bug Fixes
39
+ ~~~~~~~~~
40
+
41
+ * ``fix: make query_ids in SnowflakeSqlApiOperator in deferrable mode consistent (#51542)``
42
+ * ``fix: Duplicate region in Snowflake URI no longer breaks OpenLineage (#50831)``
43
+ * ``Do not allow semicolons in CopyFromExternalStageToSnowflakeOperator fieldS (#51734)``
44
+
45
+ Misc
46
+ ~~~~
47
+
48
+ * ``Port ''ti.run'' to Task SDK execution path (#50141)``
49
+
50
+ .. Below changes are excluded from the changelog. Move them to
51
+ appropriate section above if needed. Do not delete the lines(!):
52
+
30
53
  6.3.1
31
54
  .....
32
55
 
@@ -58,6 +58,8 @@ Extra (optional)
58
58
  * ``warehouse``: Snowflake warehouse name.
59
59
  * ``role``: Snowflake role.
60
60
  * ``authenticator``: To connect using OAuth set this parameter ``oauth``.
61
+ * ``token_endpoint``: Specify token endpoint for external OAuth provider.
62
+ * ``grant_type``: Specify grant type for OAuth authentication. Currently supported: ``refresh_token`` (default), ``client_credentials``.
61
63
  * ``refresh_token``: Specify refresh_token for OAuth connection.
62
64
  * ``private_key_file``: Specify the path to the private key file.
63
65
  * ``private_key_content``: Specify the content of the private key file in base64 encoded format. You can use the following Python code to encode the private key:
@@ -78,7 +78,7 @@ apache-airflow-providers-snowflake package
78
78
  `Snowflake <https://www.snowflake.com/>`__
79
79
 
80
80
 
81
- Release: 6.3.1
81
+ Release: 6.4.0
82
82
 
83
83
  Provider package
84
84
  ----------------
@@ -138,5 +138,5 @@ Downloading official packages
138
138
  You can download officially released packages and verify their checksums and signatures from the
139
139
  `Official Apache Download site <https://downloads.apache.org/airflow/providers/>`_
140
140
 
141
- * `The apache-airflow-providers-snowflake 6.3.1 sdist package <https://downloads.apache.org/airflow/providers/apache_airflow_providers_snowflake-6.3.1.tar.gz>`_ (`asc <https://downloads.apache.org/airflow/providers/apache_airflow_providers_snowflake-6.3.1.tar.gz.asc>`__, `sha512 <https://downloads.apache.org/airflow/providers/apache_airflow_providers_snowflake-6.3.1.tar.gz.sha512>`__)
142
- * `The apache-airflow-providers-snowflake 6.3.1 wheel package <https://downloads.apache.org/airflow/providers/apache_airflow_providers_snowflake-6.3.1-py3-none-any.whl>`_ (`asc <https://downloads.apache.org/airflow/providers/apache_airflow_providers_snowflake-6.3.1-py3-none-any.whl.asc>`__, `sha512 <https://downloads.apache.org/airflow/providers/apache_airflow_providers_snowflake-6.3.1-py3-none-any.whl.sha512>`__)
141
+ * `The apache-airflow-providers-snowflake 6.4.0 sdist package <https://downloads.apache.org/airflow/providers/apache_airflow_providers_snowflake-6.4.0.tar.gz>`_ (`asc <https://downloads.apache.org/airflow/providers/apache_airflow_providers_snowflake-6.4.0.tar.gz.asc>`__, `sha512 <https://downloads.apache.org/airflow/providers/apache_airflow_providers_snowflake-6.4.0.tar.gz.sha512>`__)
142
+ * `The apache-airflow-providers-snowflake 6.4.0 wheel package <https://downloads.apache.org/airflow/providers/apache_airflow_providers_snowflake-6.4.0-py3-none-any.whl>`_ (`asc <https://downloads.apache.org/airflow/providers/apache_airflow_providers_snowflake-6.4.0-py3-none-any.whl.asc>`__, `sha512 <https://downloads.apache.org/airflow/providers/apache_airflow_providers_snowflake-6.4.0-py3-none-any.whl.sha512>`__)
@@ -22,12 +22,13 @@ description: |
22
22
  `Snowflake <https://www.snowflake.com/>`__
23
23
 
24
24
  state: ready
25
- source-date-epoch: 1747133792
25
+ source-date-epoch: 1749896974
26
26
  # Note that those versions are maintained by release manager - do not update them manually
27
27
  # with the exception of case where other provider in sources has >= new provider version.
28
28
  # In such case adding >= NEW_VERSION and bumping to NEW_VERSION in a provider have
29
29
  # to be done in the same PR
30
30
  versions:
31
+ - 6.4.0
31
32
  - 6.3.1
32
33
  - 6.3.0
33
34
  - 6.2.2
@@ -25,7 +25,7 @@ build-backend = "flit_core.buildapi"
25
25
 
26
26
  [project]
27
27
  name = "apache-airflow-providers-snowflake"
28
- version = "6.3.1"
28
+ version = "6.4.0"
29
29
  description = "Provider package apache-airflow-providers-snowflake for Apache Airflow"
30
30
  readme = "README.rst"
31
31
  authors = [
@@ -116,8 +116,8 @@ apache-airflow-providers-common-sql = {workspace = true}
116
116
  apache-airflow-providers-standard = {workspace = true}
117
117
 
118
118
  [project.urls]
119
- "Documentation" = "https://airflow.apache.org/docs/apache-airflow-providers-snowflake/6.3.1"
120
- "Changelog" = "https://airflow.apache.org/docs/apache-airflow-providers-snowflake/6.3.1/changelog.html"
119
+ "Documentation" = "https://airflow.apache.org/docs/apache-airflow-providers-snowflake/6.4.0"
120
+ "Changelog" = "https://airflow.apache.org/docs/apache-airflow-providers-snowflake/6.4.0/changelog.html"
121
121
  "Bug Tracker" = "https://github.com/apache/airflow/issues"
122
122
  "Source Code" = "https://github.com/apache/airflow"
123
123
  "Slack Chat" = "https://s.apache.org/airflow-slack"
@@ -29,7 +29,7 @@ from airflow import __version__ as airflow_version
29
29
 
30
30
  __all__ = ["__version__"]
31
31
 
32
- __version__ = "6.3.1"
32
+ __version__ = "6.4.0"
33
33
 
34
34
  if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse(
35
35
  "2.10.0"
@@ -136,6 +136,9 @@ class SnowflakeHook(DbApiHook):
136
136
  "session_parameters": "session parameters",
137
137
  "client_request_mfa_token": "client request mfa token",
138
138
  "client_store_temporary_credential": "client store temporary credential (externalbrowser mode)",
139
+ "grant_type": "refresh_token client_credentials",
140
+ "token_endpoint": "token endpoint",
141
+ "refresh_token": "refresh token",
139
142
  },
140
143
  indent=1,
141
144
  ),
@@ -200,18 +203,32 @@ class SnowflakeHook(DbApiHook):
200
203
 
201
204
  return account_identifier
202
205
 
203
- def get_oauth_token(self, conn_config: dict | None = None) -> str:
206
+ def get_oauth_token(
207
+ self,
208
+ conn_config: dict | None = None,
209
+ token_endpoint: str | None = None,
210
+ grant_type: str = "refresh_token",
211
+ ) -> str:
204
212
  """Generate temporary OAuth access token using refresh token in connection details."""
205
213
  if conn_config is None:
206
214
  conn_config = self._get_conn_params
207
215
 
208
- url = f"https://{conn_config['account']}.snowflakecomputing.com/oauth/token-request"
216
+ url = token_endpoint or f"https://{conn_config['account']}.snowflakecomputing.com/oauth/token-request"
209
217
 
210
218
  data = {
211
- "grant_type": "refresh_token",
212
- "refresh_token": conn_config["refresh_token"],
219
+ "grant_type": grant_type,
213
220
  "redirect_uri": conn_config.get("redirect_uri", "https://localhost.com"),
214
221
  }
222
+
223
+ if grant_type == "refresh_token":
224
+ data |= {
225
+ "refresh_token": conn_config["refresh_token"],
226
+ }
227
+ elif grant_type == "client_credentials":
228
+ pass # no setup necessary for client credentials grant.
229
+ else:
230
+ raise ValueError(f"Unknown grant_type: {grant_type}")
231
+
215
232
  response = requests.post(
216
233
  url,
217
234
  data=data,
@@ -226,7 +243,8 @@ class SnowflakeHook(DbApiHook):
226
243
  except requests.exceptions.HTTPError as e: # pragma: no cover
227
244
  msg = f"Response: {e.response.content.decode()} Status Code: {e.response.status_code}"
228
245
  raise AirflowException(msg)
229
- return response.json()["access_token"]
246
+ token = response.json()["access_token"]
247
+ return token
230
248
 
231
249
  @cached_property
232
250
  def _get_conn_params(self) -> dict[str, str | None]:
@@ -329,14 +347,21 @@ class SnowflakeHook(DbApiHook):
329
347
  if refresh_token:
330
348
  conn_config["refresh_token"] = refresh_token
331
349
  conn_config["authenticator"] = "oauth"
350
+
351
+ if conn_config.get("authenticator") == "oauth":
352
+ token_endpoint = self._get_field(extra_dict, "token_endpoint") or ""
332
353
  conn_config["client_id"] = conn.login
333
354
  conn_config["client_secret"] = conn.password
355
+ conn_config["token"] = self.get_oauth_token(
356
+ conn_config=conn_config,
357
+ token_endpoint=token_endpoint,
358
+ grant_type=extra_dict.get("grant_type", "refresh_token"),
359
+ )
360
+
334
361
  conn_config.pop("login", None)
335
362
  conn_config.pop("user", None)
336
363
  conn_config.pop("password", None)
337
364
 
338
- conn_config["token"] = self.get_oauth_token(conn_config=conn_config)
339
-
340
365
  # configure custom target hostname and port, if specified
341
366
  snowflake_host = extra_dict.get("host")
342
367
  snowflake_port = extra_dict.get("port")
@@ -137,6 +137,7 @@ class SnowflakeSqlApiHook(SnowflakeHook):
137
137
  When executing the statement, Snowflake replaces placeholders (? and :name) in
138
138
  the statement with these specified values.
139
139
  """
140
+ self.query_ids = []
140
141
  conn_config = self._get_conn_params
141
142
 
142
143
  req_id = uuid.uuid4()
@@ -222,14 +223,21 @@ class SnowflakeSqlApiHook(SnowflakeHook):
222
223
  }
223
224
  return headers
224
225
 
225
- def get_oauth_token(self, conn_config: dict[str, Any] | None = None) -> str:
226
+ def get_oauth_token(
227
+ self,
228
+ conn_config: dict[str, Any] | None = None,
229
+ token_endpoint: str | None = None,
230
+ grant_type: str = "refresh_token",
231
+ ) -> str:
226
232
  """Generate temporary OAuth access token using refresh token in connection details."""
227
233
  warnings.warn(
228
234
  "This method is deprecated. Please use `get_oauth_token` method from `SnowflakeHook` instead. ",
229
235
  AirflowProviderDeprecationWarning,
230
236
  stacklevel=2,
231
237
  )
232
- return super().get_oauth_token(conn_config=conn_config)
238
+ return super().get_oauth_token(
239
+ conn_config=conn_config, token_endpoint=token_endpoint, grant_type=grant_type
240
+ )
233
241
 
234
242
  def get_request_url_header_params(self, query_id: str) -> tuple[dict[str, Any], dict[str, Any], str]:
235
243
  """
@@ -20,6 +20,7 @@ from __future__ import annotations
20
20
  import time
21
21
  from collections.abc import Iterable, Mapping, Sequence
22
22
  from datetime import timedelta
23
+ from functools import cached_property
23
24
  from typing import TYPE_CHECKING, Any, SupportsAbs, cast
24
25
 
25
26
  from airflow.configuration import conf
@@ -390,6 +391,7 @@ class SnowflakeSqlApiOperator(SQLExecuteQueryOperator):
390
391
  self.bindings = bindings
391
392
  self.execute_async = False
392
393
  self.deferrable = deferrable
394
+ self.query_ids: list[str] = []
393
395
  if any([warehouse, database, role, schema, authenticator, session_parameters]): # pragma: no cover
394
396
  hook_params = kwargs.pop("hook_params", {}) # pragma: no cover
395
397
  kwargs["hook_params"] = {
@@ -403,6 +405,16 @@ class SnowflakeSqlApiOperator(SQLExecuteQueryOperator):
403
405
  }
404
406
  super().__init__(conn_id=snowflake_conn_id, **kwargs) # pragma: no cover
405
407
 
408
+ @cached_property
409
+ def _hook(self):
410
+ return SnowflakeSqlApiHook(
411
+ snowflake_conn_id=self.snowflake_conn_id,
412
+ token_life_time=self.token_life_time,
413
+ token_renewal_delta=self.token_renewal_delta,
414
+ deferrable=self.deferrable,
415
+ **self.hook_params,
416
+ )
417
+
406
418
  def execute(self, context: Context) -> None:
407
419
  """
408
420
  Make a POST API request to snowflake by using SnowflakeSQL and execute the query to get the ids.
@@ -410,13 +422,6 @@ class SnowflakeSqlApiOperator(SQLExecuteQueryOperator):
410
422
  By deferring the SnowflakeSqlApiTrigger class passed along with query ids.
411
423
  """
412
424
  self.log.info("Executing: %s", self.sql)
413
- self._hook = SnowflakeSqlApiHook(
414
- snowflake_conn_id=self.snowflake_conn_id,
415
- token_life_time=self.token_life_time,
416
- token_renewal_delta=self.token_renewal_delta,
417
- deferrable=self.deferrable,
418
- **self.hook_params,
419
- )
420
425
  self.query_ids = self._hook.execute_query(
421
426
  self.sql, # type: ignore[arg-type]
422
427
  statement_count=self.statement_count,
@@ -504,9 +509,11 @@ class SnowflakeSqlApiOperator(SQLExecuteQueryOperator):
504
509
  msg = f"{event['status']}: {event['message']}"
505
510
  raise AirflowException(msg)
506
511
  if "status" in event and event["status"] == "success":
507
- hook = SnowflakeSqlApiHook(snowflake_conn_id=self.snowflake_conn_id)
508
- query_ids = cast("list[str]", event["statement_query_ids"])
509
- hook.check_query_output(query_ids)
512
+ self.query_ids = cast("list[str]", event["statement_query_ids"])
513
+ self._hook.check_query_output(self.query_ids)
510
514
  self.log.info("%s completed successfully.", self.task_id)
515
+ # Re-assign query_ids to hook after coming back from deferral to be consistent for listeners.
516
+ if not self._hook.query_ids:
517
+ self._hook.query_ids = self.query_ids
511
518
  else:
512
519
  self.log.info("%s completed successfully.", self.task_id)
@@ -27,6 +27,15 @@ from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook
27
27
  from airflow.providers.snowflake.utils.common import enclose_param
28
28
 
29
29
 
30
+ def _validate_parameter(param_name: str, value: str | None) -> str | None:
31
+ """Validate that the parameter doesn't contain any invalid pattern."""
32
+ if value is None:
33
+ return None
34
+ if ";" in value:
35
+ raise ValueError(f"Invalid {param_name}: semicolons (;) not allowed.")
36
+ return value
37
+
38
+
30
39
  class CopyFromExternalStageToSnowflakeOperator(BaseOperator):
31
40
  """
32
41
  Executes a COPY INTO command to load files from an external stage from clouds to Snowflake.
@@ -91,8 +100,8 @@ class CopyFromExternalStageToSnowflakeOperator(BaseOperator):
91
100
  ):
92
101
  super().__init__(**kwargs)
93
102
  self.files = files
94
- self.table = table
95
- self.stage = stage
103
+ self.table = _validate_parameter("table", table)
104
+ self.stage = _validate_parameter("stage", stage)
96
105
  self.prefix = prefix
97
106
  self.file_format = file_format
98
107
  self.schema = schema
@@ -126,7 +135,7 @@ class CopyFromExternalStageToSnowflakeOperator(BaseOperator):
126
135
  if self.schema:
127
136
  into = f"{self.schema}.{self.table}"
128
137
  else:
129
- into = self.table
138
+ into = self.table # type: ignore[assignment]
130
139
 
131
140
  if self.columns_array:
132
141
  into = f"{into}({', '.join(self.columns_array)})"
@@ -74,7 +74,6 @@ class SnowflakeSqlApiTrigger(BaseTrigger):
74
74
  self.token_renewal_delta,
75
75
  )
76
76
  try:
77
- statement_query_ids: list[str] = []
78
77
  for query_id in self.query_ids:
79
78
  while True:
80
79
  statement_status = await self.get_query_status(query_id)
@@ -84,12 +83,10 @@ class SnowflakeSqlApiTrigger(BaseTrigger):
84
83
  if statement_status["status"] == "error":
85
84
  yield TriggerEvent(statement_status)
86
85
  return
87
- if statement_status["status"] == "success":
88
- statement_query_ids.extend(statement_status["statement_handles"])
89
86
  yield TriggerEvent(
90
87
  {
91
88
  "status": "success",
92
- "statement_query_ids": statement_query_ids,
89
+ "statement_query_ids": self.query_ids,
93
90
  }
94
91
  )
95
92
  except Exception as e:
@@ -52,7 +52,15 @@ def fix_account_name(name: str) -> str:
52
52
  account, region = spl
53
53
  cloud = "aws"
54
54
  else:
55
- account, region, cloud = spl
55
+ # region can easily get duplicated without crashing snowflake, so we need to handle that as well
56
+ # eg. account_locator.europe-west3.gcp.europe-west3.gcp will be ok for snowflake
57
+ account, region, cloud, *rest = spl
58
+ rest = [x for x in rest if x not in (region, cloud)]
59
+ if rest: # Not sure what could be left here, but leaving this just in case
60
+ log.warning(
61
+ "Unexpected parts found in Snowflake uri hostname and will be ignored by OpenLineage: %s",
62
+ rest,
63
+ )
56
64
  return f"{account}.{region}.{cloud}"
57
65
 
58
66
  # Check for existing accounts with cloud names
@@ -72,13 +80,16 @@ def fix_snowflake_sqlalchemy_uri(uri: str) -> str:
72
80
  """
73
81
  Fix snowflake sqlalchemy connection URI to OpenLineage structure.
74
82
 
75
- Snowflake sqlalchemy connection URI has following structure:
83
+ Snowflake sqlalchemy connection URI has the following structure:
76
84
  'snowflake://<user_login_name>:<password>@<account_identifier>/<database_name>/<schema_name>?warehouse=<warehouse_name>&role=<role_name>'
77
85
  We want account identifier normalized. It can have two forms:
78
- - newer, in form of <organization>-<id>. In this case we want to do nothing.
79
- - older, composed of <id>-<region>-<cloud> where region and cloud can be
86
+ - newer, in form of <organization_id>-<account_id>. In this case we want to do nothing.
87
+ - older, composed of <account_locator>.<region>.<cloud> where region and cloud can be
80
88
  optional in some cases. If <cloud> is omitted, it's AWS.
81
89
  If region and cloud are omitted, it's AWS us-west-1
90
+
91
+ Current doc on Snowflake account identifiers:
92
+ https://docs.snowflake.com/en/user-guide/admin-account-identifier
82
93
  """
83
94
  try:
84
95
  parts = urlparse(uri)
@@ -24,6 +24,7 @@ from unittest import mock
24
24
  import pytest
25
25
 
26
26
  from airflow.decorators import task
27
+ from airflow.providers.snowflake.version_compat import AIRFLOW_V_3_0_PLUS
27
28
  from airflow.utils import timezone
28
29
 
29
30
  if TYPE_CHECKING:
@@ -156,7 +157,7 @@ class TestSnowparkDecorator:
156
157
  mock_snowflake_hook.return_value.get_snowpark_session.assert_called_once()
157
158
 
158
159
  @mock.patch("airflow.providers.snowflake.operators.snowpark.SnowflakeHook")
159
- def test_snowpark_decorator_multiple_output(self, mock_snowflake_hook, dag_maker):
160
+ def test_snowpark_decorator_multiple_output(self, mock_snowflake_hook, dag_maker, request):
160
161
  @task.snowpark(
161
162
  task_id=TASK_ID,
162
163
  snowflake_conn_id=CONN_ID,
@@ -171,15 +172,23 @@ class TestSnowparkDecorator:
171
172
  assert session == mock_snowflake_hook.return_value.get_snowpark_session.return_value
172
173
  return {"a": 1, "b": "2"}
173
174
 
174
- with dag_maker(dag_id=TEST_DAG_ID):
175
- ret = func()
176
-
177
- dr = dag_maker.create_dagrun()
178
- ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
179
- ti = dr.get_task_instances()[0]
180
- assert ti.xcom_pull(key="a") == 1
181
- assert ti.xcom_pull(key="b") == "2"
182
- assert ti.xcom_pull() == {"a": 1, "b": "2"}
175
+ if AIRFLOW_V_3_0_PLUS:
176
+ run_task = request.getfixturevalue("run_task")
177
+ op = func().operator
178
+ run_task(task=op)
179
+ assert run_task.xcom.get(key="a") == 1
180
+ assert run_task.xcom.get(key="b") == "2"
181
+ assert run_task.xcom.get(key="return_value") == {"a": 1, "b": "2"}
182
+ else:
183
+ with dag_maker(dag_id=TEST_DAG_ID):
184
+ ret = func()
185
+
186
+ dr = dag_maker.create_dagrun()
187
+ ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
188
+ ti = dr.get_task_instances()[0]
189
+ assert ti.xcom_pull(key="a") == 1
190
+ assert ti.xcom_pull(key="b") == "2"
191
+ assert ti.xcom_pull() == {"a": 1, "b": "2"}
183
192
  mock_snowflake_hook.assert_called_once()
184
193
  mock_snowflake_hook.return_value.get_snowpark_session.assert_called_once()
185
194
 
@@ -53,14 +53,13 @@ BASE_CONNECTION_KWARGS: dict = {
53
53
  },
54
54
  }
55
55
 
56
- CONN_PARAMS_OAUTH = {
56
+ CONN_PARAMS_OAUTH_BASE = {
57
57
  "account": "airflow",
58
58
  "application": "AIRFLOW",
59
59
  "authenticator": "oauth",
60
60
  "database": "db",
61
61
  "client_id": "test_client_id",
62
62
  "client_secret": "test_client_pw",
63
- "refresh_token": "secrettoken",
64
63
  "region": "af_region",
65
64
  "role": "af_role",
66
65
  "schema": "public",
@@ -68,6 +67,8 @@ CONN_PARAMS_OAUTH = {
68
67
  "warehouse": "af_wh",
69
68
  }
70
69
 
70
+ CONN_PARAMS_OAUTH = CONN_PARAMS_OAUTH_BASE | {"refresh_token": "secrettoken"}
71
+
71
72
 
72
73
  @pytest.fixture
73
74
  def unencrypted_temporary_private_key(tmp_path: Path) -> Path:
@@ -559,6 +560,112 @@ class TestPytestSnowflakeHook:
559
560
  assert "region" in conn_params_extra_keys
560
561
  assert "account" in conn_params_extra_keys
561
562
 
563
+ @mock.patch("requests.post")
564
+ @mock.patch(
565
+ "airflow.providers.snowflake.hooks.snowflake.SnowflakeHook._get_conn_params",
566
+ new_callable=PropertyMock,
567
+ )
568
+ def test_get_conn_params_should_support_oauth_with_token_endpoint(
569
+ self, mock_get_conn_params, requests_post
570
+ ):
571
+ requests_post.return_value = Mock(
572
+ status_code=200,
573
+ json=lambda: {
574
+ "access_token": "supersecretaccesstoken",
575
+ "expires_in": 600,
576
+ "refresh_token": "secrettoken",
577
+ "token_type": "Bearer",
578
+ "username": "test_user",
579
+ },
580
+ )
581
+ connection_kwargs = {
582
+ **BASE_CONNECTION_KWARGS,
583
+ "login": "test_client_id",
584
+ "password": "test_client_secret",
585
+ "extra": {
586
+ "database": "db",
587
+ "account": "airflow",
588
+ "warehouse": "af_wh",
589
+ "region": "af_region",
590
+ "role": "af_role",
591
+ "refresh_token": "secrettoken",
592
+ "authenticator": "oauth",
593
+ "token_endpoint": "https://www.example.com/oauth/token",
594
+ },
595
+ }
596
+ mock_get_conn_params.return_value = connection_kwargs
597
+ with mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()):
598
+ hook = SnowflakeHook(snowflake_conn_id="test_conn")
599
+ conn_params = hook._get_conn_params
600
+
601
+ conn_params_keys = conn_params.keys()
602
+ conn_params_extra = conn_params.get("extra", {})
603
+ conn_params_extra_keys = conn_params_extra.keys()
604
+
605
+ assert "authenticator" in conn_params_extra_keys
606
+ assert conn_params_extra["authenticator"] == "oauth"
607
+ assert conn_params_extra["token_endpoint"] == "https://www.example.com/oauth/token"
608
+
609
+ assert "user" not in conn_params_keys
610
+ assert "password" in conn_params_keys
611
+ assert "refresh_token" in conn_params_extra_keys
612
+ # Mandatory fields to generate account_identifier `https://<account>.<region>`
613
+ assert "region" in conn_params_extra_keys
614
+ assert "account" in conn_params_extra_keys
615
+
616
+ @mock.patch("requests.post")
617
+ @mock.patch(
618
+ "airflow.providers.snowflake.hooks.snowflake.SnowflakeHook._get_conn_params",
619
+ new_callable=PropertyMock,
620
+ )
621
+ def test_get_conn_params_should_support_oauth_with_client_credentials(
622
+ self, mock_get_conn_params, requests_post
623
+ ):
624
+ requests_post.return_value = Mock(
625
+ status_code=200,
626
+ json=lambda: {
627
+ "access_token": "supersecretaccesstoken",
628
+ "expires_in": 600,
629
+ "refresh_token": "secrettoken",
630
+ "token_type": "Bearer",
631
+ "username": "test_user",
632
+ },
633
+ )
634
+ connection_kwargs = {
635
+ **BASE_CONNECTION_KWARGS,
636
+ "login": "test_client_id",
637
+ "password": "test_client_secret",
638
+ "extra": {
639
+ "database": "db",
640
+ "account": "airflow",
641
+ "warehouse": "af_wh",
642
+ "region": "af_region",
643
+ "role": "af_role",
644
+ "authenticator": "oauth",
645
+ "token_endpoint": "https://www.example.com/oauth/token",
646
+ "grant_type": "client_credentials",
647
+ },
648
+ }
649
+ mock_get_conn_params.return_value = connection_kwargs
650
+ with mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()):
651
+ hook = SnowflakeHook(snowflake_conn_id="test_conn")
652
+ conn_params = hook._get_conn_params
653
+
654
+ conn_params_keys = conn_params.keys()
655
+ conn_params_extra = conn_params.get("extra", {})
656
+ conn_params_extra_keys = conn_params_extra.keys()
657
+
658
+ assert "authenticator" in conn_params_extra_keys
659
+ assert conn_params_extra["authenticator"] == "oauth"
660
+ assert conn_params_extra["grant_type"] == "client_credentials"
661
+
662
+ assert "user" not in conn_params_keys
663
+ assert "password" in conn_params_keys
664
+ assert "refresh_token" not in conn_params_extra_keys
665
+ # Mandatory fields to generate account_identifier `https://<account>.<region>`
666
+ assert "region" in conn_params_extra_keys
667
+ assert "account" in conn_params_extra_keys
668
+
562
669
  def test_should_add_partner_info(self):
563
670
  with mock.patch.dict(
564
671
  "os.environ",
@@ -917,3 +1024,31 @@ class TestPytestSnowflakeHook:
917
1024
  headers={"Content-Type": "application/x-www-form-urlencoded"},
918
1025
  auth=basic_auth,
919
1026
  )
1027
+
1028
+ @mock.patch("airflow.providers.snowflake.hooks.snowflake.HTTPBasicAuth")
1029
+ @mock.patch("requests.post")
1030
+ @mock.patch(
1031
+ "airflow.providers.snowflake.hooks.snowflake.SnowflakeHook._get_conn_params",
1032
+ new_callable=PropertyMock,
1033
+ )
1034
+ def test_get_oauth_token_with_token_endpoint(self, mock_conn_param, requests_post, mock_auth):
1035
+ """Test get_oauth_token method makes the right http request"""
1036
+ basic_auth = {"Authorization": "Basic usernamepassword"}
1037
+ token_endpoint = "https://example.com/oauth/token"
1038
+ mock_conn_param.return_value = CONN_PARAMS_OAUTH
1039
+ requests_post.return_value.status_code = 200
1040
+ mock_auth.return_value = basic_auth
1041
+
1042
+ hook = SnowflakeHook(snowflake_conn_id="mock_conn_id")
1043
+ hook.get_oauth_token(conn_config=CONN_PARAMS_OAUTH, token_endpoint=token_endpoint)
1044
+
1045
+ requests_post.assert_called_once_with(
1046
+ token_endpoint,
1047
+ data={
1048
+ "grant_type": "refresh_token",
1049
+ "refresh_token": CONN_PARAMS_OAUTH["refresh_token"],
1050
+ "redirect_uri": "https://localhost.com",
1051
+ },
1052
+ headers={"Content-Type": "application/x-www-form-urlencoded"},
1053
+ auth=basic_auth,
1054
+ )
@@ -203,6 +203,46 @@ class TestSnowflakeSqlApiHook:
203
203
  query_ids = hook.execute_query(sql, statement_count)
204
204
  assert query_ids == expected_query_ids
205
205
 
206
+ @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.requests")
207
+ @mock.patch(
208
+ "airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook._get_conn_params",
209
+ new_callable=PropertyMock,
210
+ )
211
+ @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook.get_headers")
212
+ def test_execute_query_multiple_times_give_fresh_query_ids_each_time(
213
+ self, mock_get_header, mock_conn_param, mock_requests
214
+ ):
215
+ """Test execute_query method, run query by mocking post request method and return the query ids"""
216
+ sql, statement_count, expected_response, expected_query_ids = (
217
+ SQL_MULTIPLE_STMTS,
218
+ 4,
219
+ {"statementHandles": ["uuid2", "uuid3"]},
220
+ ["uuid2", "uuid3"],
221
+ )
222
+
223
+ mock_requests.codes.ok = 200
224
+ mock_requests.post.side_effect = [
225
+ create_successful_response_mock(expected_response),
226
+ ]
227
+ status_code_mock = mock.PropertyMock(return_value=200)
228
+ type(mock_requests.post.return_value).status_code = status_code_mock
229
+
230
+ hook = SnowflakeSqlApiHook("mock_conn_id")
231
+ query_ids = hook.execute_query(sql, statement_count)
232
+ assert query_ids == expected_query_ids
233
+
234
+ sql, statement_count, expected_response, expected_query_ids = (
235
+ SINGLE_STMT,
236
+ 1,
237
+ {"statementHandle": "uuid"},
238
+ ["uuid"],
239
+ )
240
+ mock_requests.post.side_effect = [
241
+ create_successful_response_mock(expected_response),
242
+ ]
243
+ query_ids = hook.execute_query(sql, statement_count)
244
+ assert query_ids == expected_query_ids
245
+
206
246
  @pytest.mark.parametrize(
207
247
  "sql,statement_count,expected_response, expected_query_ids",
208
248
  [(SINGLE_STMT, 1, {"statementHandle": "uuid"}, ["uuid"])],
@@ -332,6 +332,48 @@ class TestSnowflakeSqlApiOperator:
332
332
  operator.execute_complete(context=None, event=mock_event)
333
333
  mock_log_info.assert_called_with("%s completed successfully.", TASK_ID)
334
334
 
335
+ @pytest.mark.parametrize(
336
+ "mock_event",
337
+ [
338
+ None,
339
+ ({"status": "success", "statement_query_ids": ["uuid", "uuid"]}),
340
+ ],
341
+ )
342
+ @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook.check_query_output")
343
+ def test_snowflake_sql_api_execute_complete_reassigns_query_ids(self, mock_conn, mock_event):
344
+ """Tests execute_complete assert with successful message"""
345
+
346
+ operator = SnowflakeSqlApiOperator(
347
+ task_id=TASK_ID,
348
+ snowflake_conn_id=CONN_ID,
349
+ sql=SQL_MULTIPLE_STMTS,
350
+ statement_count=4,
351
+ deferrable=True,
352
+ )
353
+ expected_query_ids = mock_event["statement_query_ids"] if mock_event else []
354
+
355
+ assert operator.query_ids == []
356
+ assert operator._hook.query_ids == []
357
+
358
+ operator.execute_complete(context=None, event=mock_event)
359
+
360
+ assert operator.query_ids == expected_query_ids
361
+ assert operator._hook.query_ids == expected_query_ids
362
+
363
+ def test_snowflake_sql_api_caches_hook(self):
364
+ """Tests execute_complete assert with successful message"""
365
+
366
+ operator = SnowflakeSqlApiOperator(
367
+ task_id=TASK_ID,
368
+ snowflake_conn_id=CONN_ID,
369
+ sql=SQL_MULTIPLE_STMTS,
370
+ statement_count=4,
371
+ deferrable=True,
372
+ )
373
+ hook1 = operator._hook
374
+ hook2 = operator._hook
375
+ assert hook1 is hook2
376
+
335
377
  @mock.patch("airflow.providers.snowflake.operators.snowflake.SnowflakeSqlApiOperator.defer")
336
378
  def test_snowflake_sql_api_execute_operator_failed_before_defer(
337
379
  self, mock_defer, mock_execute_query, mock_get_sql_api_query_status
@@ -549,3 +549,26 @@ class TestCopyFromExternalStageToSnowflake:
549
549
  },
550
550
  job_facets={"sql": SQLJobFacet(query=expected_sql)},
551
551
  )
552
+
553
+ def test_init_with_invalid_parameters(self):
554
+ # Test by passing invalid parameters to the operator
555
+ import re
556
+
557
+ with pytest.raises(ValueError, match=re.escape("Invalid table: semicolons (;) not allowed.")):
558
+ _ = CopyFromExternalStageToSnowflakeOperator(
559
+ task_id="test",
560
+ table="table; some_new_table",
561
+ stage="stage",
562
+ database="",
563
+ schema="schema",
564
+ file_format="CSV",
565
+ )
566
+ with pytest.raises(ValueError, match=re.escape("Invalid stage: semicolons (;) not allowed.")):
567
+ _ = CopyFromExternalStageToSnowflakeOperator(
568
+ task_id="test",
569
+ table="table",
570
+ stage="stage; some_new_stabge",
571
+ database="",
572
+ schema="schema",
573
+ file_format="CSV",
574
+ )
@@ -30,12 +30,13 @@ POLL_INTERVAL = 1.0
30
30
  LIFETIME = timedelta(minutes=59)
31
31
  RENEWAL_DELTA = timedelta(minutes=54)
32
32
  MODULE = "airflow.providers.snowflake"
33
+ QUERY_IDS = ["uuid"]
33
34
 
34
35
 
35
36
  class TestSnowflakeSqlApiTrigger:
36
37
  TRIGGER = SnowflakeSqlApiTrigger(
37
38
  poll_interval=POLL_INTERVAL,
38
- query_ids=["uuid"],
39
+ query_ids=QUERY_IDS,
39
40
  snowflake_conn_id="test_conn",
40
41
  token_life_time=LIFETIME,
41
42
  token_renewal_delta=RENEWAL_DELTA,
@@ -82,8 +83,8 @@ class TestSnowflakeSqlApiTrigger:
82
83
  Test SnowflakeSqlApiTrigger run method with success status and mock the get_sql_api_query_status
83
84
  result and get_query_status to False.
84
85
  """
85
- mock_get_query_status.return_value = {"status": "success", "statement_handles": ["uuid", "uuid1"]}
86
86
  statement_query_ids = ["uuid", "uuid1"]
87
+ mock_get_query_status.return_value = {"status": "success", "statement_handles": statement_query_ids}
87
88
  mock_get_sql_api_query_status_async.return_value = {
88
89
  "message": "Statement executed successfully.",
89
90
  "status": "success",
@@ -92,7 +93,7 @@ class TestSnowflakeSqlApiTrigger:
92
93
 
93
94
  generator = self.TRIGGER.run()
94
95
  actual = await generator.asend(None)
95
- assert TriggerEvent({"status": "success", "statement_query_ids": statement_query_ids}) == actual
96
+ assert TriggerEvent({"status": "success", "statement_query_ids": QUERY_IDS}) == actual
96
97
 
97
98
  @pytest.mark.asyncio
98
99
  @mock.patch(f"{MODULE}.hooks.snowflake_sql_api.SnowflakeSqlApiHook.get_sql_api_query_status_async")
@@ -93,6 +93,9 @@ def test_snowflake_sqlite_account_urls(source, target):
93
93
  ("xy12345", "xy12345.us-west-1.aws"), # No '-' or '_' in name
94
94
  ("xy12345.us-west-1.aws", "xy12345.us-west-1.aws"), # Already complete locator
95
95
  ("xy12345.us-west-2.gcp", "xy12345.us-west-2.gcp"), # Already complete locator for GCP
96
+ ("xy12345.us-west-2.gcp.us-west-2.gcp", "xy12345.us-west-2.gcp"), # Duplicated region
97
+ ("xy12345.us-west-2.gcp.us-west-2.gcp.us-west-2.gcp", "xy12345.us-west-2.gcp"), # Triple region
98
+ ("xy12345.us-west-2.gcp.some_random_part", "xy12345.us-west-2.gcp"), # Suffix to locator, ignored
96
99
  ("xy12345aws", "xy12345aws.us-west-1.aws"), # AWS without '-' or '_'
97
100
  ("xy12345-aws", "xy12345-aws"), # AWS with '-'
98
101
  ("xy12345_gcp-europe-west1", "xy12345.europe-west1.gcp"), # GCP with '_'