dataproc-spark-connect 0.2.1__py2.py3-none-any.whl → 0.7.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataproc_spark_connect-0.7.0.dist-info/METADATA +98 -0
- dataproc_spark_connect-0.7.0.dist-info/RECORD +12 -0
- google/cloud/dataproc_spark_connect/__init__.py +14 -8
- google/cloud/dataproc_spark_connect/client/core.py +34 -8
- google/cloud/dataproc_spark_connect/client/proxy.py +15 -12
- google/cloud/dataproc_spark_connect/exceptions.py +27 -0
- google/cloud/dataproc_spark_connect/pypi_artifacts.py +48 -0
- google/cloud/dataproc_spark_connect/session.py +339 -242
- dataproc_spark_connect-0.2.1.dist-info/METADATA +0 -119
- dataproc_spark_connect-0.2.1.dist-info/RECORD +0 -10
- {dataproc_spark_connect-0.2.1.dist-info → dataproc_spark_connect-0.7.0.dist-info}/LICENSE +0 -0
- {dataproc_spark_connect-0.2.1.dist-info → dataproc_spark_connect-0.7.0.dist-info}/WHEEL +0 -0
- {dataproc_spark_connect-0.2.1.dist-info → dataproc_spark_connect-0.7.0.dist-info}/top_level.txt +0 -0
|
@@ -11,36 +11,37 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import atexit
|
|
16
|
+
import datetime
|
|
14
17
|
import json
|
|
15
18
|
import logging
|
|
16
19
|
import os
|
|
17
20
|
import random
|
|
18
21
|
import string
|
|
22
|
+
import threading
|
|
19
23
|
import time
|
|
20
|
-
import
|
|
21
|
-
from time import sleep
|
|
22
|
-
from typing import Any, cast, ClassVar, Dict, Optional
|
|
24
|
+
import tqdm
|
|
23
25
|
|
|
24
26
|
from google.api_core import retry
|
|
25
|
-
from google.api_core.future.polling import POLLING_PREDICATE
|
|
26
27
|
from google.api_core.client_options import ClientOptions
|
|
27
|
-
from google.api_core.exceptions import FailedPrecondition, InvalidArgument, NotFound
|
|
28
|
-
from google.
|
|
29
|
-
|
|
28
|
+
from google.api_core.exceptions import Aborted, FailedPrecondition, InvalidArgument, NotFound, PermissionDenied
|
|
29
|
+
from google.api_core.future.polling import POLLING_PREDICATE
|
|
30
30
|
from google.cloud.dataproc_spark_connect.client import DataprocChannelBuilder
|
|
31
|
+
from google.cloud.dataproc_spark_connect.exceptions import DataprocSparkConnectException
|
|
32
|
+
from google.cloud.dataproc_spark_connect.pypi_artifacts import PyPiArtifacts
|
|
31
33
|
from google.cloud.dataproc_v1 import (
|
|
34
|
+
AuthenticationConfig,
|
|
32
35
|
CreateSessionRequest,
|
|
33
36
|
GetSessionRequest,
|
|
34
37
|
Session,
|
|
35
38
|
SessionControllerClient,
|
|
36
|
-
SessionTemplate,
|
|
37
39
|
TerminateSessionRequest,
|
|
38
40
|
)
|
|
39
|
-
from google.
|
|
40
|
-
from google.protobuf.text_format import ParseError
|
|
41
|
+
from google.cloud.dataproc_v1.types import sessions
|
|
41
42
|
from pyspark.sql.connect.session import SparkSession
|
|
42
43
|
from pyspark.sql.utils import to_str
|
|
43
|
-
|
|
44
|
+
from typing import Any, cast, ClassVar, Dict, Optional
|
|
44
45
|
|
|
45
46
|
# Set up logging
|
|
46
47
|
logging.basicConfig(level=logging.INFO)
|
|
@@ -61,11 +62,13 @@ class DataprocSparkSession(SparkSession):
|
|
|
61
62
|
>>> spark = (
|
|
62
63
|
... DataprocSparkSession.builder
|
|
63
64
|
... .appName("Word Count")
|
|
64
|
-
... .
|
|
65
|
+
... .dataprocSessionConfig(Session())
|
|
65
66
|
... .getOrCreate()
|
|
66
67
|
... ) # doctest: +SKIP
|
|
67
68
|
"""
|
|
68
69
|
|
|
70
|
+
_DEFAULT_RUNTIME_VERSION = "2.2"
|
|
71
|
+
|
|
69
72
|
_active_s8s_session_uuid: ClassVar[Optional[str]] = None
|
|
70
73
|
_project_id = None
|
|
71
74
|
_region = None
|
|
@@ -74,7 +77,12 @@ class DataprocSparkSession(SparkSession):
|
|
|
74
77
|
|
|
75
78
|
class Builder(SparkSession.Builder):
|
|
76
79
|
|
|
77
|
-
|
|
80
|
+
_dataproc_runtime_to_spark_version = {
|
|
81
|
+
"1.2": "3.5",
|
|
82
|
+
"2.2": "3.5",
|
|
83
|
+
"2.3": "3.5",
|
|
84
|
+
"3.0": "4.0",
|
|
85
|
+
}
|
|
78
86
|
|
|
79
87
|
_session_static_configs = [
|
|
80
88
|
"spark.executor.cores",
|
|
@@ -90,10 +98,10 @@ class DataprocSparkSession(SparkSession):
|
|
|
90
98
|
self._options: Dict[str, Any] = {}
|
|
91
99
|
self._channel_builder: Optional[DataprocChannelBuilder] = None
|
|
92
100
|
self._dataproc_config: Optional[Session] = None
|
|
93
|
-
self._project_id = os.
|
|
94
|
-
self._region = os.
|
|
101
|
+
self._project_id = os.getenv("GOOGLE_CLOUD_PROJECT")
|
|
102
|
+
self._region = os.getenv("GOOGLE_CLOUD_REGION")
|
|
95
103
|
self._client_options = ClientOptions(
|
|
96
|
-
api_endpoint=os.
|
|
104
|
+
api_endpoint=os.getenv(
|
|
97
105
|
"GOOGLE_CLOUD_DATAPROC_API_ENDPOINT",
|
|
98
106
|
f"{self._region}-dataproc.googleapis.com",
|
|
99
107
|
)
|
|
@@ -112,15 +120,15 @@ class DataprocSparkSession(SparkSession):
|
|
|
112
120
|
self._project_id = project_id
|
|
113
121
|
return self
|
|
114
122
|
|
|
115
|
-
def
|
|
116
|
-
self._region =
|
|
117
|
-
self._client_options.api_endpoint = os.
|
|
123
|
+
def location(self, location):
|
|
124
|
+
self._region = location
|
|
125
|
+
self._client_options.api_endpoint = os.getenv(
|
|
118
126
|
"GOOGLE_CLOUD_DATAPROC_API_ENDPOINT",
|
|
119
127
|
f"{self._region}-dataproc.googleapis.com",
|
|
120
128
|
)
|
|
121
129
|
return self
|
|
122
130
|
|
|
123
|
-
def
|
|
131
|
+
def dataprocSessionConfig(self, dataproc_config: Session):
|
|
124
132
|
with self._lock:
|
|
125
133
|
self._dataproc_config = dataproc_config
|
|
126
134
|
for k, v in dataproc_config.runtime_config.properties.items():
|
|
@@ -135,14 +143,14 @@ class DataprocSparkSession(SparkSession):
|
|
|
135
143
|
else:
|
|
136
144
|
return self
|
|
137
145
|
|
|
138
|
-
def create(self) -> "
|
|
146
|
+
def create(self) -> "DataprocSparkSession":
|
|
139
147
|
raise NotImplemented(
|
|
140
148
|
"DataprocSparkSession allows session creation only through getOrCreate"
|
|
141
149
|
)
|
|
142
150
|
|
|
143
151
|
def __create_spark_connect_session_from_s8s(
|
|
144
|
-
self, session_response
|
|
145
|
-
) -> "
|
|
152
|
+
self, session_response, session_name
|
|
153
|
+
) -> "DataprocSparkSession":
|
|
146
154
|
DataprocSparkSession._active_s8s_session_uuid = (
|
|
147
155
|
session_response.uuid
|
|
148
156
|
)
|
|
@@ -152,10 +160,14 @@ class DataprocSparkSession(SparkSession):
|
|
|
152
160
|
spark_connect_url = session_response.runtime_info.endpoints.get(
|
|
153
161
|
"Spark Connect Server"
|
|
154
162
|
)
|
|
155
|
-
|
|
156
|
-
url = f"{spark_connect_url.replace('.com/', '.com:443/')};session_id={session_response.uuid};use_ssl=true"
|
|
163
|
+
url = f"{spark_connect_url}/;session_id={session_response.uuid};use_ssl=true"
|
|
157
164
|
logger.debug(f"Spark Connect URL: {url}")
|
|
158
|
-
self._channel_builder = DataprocChannelBuilder(
|
|
165
|
+
self._channel_builder = DataprocChannelBuilder(
|
|
166
|
+
url,
|
|
167
|
+
is_active_callback=lambda: is_s8s_session_active(
|
|
168
|
+
session_name, self._client_options
|
|
169
|
+
),
|
|
170
|
+
)
|
|
159
171
|
|
|
160
172
|
assert self._channel_builder is not None
|
|
161
173
|
session = DataprocSparkSession(connection=self._channel_builder)
|
|
@@ -164,77 +176,104 @@ class DataprocSparkSession(SparkSession):
|
|
|
164
176
|
self.__apply_options(session)
|
|
165
177
|
return session
|
|
166
178
|
|
|
167
|
-
def __create(self) -> "
|
|
179
|
+
def __create(self) -> "DataprocSparkSession":
|
|
168
180
|
with self._lock:
|
|
169
181
|
|
|
170
182
|
if self._options.get("spark.remote", False):
|
|
171
183
|
raise NotImplemented(
|
|
172
|
-
"DataprocSparkSession does not support connecting to an existing remote server"
|
|
184
|
+
"DataprocSparkSession does not support connecting to an existing Spark Connect remote server"
|
|
173
185
|
)
|
|
174
186
|
|
|
175
187
|
from google.cloud.dataproc_v1 import SessionControllerClient
|
|
176
188
|
|
|
177
189
|
dataproc_config: Session = self._get_dataproc_config()
|
|
178
|
-
session_template: SessionTemplate = self._get_session_template()
|
|
179
190
|
|
|
180
|
-
self.
|
|
181
|
-
dataproc_config, session_template
|
|
182
|
-
)
|
|
191
|
+
self._validate_version(dataproc_config)
|
|
183
192
|
|
|
184
|
-
spark_connect_session = self._get_spark_connect_session(
|
|
185
|
-
dataproc_config, session_template
|
|
186
|
-
)
|
|
187
|
-
|
|
188
|
-
spark = self._get_spark(dataproc_config, session_template)
|
|
189
|
-
|
|
190
|
-
if not spark_connect_session:
|
|
191
|
-
dataproc_config.spark_connect_session = {}
|
|
192
|
-
if not spark:
|
|
193
|
-
dataproc_config.spark = {}
|
|
194
|
-
os.environ["SPARK_CONNECT_MODE_ENABLED"] = "1"
|
|
195
|
-
session_request = CreateSessionRequest()
|
|
196
193
|
session_id = self.generate_dataproc_session_id()
|
|
197
|
-
|
|
198
|
-
session_request.session_id = session_id
|
|
199
194
|
dataproc_config.name = f"projects/{self._project_id}/locations/{self._region}/sessions/{session_id}"
|
|
200
195
|
logger.debug(
|
|
201
|
-
f"
|
|
196
|
+
f"Dataproc Session configuration:\n{dataproc_config}"
|
|
202
197
|
)
|
|
198
|
+
|
|
199
|
+
session_request = CreateSessionRequest()
|
|
200
|
+
session_request.session_id = session_id
|
|
203
201
|
session_request.session = dataproc_config
|
|
204
202
|
session_request.parent = (
|
|
205
203
|
f"projects/{self._project_id}/locations/{self._region}"
|
|
206
204
|
)
|
|
207
205
|
|
|
208
|
-
logger.debug("Creating
|
|
206
|
+
logger.debug("Creating Dataproc Session")
|
|
209
207
|
DataprocSparkSession._active_s8s_session_id = session_id
|
|
210
208
|
s8s_creation_start_time = time.time()
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
"Creating Spark session. It may take few minutes."
|
|
209
|
+
|
|
210
|
+
stop_create_session_pbar = False
|
|
211
|
+
|
|
212
|
+
def create_session_pbar():
|
|
213
|
+
iterations = 150
|
|
214
|
+
pbar = tqdm.trange(
|
|
215
|
+
iterations,
|
|
216
|
+
bar_format="{bar}",
|
|
217
|
+
ncols=80,
|
|
221
218
|
)
|
|
219
|
+
for i in pbar:
|
|
220
|
+
if stop_create_session_pbar:
|
|
221
|
+
break
|
|
222
|
+
# Last iteration
|
|
223
|
+
if i >= iterations - 1:
|
|
224
|
+
# Sleep until session created
|
|
225
|
+
while not stop_create_session_pbar:
|
|
226
|
+
time.sleep(1)
|
|
227
|
+
else:
|
|
228
|
+
time.sleep(1)
|
|
229
|
+
|
|
230
|
+
pbar.close()
|
|
231
|
+
# Print new line after the progress bar
|
|
232
|
+
print()
|
|
233
|
+
|
|
234
|
+
create_session_pbar_thread = threading.Thread(
|
|
235
|
+
target=create_session_pbar
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
try:
|
|
239
|
+
if (
|
|
240
|
+
os.getenv(
|
|
241
|
+
"DATAPROC_SPARK_CONNECT_SESSION_TERMINATE_AT_EXIT",
|
|
242
|
+
"false",
|
|
243
|
+
)
|
|
244
|
+
== "true"
|
|
245
|
+
):
|
|
246
|
+
atexit.register(
|
|
247
|
+
lambda: terminate_s8s_session(
|
|
248
|
+
self._project_id,
|
|
249
|
+
self._region,
|
|
250
|
+
session_id,
|
|
251
|
+
self._client_options,
|
|
252
|
+
)
|
|
253
|
+
)
|
|
222
254
|
operation = SessionControllerClient(
|
|
223
255
|
client_options=self._client_options
|
|
224
256
|
).create_session(session_request)
|
|
225
257
|
print(
|
|
226
|
-
f"
|
|
258
|
+
f"Creating Dataproc Session: https://console.cloud.google.com/dataproc/interactive/{self._region}/{session_id}?project={self._project_id}"
|
|
227
259
|
)
|
|
260
|
+
create_session_pbar_thread.start()
|
|
228
261
|
session_response: Session = operation.result(
|
|
229
|
-
polling=
|
|
262
|
+
polling=retry.Retry(
|
|
263
|
+
predicate=POLLING_PREDICATE,
|
|
264
|
+
initial=5.0, # seconds
|
|
265
|
+
maximum=5.0, # seconds
|
|
266
|
+
multiplier=1.0,
|
|
267
|
+
timeout=600, # seconds
|
|
268
|
+
)
|
|
230
269
|
)
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
270
|
+
stop_create_session_pbar = True
|
|
271
|
+
create_session_pbar_thread.join()
|
|
272
|
+
print("Dataproc Session was successfully created")
|
|
273
|
+
file_path = (
|
|
274
|
+
DataprocSparkSession._get_active_session_file_path()
|
|
275
|
+
)
|
|
276
|
+
if file_path is not None:
|
|
238
277
|
try:
|
|
239
278
|
session_data = {
|
|
240
279
|
"session_name": session_response.name,
|
|
@@ -247,75 +286,67 @@ class DataprocSparkSession(SparkSession):
|
|
|
247
286
|
json.dump(session_data, json_file, indent=4)
|
|
248
287
|
except Exception as e:
|
|
249
288
|
logger.error(
|
|
250
|
-
f"Exception while writing active session to file {file_path}
|
|
289
|
+
f"Exception while writing active session to file {file_path}, {e}"
|
|
251
290
|
)
|
|
252
|
-
except InvalidArgument as e:
|
|
291
|
+
except (InvalidArgument, PermissionDenied) as e:
|
|
292
|
+
stop_create_session_pbar = True
|
|
293
|
+
if create_session_pbar_thread.is_alive():
|
|
294
|
+
create_session_pbar_thread.join()
|
|
253
295
|
DataprocSparkSession._active_s8s_session_id = None
|
|
254
|
-
raise
|
|
255
|
-
f"Error while creating
|
|
256
|
-
)
|
|
296
|
+
raise DataprocSparkConnectException(
|
|
297
|
+
f"Error while creating Dataproc Session: {e.message}"
|
|
298
|
+
)
|
|
257
299
|
except Exception as e:
|
|
300
|
+
stop_create_session_pbar = True
|
|
301
|
+
if create_session_pbar_thread.is_alive():
|
|
302
|
+
create_session_pbar_thread.join()
|
|
258
303
|
DataprocSparkSession._active_s8s_session_id = None
|
|
259
304
|
raise RuntimeError(
|
|
260
|
-
f"Error while creating
|
|
261
|
-
) from
|
|
305
|
+
f"Error while creating Dataproc Session"
|
|
306
|
+
) from e
|
|
262
307
|
|
|
263
308
|
logger.debug(
|
|
264
|
-
f"
|
|
309
|
+
f"Dataproc Session created: {session_id} in {int(time.time() - s8s_creation_start_time)} seconds"
|
|
265
310
|
)
|
|
266
311
|
return self.__create_spark_connect_session_from_s8s(
|
|
267
|
-
session_response
|
|
312
|
+
session_response, dataproc_config.name
|
|
268
313
|
)
|
|
269
314
|
|
|
270
|
-
def
|
|
271
|
-
self,
|
|
272
|
-
) -> Optional[
|
|
273
|
-
session_name = f"projects/{self._project_id}/locations/{self._region}/sessions/{s8s_session_id}"
|
|
274
|
-
get_session_request = GetSessionRequest()
|
|
275
|
-
get_session_request.name = session_name
|
|
276
|
-
state = None
|
|
277
|
-
try:
|
|
278
|
-
get_session_response = SessionControllerClient(
|
|
279
|
-
client_options=self._client_options
|
|
280
|
-
).get_session(get_session_request)
|
|
281
|
-
state = get_session_response.state
|
|
282
|
-
except Exception as e:
|
|
283
|
-
logger.debug(f"{s8s_session_id} deleted: {e}")
|
|
284
|
-
return None
|
|
285
|
-
|
|
286
|
-
if state is not None and (
|
|
287
|
-
state == Session.State.ACTIVE or state == Session.State.CREATING
|
|
288
|
-
):
|
|
289
|
-
return get_session_response
|
|
290
|
-
return None
|
|
291
|
-
|
|
292
|
-
def _get_exiting_active_session(self) -> Optional["SparkSession"]:
|
|
315
|
+
def _get_exiting_active_session(
|
|
316
|
+
self,
|
|
317
|
+
) -> Optional["DataprocSparkSession"]:
|
|
293
318
|
s8s_session_id = DataprocSparkSession._active_s8s_session_id
|
|
294
|
-
|
|
319
|
+
session_name = f"projects/{self._project_id}/locations/{self._region}/sessions/{s8s_session_id}"
|
|
320
|
+
session_response = None
|
|
321
|
+
session = None
|
|
322
|
+
if s8s_session_id is not None:
|
|
323
|
+
session_response = get_active_s8s_session_response(
|
|
324
|
+
session_name, self._client_options
|
|
325
|
+
)
|
|
326
|
+
session = DataprocSparkSession.getActiveSession()
|
|
295
327
|
|
|
296
|
-
session = DataprocSparkSession.getActiveSession()
|
|
297
328
|
if session is None:
|
|
298
329
|
session = DataprocSparkSession._default_session
|
|
299
330
|
|
|
300
331
|
if session_response is not None:
|
|
301
332
|
print(
|
|
302
|
-
f"Using existing
|
|
333
|
+
f"Using existing Dataproc Session (configuration changes may not be applied): https://console.cloud.google.com/dataproc/interactive/{self._region}/{s8s_session_id}?project={self._project_id}"
|
|
303
334
|
)
|
|
304
335
|
if session is None:
|
|
305
336
|
session = self.__create_spark_connect_session_from_s8s(
|
|
306
|
-
session_response
|
|
337
|
+
session_response, session_name
|
|
307
338
|
)
|
|
308
339
|
return session
|
|
309
340
|
else:
|
|
310
|
-
logger.info(
|
|
311
|
-
f"Session: {s8s_session_id} not active, stopping previous spark session and creating new"
|
|
312
|
-
)
|
|
313
341
|
if session is not None:
|
|
342
|
+
print(
|
|
343
|
+
f"{s8s_session_id} Dataproc Session is not active, stopping and creating a new one"
|
|
344
|
+
)
|
|
314
345
|
session.stop()
|
|
315
346
|
|
|
316
347
|
return None
|
|
317
348
|
|
|
318
|
-
def getOrCreate(self) -> "
|
|
349
|
+
def getOrCreate(self) -> "DataprocSparkSession":
|
|
319
350
|
with DataprocSparkSession._lock:
|
|
320
351
|
session = self._get_exiting_active_session()
|
|
321
352
|
if session is None:
|
|
@@ -330,21 +361,52 @@ class DataprocSparkSession(SparkSession):
|
|
|
330
361
|
dataproc_config = self._dataproc_config
|
|
331
362
|
for k, v in self._options.items():
|
|
332
363
|
dataproc_config.runtime_config.properties[k] = v
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
364
|
+
dataproc_config.spark_connect_session = (
|
|
365
|
+
sessions.SparkConnectConfig()
|
|
366
|
+
)
|
|
367
|
+
if not dataproc_config.runtime_config.version:
|
|
368
|
+
dataproc_config.runtime_config.version = (
|
|
369
|
+
DataprocSparkSession._DEFAULT_RUNTIME_VERSION
|
|
370
|
+
)
|
|
371
|
+
if (
|
|
372
|
+
not dataproc_config.environment_config.execution_config.authentication_config.user_workload_authentication_type
|
|
373
|
+
and "DATAPROC_SPARK_CONNECT_AUTH_TYPE" in os.environ
|
|
374
|
+
):
|
|
375
|
+
dataproc_config.environment_config.execution_config.authentication_config.user_workload_authentication_type = AuthenticationConfig.AuthenticationType[
|
|
376
|
+
os.getenv("DATAPROC_SPARK_CONNECT_AUTH_TYPE")
|
|
336
377
|
]
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
378
|
+
if (
|
|
379
|
+
not dataproc_config.environment_config.execution_config.service_account
|
|
380
|
+
and "DATAPROC_SPARK_CONNECT_SERVICE_ACCOUNT" in os.environ
|
|
381
|
+
):
|
|
382
|
+
dataproc_config.environment_config.execution_config.service_account = os.getenv(
|
|
383
|
+
"DATAPROC_SPARK_CONNECT_SERVICE_ACCOUNT"
|
|
384
|
+
)
|
|
385
|
+
if (
|
|
386
|
+
not dataproc_config.environment_config.execution_config.subnetwork_uri
|
|
387
|
+
and "DATAPROC_SPARK_CONNECT_SUBNET" in os.environ
|
|
388
|
+
):
|
|
389
|
+
dataproc_config.environment_config.execution_config.subnetwork_uri = os.getenv(
|
|
390
|
+
"DATAPROC_SPARK_CONNECT_SUBNET"
|
|
391
|
+
)
|
|
392
|
+
if (
|
|
393
|
+
not dataproc_config.environment_config.execution_config.ttl
|
|
394
|
+
and "DATAPROC_SPARK_CONNECT_TTL_SECONDS" in os.environ
|
|
395
|
+
):
|
|
396
|
+
dataproc_config.environment_config.execution_config.ttl = {
|
|
397
|
+
"seconds": int(
|
|
398
|
+
os.getenv("DATAPROC_SPARK_CONNECT_TTL_SECONDS")
|
|
399
|
+
)
|
|
400
|
+
}
|
|
401
|
+
if (
|
|
402
|
+
not dataproc_config.environment_config.execution_config.idle_ttl
|
|
403
|
+
and "DATAPROC_SPARK_CONNECT_IDLE_TTL_SECONDS" in os.environ
|
|
404
|
+
):
|
|
405
|
+
dataproc_config.environment_config.execution_config.idle_ttl = {
|
|
406
|
+
"seconds": int(
|
|
407
|
+
os.getenv("DATAPROC_SPARK_CONNECT_IDLE_TTL_SECONDS")
|
|
408
|
+
)
|
|
409
|
+
}
|
|
348
410
|
if "COLAB_NOTEBOOK_RUNTIME_ID" in os.environ:
|
|
349
411
|
dataproc_config.labels["colab-notebook-runtime-id"] = (
|
|
350
412
|
os.environ["COLAB_NOTEBOOK_RUNTIME_ID"]
|
|
@@ -355,95 +417,38 @@ class DataprocSparkSession(SparkSession):
|
|
|
355
417
|
]
|
|
356
418
|
return dataproc_config
|
|
357
419
|
|
|
358
|
-
def
|
|
359
|
-
|
|
360
|
-
GetSessionTemplateRequest,
|
|
361
|
-
SessionTemplateControllerClient,
|
|
362
|
-
)
|
|
363
|
-
|
|
364
|
-
session_template = None
|
|
365
|
-
if self._dataproc_config and self._dataproc_config.session_template:
|
|
366
|
-
session_template = self._dataproc_config.session_template
|
|
367
|
-
get_session_template_request = GetSessionTemplateRequest()
|
|
368
|
-
get_session_template_request.name = session_template
|
|
369
|
-
client = SessionTemplateControllerClient(
|
|
370
|
-
client_options=self._client_options
|
|
371
|
-
)
|
|
372
|
-
try:
|
|
373
|
-
session_template = client.get_session_template(
|
|
374
|
-
get_session_template_request
|
|
375
|
-
)
|
|
376
|
-
except Exception as e:
|
|
377
|
-
logger.error(
|
|
378
|
-
f"Failed to get session template {session_template}: {e}"
|
|
379
|
-
)
|
|
380
|
-
raise
|
|
381
|
-
return session_template
|
|
420
|
+
def _validate_version(self, dataproc_config):
|
|
421
|
+
trim_version = lambda v: ".".join(v.split(".")[:2])
|
|
382
422
|
|
|
383
|
-
|
|
384
|
-
trimmed_version = lambda v: ".".join(v.split(".")[:2])
|
|
385
|
-
version = None
|
|
423
|
+
version = dataproc_config.runtime_config.version
|
|
386
424
|
if (
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
and dataproc_config.runtime_config.version
|
|
390
|
-
):
|
|
391
|
-
version = dataproc_config.runtime_config.version
|
|
392
|
-
elif (
|
|
393
|
-
session_template
|
|
394
|
-
and session_template.runtime_config
|
|
395
|
-
and session_template.runtime_config.version
|
|
396
|
-
):
|
|
397
|
-
version = session_template.runtime_config.version
|
|
398
|
-
|
|
399
|
-
if not version:
|
|
400
|
-
version = "3.0"
|
|
401
|
-
dataproc_config.runtime_config.version = version
|
|
402
|
-
elif (
|
|
403
|
-
trimmed_version(version)
|
|
404
|
-
not in self._dataproc_runtime_spark_version
|
|
425
|
+
trim_version(version)
|
|
426
|
+
not in self._dataproc_runtime_to_spark_version
|
|
405
427
|
):
|
|
406
428
|
raise ValueError(
|
|
407
|
-
f"
|
|
408
|
-
f"Supported versions: {self.
|
|
429
|
+
f"Specified {version} Dataproc Spark runtime version is not supported. "
|
|
430
|
+
f"Supported runtime versions: {self._dataproc_runtime_to_spark_version.keys()}"
|
|
409
431
|
)
|
|
410
432
|
|
|
411
|
-
server_version = self.
|
|
412
|
-
|
|
433
|
+
server_version = self._dataproc_runtime_to_spark_version[
|
|
434
|
+
trim_version(version)
|
|
413
435
|
]
|
|
436
|
+
|
|
414
437
|
import importlib.metadata
|
|
415
438
|
|
|
416
439
|
dataproc_connect_version = importlib.metadata.version(
|
|
417
440
|
"dataproc-spark-connect"
|
|
418
441
|
)
|
|
419
442
|
client_version = importlib.metadata.version("pyspark")
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
logger.warning(
|
|
426
|
-
f"client and server on different versions: {version_message}"
|
|
443
|
+
if trim_version(client_version) != trim_version(server_version):
|
|
444
|
+
print(
|
|
445
|
+
f"Spark Connect client and server use different versions:\n"
|
|
446
|
+
f"- Dataproc Spark Connect client {dataproc_connect_version} (PySpark {client_version})\n"
|
|
447
|
+
f"- Dataproc Spark runtime {version} (Spark {server_version})"
|
|
427
448
|
)
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
def
|
|
431
|
-
spark_connect_session = None
|
|
432
|
-
if dataproc_config and dataproc_config.spark_connect_session:
|
|
433
|
-
spark_connect_session = dataproc_config.spark_connect_session
|
|
434
|
-
elif session_template and session_template.spark_connect_session:
|
|
435
|
-
spark_connect_session = session_template.spark_connect_session
|
|
436
|
-
return spark_connect_session
|
|
437
|
-
|
|
438
|
-
def _get_spark(self, dataproc_config, session_template):
|
|
439
|
-
spark = None
|
|
440
|
-
if dataproc_config and dataproc_config.spark:
|
|
441
|
-
spark = dataproc_config.spark
|
|
442
|
-
elif session_template and session_template.spark:
|
|
443
|
-
spark = session_template.spark
|
|
444
|
-
return spark
|
|
445
|
-
|
|
446
|
-
def generate_dataproc_session_id(self):
|
|
449
|
+
|
|
450
|
+
@staticmethod
|
|
451
|
+
def generate_dataproc_session_id():
|
|
447
452
|
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|
|
448
453
|
suffix_length = 6
|
|
449
454
|
random_suffix = "".join(
|
|
@@ -456,75 +461,102 @@ class DataprocSparkSession(SparkSession):
|
|
|
456
461
|
def _repr_html_(self) -> str:
|
|
457
462
|
if not self._active_s8s_session_id:
|
|
458
463
|
return """
|
|
459
|
-
<div>No Active Dataproc
|
|
464
|
+
<div>No Active Dataproc Session</div>
|
|
460
465
|
"""
|
|
461
466
|
|
|
462
467
|
s8s_session = f"https://console.cloud.google.com/dataproc/interactive/{self._region}/{self._active_s8s_session_id}"
|
|
463
468
|
ui = f"{s8s_session}/sparkApplications/applications"
|
|
464
|
-
version = ""
|
|
465
469
|
return f"""
|
|
466
470
|
<div>
|
|
467
471
|
<p><b>Spark Connect</b></p>
|
|
468
472
|
|
|
469
|
-
<p><a href="{s8s_session}">Dataproc Session</a></p>
|
|
470
|
-
<p><a href="{ui}">Spark UI</a></p>
|
|
473
|
+
<p><a href="{s8s_session}?project={self._project_id}">Dataproc Session</a></p>
|
|
474
|
+
<p><a href="{ui}?project={self._project_id}">Spark UI</a></p>
|
|
471
475
|
</div>
|
|
472
476
|
"""
|
|
473
477
|
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
]
|
|
478
|
+
@staticmethod
|
|
479
|
+
def _remove_stopped_session_from_file():
|
|
480
|
+
file_path = DataprocSparkSession._get_active_session_file_path()
|
|
481
|
+
if file_path is not None:
|
|
479
482
|
try:
|
|
480
483
|
with open(file_path, "w"):
|
|
481
484
|
pass
|
|
482
485
|
except Exception as e:
|
|
483
486
|
logger.error(
|
|
484
|
-
f"Exception while removing active session in file {file_path}
|
|
487
|
+
f"Exception while removing active session in file {file_path}, {e}"
|
|
485
488
|
)
|
|
486
489
|
|
|
490
|
+
def addArtifacts(
|
|
491
|
+
self,
|
|
492
|
+
*artifact: str,
|
|
493
|
+
pyfile: bool = False,
|
|
494
|
+
archive: bool = False,
|
|
495
|
+
file: bool = False,
|
|
496
|
+
pypi: bool = False,
|
|
497
|
+
) -> None:
|
|
498
|
+
"""
|
|
499
|
+
Add artifact(s) to the client session. Currently only local files & pypi installations are supported.
|
|
500
|
+
|
|
501
|
+
.. versionadded:: 3.5.0
|
|
502
|
+
|
|
503
|
+
Parameters
|
|
504
|
+
----------
|
|
505
|
+
*artifact : tuple of str
|
|
506
|
+
Artifact's URIs to add.
|
|
507
|
+
pyfile : bool
|
|
508
|
+
Whether to add them as Python dependencies such as .py, .egg, .zip or .jar files.
|
|
509
|
+
The pyfiles are directly inserted into the path when executing Python functions
|
|
510
|
+
in executors.
|
|
511
|
+
archive : bool
|
|
512
|
+
Whether to add them as archives such as .zip, .jar, .tar.gz, .tgz, or .tar files.
|
|
513
|
+
The archives are unpacked on the executor side automatically.
|
|
514
|
+
file : bool
|
|
515
|
+
Add a file to be downloaded with this Spark job on every node.
|
|
516
|
+
The ``path`` passed can only be a local file for now.
|
|
517
|
+
pypi : bool
|
|
518
|
+
This option is only available with DataprocSparkSession. e.g. `spark.addArtifacts("spacy==3.8.4", "torch", pypi=True)`
|
|
519
|
+
Installs PyPi package (with its dependencies) in the active Spark session on the driver and executors.
|
|
520
|
+
|
|
521
|
+
Notes
|
|
522
|
+
-----
|
|
523
|
+
This is an API dedicated to Spark Connect client only. With regular Spark Session, it throws
|
|
524
|
+
an exception.
|
|
525
|
+
Regarding pypi: Popular packages are already pre-installed in s8s runtime.
|
|
526
|
+
https://cloud.google.com/dataproc-serverless/docs/concepts/versions/spark-runtime-2.2#python_libraries
|
|
527
|
+
If there are conflicts/package doesn't exist, it throws an exception.
|
|
528
|
+
"""
|
|
529
|
+
if sum([pypi, file, pyfile, archive]) > 1:
|
|
530
|
+
raise ValueError(
|
|
531
|
+
"'pyfile', 'archive', 'file' and/or 'pypi' cannot be True together."
|
|
532
|
+
)
|
|
533
|
+
if pypi:
|
|
534
|
+
artifacts = PyPiArtifacts(set(artifact))
|
|
535
|
+
logger.debug("Making addArtifact call to install packages")
|
|
536
|
+
self.addArtifact(
|
|
537
|
+
artifacts.write_packages_config(self._active_s8s_session_uuid),
|
|
538
|
+
file=True,
|
|
539
|
+
)
|
|
540
|
+
else:
|
|
541
|
+
super().addArtifacts(
|
|
542
|
+
*artifact, pyfile=pyfile, archive=archive, file=file
|
|
543
|
+
)
|
|
544
|
+
|
|
545
|
+
@staticmethod
|
|
546
|
+
def _get_active_session_file_path():
|
|
547
|
+
return os.getenv("DATAPROC_SPARK_CONNECT_ACTIVE_SESSION_FILE_PATH")
|
|
548
|
+
|
|
487
549
|
def stop(self) -> None:
|
|
488
550
|
with DataprocSparkSession._lock:
|
|
489
551
|
if DataprocSparkSession._active_s8s_session_id is not None:
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
552
|
+
terminate_s8s_session(
|
|
553
|
+
DataprocSparkSession._project_id,
|
|
554
|
+
DataprocSparkSession._region,
|
|
555
|
+
DataprocSparkSession._active_s8s_session_id,
|
|
556
|
+
self._client_options,
|
|
494
557
|
)
|
|
495
|
-
terminate_session_request = TerminateSessionRequest()
|
|
496
|
-
session_name = f"projects/{DataprocSparkSession._project_id}/locations/{DataprocSparkSession._region}/sessions/{DataprocSparkSession._active_s8s_session_id}"
|
|
497
|
-
terminate_session_request.name = session_name
|
|
498
|
-
state = None
|
|
499
|
-
try:
|
|
500
|
-
SessionControllerClient(
|
|
501
|
-
client_options=self._client_options
|
|
502
|
-
).terminate_session(terminate_session_request)
|
|
503
|
-
get_session_request = GetSessionRequest()
|
|
504
|
-
get_session_request.name = session_name
|
|
505
|
-
state = Session.State.ACTIVE
|
|
506
|
-
while (
|
|
507
|
-
state != Session.State.TERMINATING
|
|
508
|
-
and state != Session.State.TERMINATED
|
|
509
|
-
and state != Session.State.FAILED
|
|
510
|
-
):
|
|
511
|
-
session = SessionControllerClient(
|
|
512
|
-
client_options=self._client_options
|
|
513
|
-
).get_session(get_session_request)
|
|
514
|
-
state = session.state
|
|
515
|
-
sleep(1)
|
|
516
|
-
except NotFound:
|
|
517
|
-
logger.debug(
|
|
518
|
-
f"Session {DataprocSparkSession._active_s8s_session_id} already deleted"
|
|
519
|
-
)
|
|
520
|
-
except FailedPrecondition:
|
|
521
|
-
logger.debug(
|
|
522
|
-
f"Session {DataprocSparkSession._active_s8s_session_id} already terminated manually or terminated automatically through session ttl limits"
|
|
523
|
-
)
|
|
524
|
-
if state is not None and state == Session.State.FAILED:
|
|
525
|
-
raise RuntimeError("Serverless session termination failed")
|
|
526
558
|
|
|
527
|
-
self.
|
|
559
|
+
self._remove_stopped_session_from_file()
|
|
528
560
|
DataprocSparkSession._active_s8s_session_uuid = None
|
|
529
561
|
DataprocSparkSession._active_s8s_session_id = None
|
|
530
562
|
DataprocSparkSession._project_id = None
|
|
@@ -538,3 +570,68 @@ class DataprocSparkSession(SparkSession):
|
|
|
538
570
|
DataprocSparkSession._active_session, "session", None
|
|
539
571
|
):
|
|
540
572
|
DataprocSparkSession._active_session.session = None
|
|
573
|
+
|
|
574
|
+
|
|
575
|
+
def terminate_s8s_session(
|
|
576
|
+
project_id, region, active_s8s_session_id, client_options=None
|
|
577
|
+
):
|
|
578
|
+
from google.cloud.dataproc_v1 import SessionControllerClient
|
|
579
|
+
|
|
580
|
+
logger.debug(f"Terminating Dataproc Session: {active_s8s_session_id}")
|
|
581
|
+
terminate_session_request = TerminateSessionRequest()
|
|
582
|
+
session_name = f"projects/{project_id}/locations/{region}/sessions/{active_s8s_session_id}"
|
|
583
|
+
terminate_session_request.name = session_name
|
|
584
|
+
state = None
|
|
585
|
+
try:
|
|
586
|
+
session_client = SessionControllerClient(client_options=client_options)
|
|
587
|
+
session_client.terminate_session(terminate_session_request)
|
|
588
|
+
get_session_request = GetSessionRequest()
|
|
589
|
+
get_session_request.name = session_name
|
|
590
|
+
state = Session.State.ACTIVE
|
|
591
|
+
while (
|
|
592
|
+
state != Session.State.TERMINATING
|
|
593
|
+
and state != Session.State.TERMINATED
|
|
594
|
+
and state != Session.State.FAILED
|
|
595
|
+
):
|
|
596
|
+
session = session_client.get_session(get_session_request)
|
|
597
|
+
state = session.state
|
|
598
|
+
time.sleep(1)
|
|
599
|
+
except NotFound:
|
|
600
|
+
logger.debug(
|
|
601
|
+
f"{active_s8s_session_id} Dataproc Session already deleted"
|
|
602
|
+
)
|
|
603
|
+
# Client will get 'Aborted' error if session creation is still in progress and
|
|
604
|
+
# 'FailedPrecondition' if another termination is still in progress.
|
|
605
|
+
# Both are retryable, but we catch it and let TTL take care of cleanups.
|
|
606
|
+
except (FailedPrecondition, Aborted):
|
|
607
|
+
logger.debug(
|
|
608
|
+
f"{active_s8s_session_id} Dataproc Session already terminated manually or automatically due to TTL"
|
|
609
|
+
)
|
|
610
|
+
if state is not None and state == Session.State.FAILED:
|
|
611
|
+
raise RuntimeError("Dataproc Session termination failed")
|
|
612
|
+
|
|
613
|
+
|
|
614
|
+
def get_active_s8s_session_response(
|
|
615
|
+
session_name, client_options
|
|
616
|
+
) -> Optional[sessions.Session]:
|
|
617
|
+
get_session_request = GetSessionRequest()
|
|
618
|
+
get_session_request.name = session_name
|
|
619
|
+
try:
|
|
620
|
+
get_session_response = SessionControllerClient(
|
|
621
|
+
client_options=client_options
|
|
622
|
+
).get_session(get_session_request)
|
|
623
|
+
state = get_session_response.state
|
|
624
|
+
except Exception as e:
|
|
625
|
+
print(f"{session_name} Dataproc Session deleted: {e}")
|
|
626
|
+
return None
|
|
627
|
+
if state is not None and (
|
|
628
|
+
state == Session.State.ACTIVE or state == Session.State.CREATING
|
|
629
|
+
):
|
|
630
|
+
return get_session_response
|
|
631
|
+
return None
|
|
632
|
+
|
|
633
|
+
|
|
634
|
+
def is_s8s_session_active(session_name, client_options) -> bool:
|
|
635
|
+
if get_active_s8s_session_response(session_name, client_options) is None:
|
|
636
|
+
return False
|
|
637
|
+
return True
|