flyteplugins-databricks 2.0.0b54__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
File without changes
@@ -0,0 +1,4 @@
1
+ from flyteplugins.databricks.connector import DatabricksConnector
2
+ from flyteplugins.databricks.task import Databricks
3
+
4
+ __all__ = ["Databricks", "DatabricksConnector"]
@@ -0,0 +1,151 @@
1
+ import http
2
+ import json
3
+ import os
4
+ import typing
5
+ from dataclasses import dataclass
6
+ from typing import Optional
7
+
8
+ import aiohttp
9
+ from flyte import logger
10
+ from flyte.connectors import AsyncConnector, ConnectorRegistry, Resource, ResourceMeta
11
+ from flyte.connectors.utils import convert_to_flyte_phase
12
+ from flyteidl2.core.execution_pb2 import TaskExecution, TaskLog
13
+ from flyteidl2.core.tasks_pb2 import TaskTemplate
14
+ from google.protobuf.json_format import MessageToDict
15
+
16
+ DATABRICKS_API_ENDPOINT = "/api/2.1/jobs"
17
+ DEFAULT_DATABRICKS_INSTANCE_ENV_KEY = "FLYTE_DATABRICKS_INSTANCE"
18
+
19
+
20
+ @dataclass
21
+ class DatabricksJobMetadata(ResourceMeta):
22
+ databricks_instance: str
23
+ run_id: str
24
+
25
+
26
+ def _get_databricks_job_spec(task_template: TaskTemplate) -> dict:
27
+ custom = MessageToDict(task_template.custom)
28
+ container = task_template.container
29
+ envs = task_template.container.env
30
+ databricks_job = custom.get("databricksConf")
31
+ if databricks_job is None:
32
+ raise ValueError("Missing Databricks job configuration in task template.")
33
+ if databricks_job.get("existing_cluster_id") is None:
34
+ new_cluster = databricks_job.get("new_cluster")
35
+ if new_cluster is None:
36
+ raise ValueError("Either existing_cluster_id or new_cluster must be specified")
37
+ if not new_cluster.get("docker_image"):
38
+ new_cluster["docker_image"] = {"url": container.image}
39
+ if not new_cluster.get("spark_conf"):
40
+ new_cluster["spark_conf"] = custom.get("sparkConf", {})
41
+ if not new_cluster.get("spark_env_vars"):
42
+ new_cluster["spark_env_vars"] = {env.key: env.value for env in envs}
43
+ else:
44
+ new_cluster["spark_env_vars"].update({env.key: env.value for env in envs})
45
+ # https://docs.databricks.com/api/workspace/jobs/submit
46
+ databricks_job["spark_python_task"] = {
47
+ "python_file": "flyteplugins/databricks/entrypoint.py",
48
+ "parameters": list(container.args),
49
+ "source": "GIT",
50
+ }
51
+ # https://github.com/flyteorg/flytetools/blob/master/flyteplugins/databricks/entrypoint.py
52
+ databricks_job["git_source"] = {
53
+ "git_url": "https://github.com/flyteorg/flytetools",
54
+ "git_provider": "gitHub",
55
+ "git_commit": "194364210c47c49ce66c419e8fb68d6f9c06fd7e",
56
+ }
57
+
58
+ logger.debug("databricks_job spec:", databricks_job)
59
+ return databricks_job
60
+
61
+
62
+ class DatabricksConnector(AsyncConnector):
63
+ name: str = "Databricks Connector"
64
+ task_type_name: str = "databricks"
65
+ metadata_type: type = DatabricksJobMetadata
66
+
67
+ async def create(
68
+ self,
69
+ task_template: TaskTemplate,
70
+ inputs: Optional[typing.Dict[str, typing.Any]] = None,
71
+ databricks_token: Optional[str] = None,
72
+ **kwargs,
73
+ ) -> DatabricksJobMetadata:
74
+ data = json.dumps(_get_databricks_job_spec(task_template))
75
+ custom = MessageToDict(task_template.custom)
76
+ databricks_instance = custom.get("databricksInstance", os.getenv(DEFAULT_DATABRICKS_INSTANCE_ENV_KEY))
77
+
78
+ if not databricks_instance:
79
+ raise ValueError(
80
+ f"Missing databricks instance. Please set the value through the task config or"
81
+ f" set the {DEFAULT_DATABRICKS_INSTANCE_ENV_KEY} environment variable in the connector."
82
+ )
83
+
84
+ databricks_url = f"https://{databricks_instance}{DATABRICKS_API_ENDPOINT}/runs/submit"
85
+
86
+ async with aiohttp.ClientSession() as session:
87
+ async with session.post(databricks_url, headers=get_header(databricks_token), data=data) as resp:
88
+ response = await resp.json()
89
+ if resp.status != http.HTTPStatus.OK:
90
+ raise RuntimeError(f"Failed to create databricks job with error: {response}")
91
+
92
+ return DatabricksJobMetadata(databricks_instance=databricks_instance, run_id=str(response["run_id"]))
93
+
94
+ async def get(
95
+ self, resource_meta: DatabricksJobMetadata, databricks_token: Optional[str] = None, **kwargs
96
+ ) -> Resource:
97
+ databricks_instance = resource_meta.databricks_instance
98
+ databricks_url = (
99
+ f"https://{databricks_instance}{DATABRICKS_API_ENDPOINT}/runs/get?run_id={resource_meta.run_id}"
100
+ )
101
+
102
+ async with aiohttp.ClientSession() as session:
103
+ async with session.get(databricks_url, headers=get_header(databricks_token)) as resp:
104
+ if resp.status != http.HTTPStatus.OK:
105
+ raise RuntimeError(f"Failed to get databricks job {resource_meta.run_id} with error: {resp.reason}")
106
+ response = await resp.json()
107
+
108
+ cur_phase = TaskExecution.UNDEFINED
109
+ message = ""
110
+ state = response.get("state")
111
+
112
+ # The databricks job's state is determined by life_cycle_state and result_state.
113
+ # https://docs.databricks.com/en/workflows/jobs/jobs-2.0-api.html#runresultstate
114
+ if state:
115
+ life_cycle_state = state.get("life_cycle_state")
116
+ if result_state_is_available(life_cycle_state):
117
+ result_state = state.get("result_state")
118
+ cur_phase = convert_to_flyte_phase(result_state)
119
+ else:
120
+ cur_phase = convert_to_flyte_phase(life_cycle_state)
121
+
122
+ message = state.get("state_message")
123
+
124
+ job_id = response.get("job_id")
125
+ databricks_console_url = f"https://{databricks_instance}/#job/{job_id}/run/{resource_meta.run_id}"
126
+ log_links = [TaskLog(uri=databricks_console_url, name="Databricks Console")]
127
+
128
+ return Resource(phase=cur_phase, message=message, log_links=log_links)
129
+
130
+ async def delete(self, resource_meta: DatabricksJobMetadata, databricks_token: Optional[str] = None, **kwargs):
131
+ databricks_url = f"https://{resource_meta.databricks_instance}{DATABRICKS_API_ENDPOINT}/runs/cancel"
132
+ data = json.dumps({"run_id": resource_meta.run_id})
133
+
134
+ async with aiohttp.ClientSession() as session:
135
+ async with session.post(databricks_url, headers=get_header(databricks_token), data=data) as resp:
136
+ if resp.status != http.HTTPStatus.OK:
137
+ raise RuntimeError(
138
+ f"Failed to cancel databricks job {resource_meta.run_id} with error: {resp.reason}"
139
+ )
140
+ await resp.json()
141
+
142
+
143
+ def get_header(token: str) -> typing.Dict[str, str]:
144
+ return {"Authorization": f"Bearer {token}", "content-type": "application/json"}
145
+
146
+
147
+ def result_state_is_available(life_cycle_state: str) -> bool:
148
+ return life_cycle_state == "TERMINATED"
149
+
150
+
151
+ ConnectorRegistry.register(DatabricksConnector())
@@ -0,0 +1,67 @@
1
+ from dataclasses import dataclass
2
+ from typing import Any, Dict, Optional, Union
3
+
4
+ from flyte._task_plugins import TaskPluginRegistry
5
+ from flyte.connectors import AsyncConnectorExecutorMixin
6
+ from flyte.models import SerializationContext
7
+ from flyteidl2.plugins.spark_pb2 import SparkApplication, SparkJob
8
+ from google.protobuf.json_format import MessageToDict
9
+
10
+ from flyteplugins.spark import Spark
11
+ from flyteplugins.spark.task import PysparkFunctionTask
12
+
13
+
14
+ @dataclass
15
+ class Databricks(Spark):
16
+ """
17
+ Use this to configure a Databricks task. Task's marked with this will automatically execute
18
+ natively onto databricks platform as a distributed execution of spark
19
+
20
+ Args:
21
+ databricks_conf: Databricks job configuration compliant with API version 2.1, supporting 2.0 use cases.
22
+ For the configuration structure, visit here.https://docs.databricks.com/dev-tools/api/2.0/jobs.html#request-structure
23
+ For updates in API 2.1, refer to: https://docs.databricks.com/en/workflows/jobs/jobs-api-updates.html
24
+ databricks_instance: Domain name of your deployment. Use the form <account>.cloud.databricks.com.
25
+ databricks_token: the name of the secret containing the Databricks token for authentication.
26
+ """
27
+
28
+ databricks_conf: Optional[Dict[str, Union[str, dict]]] = None
29
+ databricks_instance: Optional[str] = None
30
+ databricks_token: Optional[str] = None
31
+
32
+
33
+ class DatabricksFunctionTask(AsyncConnectorExecutorMixin, PysparkFunctionTask):
34
+ """
35
+ Actual Plugin that transforms the local python code for execution within a spark context
36
+ """
37
+
38
+ plugin_config: Databricks
39
+
40
+ def __init__(self, *args, **kwargs):
41
+ super().__init__(*args, **kwargs)
42
+ self.task_type = "databricks"
43
+
44
+ def custom_config(self, sctx: SerializationContext) -> Dict[str, Any]:
45
+ driver_pod = self.plugin_config.driver_pod.to_k8s_pod() if self.plugin_config.driver_pod else None
46
+ executor_pod = self.plugin_config.executor_pod.to_k8s_pod() if self.plugin_config.executor_pod else None
47
+
48
+ job = SparkJob(
49
+ sparkConf=self.plugin_config.spark_conf,
50
+ hadoopConf=self.plugin_config.hadoop_conf,
51
+ mainApplicationFile=self.plugin_config.applications_path or "local://" + sctx.get_entrypoint_path(),
52
+ executorPath=self.plugin_config.executor_path or sctx.interpreter_path,
53
+ mainClass="",
54
+ applicationType=SparkApplication.PYTHON,
55
+ driverPod=driver_pod,
56
+ executorPod=executor_pod,
57
+ databricksConf=self.plugin_config.databricks_conf,
58
+ databricksInstance=self.plugin_config.databricks_instance,
59
+ )
60
+
61
+ cfg = MessageToDict(job)
62
+ cfg["secrets"] = {"databricks_token": self.plugin_config.databricks_token}
63
+
64
+ return cfg
65
+
66
+
67
+ TaskPluginRegistry.register(Databricks, DatabricksFunctionTask)
@@ -0,0 +1,57 @@
1
+ Metadata-Version: 2.4
2
+ Name: flyteplugins-databricks
3
+ Version: 2.0.0b54
4
+ Summary: Databricks plugin for flyte
5
+ Author-email: Kevin Su <pingsutw@users.noreply.github.com>
6
+ Requires-Python: >=3.10
7
+ Description-Content-Type: text/markdown
8
+ Requires-Dist: flyte[connector]
9
+ Requires-Dist: aiohttp
10
+ Requires-Dist: nest-asyncio
11
+ Requires-Dist: flyteplugins-spark
12
+
13
+ # Databricks Plugin for Flyte
14
+
15
+ This plugin provides Databricks integration for Flyte, enabling you to run Spark jobs on Databricks as Flyte tasks.
16
+
17
+ ## Installation
18
+
19
+ ```bash
20
+ pip install flyteplugins-databricks
21
+ ```
22
+
23
+ ## Usage
24
+
25
+ ```python
26
+ from flyteplugins.databricks import Databricks, DatabricksConnector
27
+
28
+ @task(task_config=Databricks(
29
+ databricks_conf={
30
+ "run_name": "flyte databricks plugin",
31
+ "new_cluster": {
32
+ "spark_version": "13.3.x-scala2.12",
33
+ "autoscale": {
34
+ "min_workers": 1,
35
+ "max_workers": 1,
36
+ },
37
+ "node_type_id": "m6i.large",
38
+ "num_workers": 1,
39
+ "aws_attributes": {
40
+ "availability": "SPOT_WITH_FALLBACK",
41
+ "instance_profile_arn": "arn:aws:iam::339713193121:instance-profile/databricks-demo",
42
+ "ebs_volume_type": "GENERAL_PURPOSE_SSD",
43
+ "ebs_volume_count": 1,
44
+ "ebs_volume_size": 100,
45
+ "first_on_demand": 1,
46
+ },
47
+ },
48
+ # "existing_cluster_id": "1113-204018-tb9vr2fm", # use existing cluster id if you want
49
+ "timeout_seconds": 3600,
50
+ "max_retries": 1,
51
+ },
52
+ databricks_instance="mycompany.cloud.databricks.com",
53
+ ))
54
+ def my_spark_task() -> int:
55
+ # Your Spark code here
56
+ return 42
57
+ ```
@@ -0,0 +1,9 @@
1
+ flyteplugins/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ flyteplugins/databricks/__init__.py,sha256=FfSNiu0JuKflhA0_i0VrNuwE_SX08lY4CBH1nY8AyD0,167
3
+ flyteplugins/databricks/connector.py,sha256=ysg8ALL0Bmnmi32BszOTlpS0-kQWz843sX8qIwqXiaA,6576
4
+ flyteplugins/databricks/task.py,sha256=K5xO5ezsiL4XxVT6XtYNI6K0K_VjfMk2-bjRTw6ceE4,2864
5
+ flyteplugins_databricks-2.0.0b54.dist-info/METADATA,sha256=q83muoNogwa-1yg1ya_CHqNo1E9U-orQY8yc1BgXkPY,1689
6
+ flyteplugins_databricks-2.0.0b54.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
7
+ flyteplugins_databricks-2.0.0b54.dist-info/entry_points.txt,sha256=TrAGjurydxqKpeLlNdhOv0Jc8WrmaQ6XIANTnWF9LCE,86
8
+ flyteplugins_databricks-2.0.0b54.dist-info/top_level.txt,sha256=cgd779rPu9EsvdtuYgUxNHHgElaQvPn74KhB5XSeMBE,13
9
+ flyteplugins_databricks-2.0.0b54.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (80.10.2)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1,2 @@
1
+ [flyte.connectors]
2
+ databricks = flyteplugins.databricks.connector:DatabricksConnector
@@ -0,0 +1 @@
1
+ flyteplugins