flyteplugins-databricks 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,4 @@
1
+ from flyteplugins.databricks.connector import DatabricksConnector
2
+ from flyteplugins.databricks.task import Databricks
3
+
4
+ __all__ = ["Databricks", "DatabricksConnector"]
@@ -0,0 +1,151 @@
1
+ import http
2
+ import json
3
+ import os
4
+ import typing
5
+ from dataclasses import dataclass
6
+ from typing import Optional
7
+
8
+ import aiohttp
9
+ from flyte import logger
10
+ from flyte.connectors import AsyncConnector, ConnectorRegistry, Resource, ResourceMeta
11
+ from flyte.connectors.utils import convert_to_flyte_phase
12
+ from flyteidl2.core.execution_pb2 import TaskExecution, TaskLog
13
+ from flyteidl2.core.tasks_pb2 import TaskTemplate
14
+ from google.protobuf.json_format import MessageToDict
15
+
16
+ DATABRICKS_API_ENDPOINT = "/api/2.1/jobs"
17
+ DEFAULT_DATABRICKS_INSTANCE_ENV_KEY = "FLYTE_DATABRICKS_INSTANCE"
18
+
19
+
20
+ @dataclass
21
+ class DatabricksJobMetadata(ResourceMeta):
22
+ databricks_instance: str
23
+ run_id: str
24
+
25
+
26
+ def _get_databricks_job_spec(task_template: TaskTemplate) -> dict:
27
+ custom = MessageToDict(task_template.custom)
28
+ container = task_template.container
29
+ envs = task_template.container.env
30
+ databricks_job = custom.get("databricksConf")
31
+ if databricks_job is None:
32
+ raise ValueError("Missing Databricks job configuration in task template.")
33
+ if databricks_job.get("existing_cluster_id") is None:
34
+ new_cluster = databricks_job.get("new_cluster")
35
+ if new_cluster is None:
36
+ raise ValueError("Either existing_cluster_id or new_cluster must be specified")
37
+ if not new_cluster.get("docker_image"):
38
+ new_cluster["docker_image"] = {"url": container.image}
39
+ if not new_cluster.get("spark_conf"):
40
+ new_cluster["spark_conf"] = custom.get("sparkConf", {})
41
+ if not new_cluster.get("spark_env_vars"):
42
+ new_cluster["spark_env_vars"] = {env.key: env.value for env in envs}
43
+ else:
44
+ new_cluster["spark_env_vars"].update({env.key: env.value for env in envs})
45
+ # https://docs.databricks.com/api/workspace/jobs/submit
46
+ databricks_job["spark_python_task"] = {
47
+ "python_file": "flyteplugins/databricks/entrypoint.py",
48
+ "parameters": list(container.args),
49
+ "source": "GIT",
50
+ }
51
+ # https://github.com/flyteorg/flytetools/blob/master/flyteplugins/databricks/entrypoint.py
52
+ databricks_job["git_source"] = {
53
+ "git_url": "https://github.com/flyteorg/flytetools",
54
+ "git_provider": "gitHub",
55
+ "git_commit": "194364210c47c49ce66c419e8fb68d6f9c06fd7e",
56
+ }
57
+
58
+ logger.debug("databricks_job spec:", databricks_job)
59
+ return databricks_job
60
+
61
+
62
+ class DatabricksConnector(AsyncConnector):
63
+ name: str = "Databricks Connector"
64
+ task_type_name: str = "databricks"
65
+ metadata_type: type = DatabricksJobMetadata
66
+
67
+ async def create(
68
+ self,
69
+ task_template: TaskTemplate,
70
+ inputs: Optional[typing.Dict[str, typing.Any]] = None,
71
+ databricks_token: Optional[str] = None,
72
+ **kwargs,
73
+ ) -> DatabricksJobMetadata:
74
+ data = json.dumps(_get_databricks_job_spec(task_template))
75
+ custom = MessageToDict(task_template.custom)
76
+ databricks_instance = custom.get("databricksInstance", os.getenv(DEFAULT_DATABRICKS_INSTANCE_ENV_KEY))
77
+
78
+ if not databricks_instance:
79
+ raise ValueError(
80
+ f"Missing databricks instance. Please set the value through the task config or"
81
+ f" set the {DEFAULT_DATABRICKS_INSTANCE_ENV_KEY} environment variable in the connector."
82
+ )
83
+
84
+ databricks_url = f"https://{databricks_instance}{DATABRICKS_API_ENDPOINT}/runs/submit"
85
+
86
+ async with aiohttp.ClientSession() as session:
87
+ async with session.post(databricks_url, headers=get_header(databricks_token), data=data) as resp:
88
+ response = await resp.json()
89
+ if resp.status != http.HTTPStatus.OK:
90
+ raise RuntimeError(f"Failed to create databricks job with error: {response}")
91
+
92
+ return DatabricksJobMetadata(databricks_instance=databricks_instance, run_id=str(response["run_id"]))
93
+
94
+ async def get(
95
+ self, resource_meta: DatabricksJobMetadata, databricks_token: Optional[str] = None, **kwargs
96
+ ) -> Resource:
97
+ databricks_instance = resource_meta.databricks_instance
98
+ databricks_url = (
99
+ f"https://{databricks_instance}{DATABRICKS_API_ENDPOINT}/runs/get?run_id={resource_meta.run_id}"
100
+ )
101
+
102
+ async with aiohttp.ClientSession() as session:
103
+ async with session.get(databricks_url, headers=get_header(databricks_token)) as resp:
104
+ if resp.status != http.HTTPStatus.OK:
105
+ raise RuntimeError(f"Failed to get databricks job {resource_meta.run_id} with error: {resp.reason}")
106
+ response = await resp.json()
107
+
108
+ cur_phase = TaskExecution.UNDEFINED
109
+ message = ""
110
+ state = response.get("state")
111
+
112
+ # The databricks job's state is determined by life_cycle_state and result_state.
113
+ # https://docs.databricks.com/en/workflows/jobs/jobs-2.0-api.html#runresultstate
114
+ if state:
115
+ life_cycle_state = state.get("life_cycle_state")
116
+ if result_state_is_available(life_cycle_state):
117
+ result_state = state.get("result_state")
118
+ cur_phase = convert_to_flyte_phase(result_state)
119
+ else:
120
+ cur_phase = convert_to_flyte_phase(life_cycle_state)
121
+
122
+ message = state.get("state_message")
123
+
124
+ job_id = response.get("job_id")
125
+ databricks_console_url = f"https://{databricks_instance}/#job/{job_id}/run/{resource_meta.run_id}"
126
+ log_links = [TaskLog(uri=databricks_console_url, name="Databricks Console")]
127
+
128
+ return Resource(phase=cur_phase, message=message, log_links=log_links)
129
+
130
+ async def delete(self, resource_meta: DatabricksJobMetadata, databricks_token: Optional[str] = None, **kwargs):
131
+ databricks_url = f"https://{resource_meta.databricks_instance}{DATABRICKS_API_ENDPOINT}/runs/cancel"
132
+ data = json.dumps({"run_id": resource_meta.run_id})
133
+
134
+ async with aiohttp.ClientSession() as session:
135
+ async with session.post(databricks_url, headers=get_header(databricks_token), data=data) as resp:
136
+ if resp.status != http.HTTPStatus.OK:
137
+ raise RuntimeError(
138
+ f"Failed to cancel databricks job {resource_meta.run_id} with error: {resp.reason}"
139
+ )
140
+ await resp.json()
141
+
142
+
143
+ def get_header(token: str) -> typing.Dict[str, str]:
144
+ return {"Authorization": f"Bearer {token}", "content-type": "application/json"}
145
+
146
+
147
+ def result_state_is_available(life_cycle_state: str) -> bool:
148
+ return life_cycle_state == "TERMINATED"
149
+
150
+
151
+ ConnectorRegistry.register(DatabricksConnector())
@@ -0,0 +1,66 @@
1
+ from dataclasses import dataclass
2
+ from typing import Any, Dict, Optional, Union
3
+
4
+ from flyte._task_plugins import TaskPluginRegistry
5
+ from flyte.connectors import AsyncConnectorExecutorMixin
6
+ from flyte.models import SerializationContext
7
+ from flyteidl2.plugins.spark_pb2 import SparkApplication, SparkJob
8
+ from flyteplugins.spark import Spark
9
+ from flyteplugins.spark.task import PysparkFunctionTask
10
+ from google.protobuf.json_format import MessageToDict
11
+
12
+
13
+ @dataclass
14
+ class Databricks(Spark):
15
+ """
16
+ Use this to configure a Databricks task. Task's marked with this will automatically execute
17
+ natively onto databricks platform as a distributed execution of spark
18
+
19
+ Args:
20
+ databricks_conf: Databricks job configuration compliant with API version 2.1, supporting 2.0 use cases.
21
+ For the configuration structure, visit here.https://docs.databricks.com/dev-tools/api/2.0/jobs.html#request-structure
22
+ For updates in API 2.1, refer to: https://docs.databricks.com/en/workflows/jobs/jobs-api-updates.html
23
+ databricks_instance: Domain name of your deployment. Use the form <account>.cloud.databricks.com.
24
+ databricks_token: the name of the secret containing the Databricks token for authentication.
25
+ """
26
+
27
+ databricks_conf: Optional[Dict[str, Union[str, dict]]] = None
28
+ databricks_instance: Optional[str] = None
29
+ databricks_token: Optional[str] = None
30
+
31
+
32
+ class DatabricksFunctionTask(AsyncConnectorExecutorMixin, PysparkFunctionTask):
33
+ """
34
+ Actual Plugin that transforms the local python code for execution within a spark context
35
+ """
36
+
37
+ plugin_config: Databricks
38
+
39
+ def __init__(self, *args, **kwargs):
40
+ super().__init__(*args, **kwargs)
41
+ self.task_type = "databricks"
42
+
43
+ def custom_config(self, sctx: SerializationContext) -> Dict[str, Any]:
44
+ driver_pod = self.plugin_config.driver_pod.to_k8s_pod() if self.plugin_config.driver_pod else None
45
+ executor_pod = self.plugin_config.executor_pod.to_k8s_pod() if self.plugin_config.executor_pod else None
46
+
47
+ job = SparkJob(
48
+ sparkConf=self.plugin_config.spark_conf,
49
+ hadoopConf=self.plugin_config.hadoop_conf,
50
+ mainApplicationFile=self.plugin_config.applications_path or "local://" + sctx.get_entrypoint_path(),
51
+ executorPath=self.plugin_config.executor_path or sctx.interpreter_path,
52
+ mainClass="",
53
+ applicationType=SparkApplication.PYTHON,
54
+ driverPod=driver_pod,
55
+ executorPod=executor_pod,
56
+ databricksConf=self.plugin_config.databricks_conf,
57
+ databricksInstance=self.plugin_config.databricks_instance,
58
+ )
59
+
60
+ cfg = MessageToDict(job)
61
+ cfg["secrets"] = {"databricks_token": self.plugin_config.databricks_token}
62
+
63
+ return cfg
64
+
65
+
66
+ TaskPluginRegistry.register(Databricks, DatabricksFunctionTask)
@@ -0,0 +1,57 @@
1
+ Metadata-Version: 2.4
2
+ Name: flyteplugins-databricks
3
+ Version: 2.0.0
4
+ Summary: Databricks plugin for flyte
5
+ Author-email: Kevin Su <pingsutw@users.noreply.github.com>
6
+ Requires-Python: >=3.10
7
+ Description-Content-Type: text/markdown
8
+ Requires-Dist: flyte[connector]
9
+ Requires-Dist: aiohttp
10
+ Requires-Dist: nest-asyncio
11
+ Requires-Dist: flyteplugins-spark
12
+
13
+ # Databricks Plugin for Flyte
14
+
15
+ This plugin provides Databricks integration for Flyte, enabling you to run Spark jobs on Databricks as Flyte tasks.
16
+
17
+ ## Installation
18
+
19
+ ```bash
20
+ pip install flyteplugins-databricks
21
+ ```
22
+
23
+ ## Usage
24
+
25
+ ```python
26
+ from flyteplugins.databricks import Databricks, DatabricksConnector
27
+
28
+ @task(task_config=Databricks(
29
+ databricks_conf={
30
+ "run_name": "flyte databricks plugin",
31
+ "new_cluster": {
32
+ "spark_version": "13.3.x-scala2.12",
33
+ "autoscale": {
34
+ "min_workers": 1,
35
+ "max_workers": 1,
36
+ },
37
+ "node_type_id": "m6i.large",
38
+ "num_workers": 1,
39
+ "aws_attributes": {
40
+ "availability": "SPOT_WITH_FALLBACK",
41
+ "instance_profile_arn": "arn:aws:iam::339713193121:instance-profile/databricks-demo",
42
+ "ebs_volume_type": "GENERAL_PURPOSE_SSD",
43
+ "ebs_volume_count": 1,
44
+ "ebs_volume_size": 100,
45
+ "first_on_demand": 1,
46
+ },
47
+ },
48
+ # "existing_cluster_id": "1113-204018-tb9vr2fm", # use existing cluster id if you want
49
+ "timeout_seconds": 3600,
50
+ "max_retries": 1,
51
+ },
52
+ databricks_instance="mycompany.cloud.databricks.com",
53
+ ))
54
+ def my_spark_task() -> int:
55
+ # Your Spark code here
56
+ return 42
57
+ ```
@@ -0,0 +1,8 @@
1
+ flyteplugins/databricks/__init__.py,sha256=FfSNiu0JuKflhA0_i0VrNuwE_SX08lY4CBH1nY8AyD0,167
2
+ flyteplugins/databricks/connector.py,sha256=ysg8ALL0Bmnmi32BszOTlpS0-kQWz843sX8qIwqXiaA,6576
3
+ flyteplugins/databricks/task.py,sha256=ZscXO3PoocpCiBa9PM3Joe0_KytV7mw2GvCl7B69rVg,2863
4
+ flyteplugins_databricks-2.0.0.dist-info/METADATA,sha256=WOsjPV9HE0Q_XbzkvOsXvDfnx8frVAF1cxLBiUpVB7E,1686
5
+ flyteplugins_databricks-2.0.0.dist-info/WHEEL,sha256=YCfwYGOYMi5Jhw2fU4yNgwErybb2IX5PEwBKV4ZbdBo,91
6
+ flyteplugins_databricks-2.0.0.dist-info/entry_points.txt,sha256=TrAGjurydxqKpeLlNdhOv0Jc8WrmaQ6XIANTnWF9LCE,86
7
+ flyteplugins_databricks-2.0.0.dist-info/top_level.txt,sha256=cgd779rPu9EsvdtuYgUxNHHgElaQvPn74KhB5XSeMBE,13
8
+ flyteplugins_databricks-2.0.0.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (82.0.0)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1,2 @@
1
+ [flyte.connectors]
2
+ databricks = flyteplugins.databricks.connector:DatabricksConnector
@@ -0,0 +1 @@
1
+ flyteplugins