jupyter-databricks-kernel 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jupyter_databricks_kernel/__init__.py +3 -0
- jupyter_databricks_kernel/__main__.py +8 -0
- jupyter_databricks_kernel/config.py +134 -0
- jupyter_databricks_kernel/executor.py +244 -0
- jupyter_databricks_kernel/kernel.py +306 -0
- jupyter_databricks_kernel/sync.py +761 -0
- jupyter_databricks_kernel-0.1.0.dist-info/METADATA +197 -0
- jupyter_databricks_kernel-0.1.0.dist-info/RECORD +10 -0
- jupyter_databricks_kernel-0.1.0.dist-info/WHEEL +4 -0
- jupyter_databricks_kernel-0.1.0.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
"""Configuration management for Databricks kernel."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
import sys
|
|
7
|
+
import tomllib
|
|
8
|
+
from dataclasses import dataclass, field
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class SyncConfig:
|
|
14
|
+
"""Configuration for file synchronization.
|
|
15
|
+
|
|
16
|
+
The sync module applies default exclusion patterns automatically.
|
|
17
|
+
When use_gitignore is True, .gitignore rules are also applied.
|
|
18
|
+
User-specified exclude patterns are applied in addition to those defaults.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
enabled: bool = True
|
|
22
|
+
source: str = "."
|
|
23
|
+
exclude: list[str] = field(default_factory=list)
|
|
24
|
+
max_size_mb: float | None = None
|
|
25
|
+
max_file_size_mb: float | None = None
|
|
26
|
+
use_gitignore: bool = True
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@dataclass
|
|
30
|
+
class Config:
|
|
31
|
+
"""Main configuration for the Databricks kernel."""
|
|
32
|
+
|
|
33
|
+
cluster_id: str | None = None
|
|
34
|
+
sync: SyncConfig = field(default_factory=SyncConfig)
|
|
35
|
+
|
|
36
|
+
@classmethod
|
|
37
|
+
def load(cls, config_path: Path | None = None) -> Config:
|
|
38
|
+
"""Load configuration from environment variables and pyproject.toml.
|
|
39
|
+
|
|
40
|
+
Priority order:
|
|
41
|
+
1. Environment variables (highest priority)
|
|
42
|
+
2. pyproject.toml [tool.jupyter-databricks-kernel]
|
|
43
|
+
3. Default values
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
config_path: Optional path to the config file.
|
|
47
|
+
Defaults to pyproject.toml in current directory.
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
Loaded configuration.
|
|
51
|
+
"""
|
|
52
|
+
config = cls()
|
|
53
|
+
|
|
54
|
+
# Load from environment variables (highest priority)
|
|
55
|
+
config.cluster_id = os.environ.get("DATABRICKS_CLUSTER_ID")
|
|
56
|
+
|
|
57
|
+
# Determine config file path
|
|
58
|
+
if config_path is None:
|
|
59
|
+
config_path = Path.cwd() / "pyproject.toml"
|
|
60
|
+
|
|
61
|
+
# Load from config file if it exists
|
|
62
|
+
if config_path.exists():
|
|
63
|
+
config._load_from_pyproject(config_path)
|
|
64
|
+
|
|
65
|
+
return config
|
|
66
|
+
|
|
67
|
+
def _load_from_pyproject(self, config_path: Path) -> None:
|
|
68
|
+
"""Load configuration from pyproject.toml.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
config_path: Path to pyproject.toml.
|
|
72
|
+
"""
|
|
73
|
+
try:
|
|
74
|
+
with open(config_path, "rb") as f:
|
|
75
|
+
data = tomllib.load(f)
|
|
76
|
+
except tomllib.TOMLDecodeError as e:
|
|
77
|
+
print(
|
|
78
|
+
f"Warning: Failed to parse {config_path}: {e}. "
|
|
79
|
+
"Using default configuration.",
|
|
80
|
+
file=sys.stderr,
|
|
81
|
+
)
|
|
82
|
+
return
|
|
83
|
+
|
|
84
|
+
# Get [tool.jupyter-databricks-kernel] section
|
|
85
|
+
tool_config = data.get("tool", {}).get("jupyter-databricks-kernel", {})
|
|
86
|
+
if not tool_config:
|
|
87
|
+
return
|
|
88
|
+
|
|
89
|
+
# Override cluster_id if specified in file (but env var has priority)
|
|
90
|
+
if "cluster_id" in tool_config and self.cluster_id is None:
|
|
91
|
+
self.cluster_id = tool_config["cluster_id"]
|
|
92
|
+
|
|
93
|
+
# Load sync configuration
|
|
94
|
+
if "sync" in tool_config:
|
|
95
|
+
sync_data = tool_config["sync"]
|
|
96
|
+
if "enabled" in sync_data:
|
|
97
|
+
self.sync.enabled = sync_data["enabled"]
|
|
98
|
+
if "source" in sync_data:
|
|
99
|
+
self.sync.source = sync_data["source"]
|
|
100
|
+
if "exclude" in sync_data:
|
|
101
|
+
self.sync.exclude = sync_data["exclude"]
|
|
102
|
+
if "max_size_mb" in sync_data:
|
|
103
|
+
self.sync.max_size_mb = sync_data["max_size_mb"]
|
|
104
|
+
if "max_file_size_mb" in sync_data:
|
|
105
|
+
self.sync.max_file_size_mb = sync_data["max_file_size_mb"]
|
|
106
|
+
if "use_gitignore" in sync_data:
|
|
107
|
+
self.sync.use_gitignore = sync_data["use_gitignore"]
|
|
108
|
+
|
|
109
|
+
def validate(self) -> list[str]:
|
|
110
|
+
"""Validate the configuration.
|
|
111
|
+
|
|
112
|
+
Note: Authentication is handled by the Databricks SDK, which
|
|
113
|
+
automatically resolves credentials from environment variables,
|
|
114
|
+
CLI config, or cloud provider authentication.
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
List of validation error messages. Empty if valid.
|
|
118
|
+
"""
|
|
119
|
+
errors: list[str] = []
|
|
120
|
+
|
|
121
|
+
if not self.cluster_id:
|
|
122
|
+
errors.append(
|
|
123
|
+
"DATABRICKS_CLUSTER_ID environment variable is not set. "
|
|
124
|
+
"Please set it to your Databricks cluster ID."
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
# Validate sync size limits
|
|
128
|
+
if self.sync.max_size_mb is not None and self.sync.max_size_mb <= 0:
|
|
129
|
+
errors.append("max_size_mb must be a positive number.")
|
|
130
|
+
|
|
131
|
+
if self.sync.max_file_size_mb is not None and self.sync.max_file_size_mb <= 0:
|
|
132
|
+
errors.append("max_file_size_mb must be a positive number.")
|
|
133
|
+
|
|
134
|
+
return errors
|
|
@@ -0,0 +1,244 @@
|
|
|
1
|
+
"""Databricks execution context management."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
import re
|
|
7
|
+
import time
|
|
8
|
+
from dataclasses import dataclass, replace
|
|
9
|
+
from datetime import timedelta
|
|
10
|
+
from typing import TYPE_CHECKING
|
|
11
|
+
|
|
12
|
+
from databricks.sdk import WorkspaceClient
|
|
13
|
+
from databricks.sdk.service import compute
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
from .config import Config
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
# Retry and timeout configuration
|
|
21
|
+
RECONNECT_DELAY_SECONDS = 1.0 # Delay before reconnection attempt
|
|
22
|
+
CONTEXT_CREATION_TIMEOUT = timedelta(minutes=5) # Timeout for context creation
|
|
23
|
+
COMMAND_EXECUTION_TIMEOUT = timedelta(minutes=10) # Timeout for command execution
|
|
24
|
+
|
|
25
|
+
# Pre-compiled pattern for context error detection
|
|
26
|
+
# Matches errors that specifically relate to execution context invalidation
|
|
27
|
+
CONTEXT_ERROR_PATTERN = re.compile(
|
|
28
|
+
r"context\s*(not\s*found|does\s*not\s*exist|is\s*invalid|expired)|"
|
|
29
|
+
r"invalid\s*context|"
|
|
30
|
+
r"\bcontext_id\b|"
|
|
31
|
+
r"execution\s*context",
|
|
32
|
+
re.IGNORECASE,
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@dataclass
|
|
37
|
+
class ExecutionResult:
|
|
38
|
+
"""Result of a command execution."""
|
|
39
|
+
|
|
40
|
+
status: str
|
|
41
|
+
output: str | None = None
|
|
42
|
+
error: str | None = None
|
|
43
|
+
traceback: list[str] | None = None
|
|
44
|
+
reconnected: bool = False
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class DatabricksExecutor:
|
|
48
|
+
"""Manages Databricks execution context and command execution."""
|
|
49
|
+
|
|
50
|
+
def __init__(self, config: Config) -> None:
|
|
51
|
+
"""Initialize the executor.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
config: Kernel configuration.
|
|
55
|
+
"""
|
|
56
|
+
self.config = config
|
|
57
|
+
self.client: WorkspaceClient | None = None
|
|
58
|
+
self.context_id: str | None = None
|
|
59
|
+
|
|
60
|
+
def _ensure_client(self) -> WorkspaceClient:
|
|
61
|
+
"""Ensure the WorkspaceClient is initialized.
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
The WorkspaceClient instance.
|
|
65
|
+
"""
|
|
66
|
+
if self.client is None:
|
|
67
|
+
self.client = WorkspaceClient()
|
|
68
|
+
return self.client
|
|
69
|
+
|
|
70
|
+
def create_context(self) -> None:
|
|
71
|
+
"""Create an execution context on the Databricks cluster."""
|
|
72
|
+
if self.context_id is not None:
|
|
73
|
+
return # Context already exists
|
|
74
|
+
|
|
75
|
+
if not self.config.cluster_id:
|
|
76
|
+
raise ValueError("Cluster ID is not configured")
|
|
77
|
+
|
|
78
|
+
client = self._ensure_client()
|
|
79
|
+
response = client.command_execution.create(
|
|
80
|
+
cluster_id=self.config.cluster_id,
|
|
81
|
+
language=compute.Language.PYTHON,
|
|
82
|
+
).result(timeout=CONTEXT_CREATION_TIMEOUT)
|
|
83
|
+
|
|
84
|
+
if response and response.id:
|
|
85
|
+
self.context_id = response.id
|
|
86
|
+
|
|
87
|
+
def reconnect(self) -> None:
|
|
88
|
+
"""Recreate the execution context.
|
|
89
|
+
|
|
90
|
+
Destroys the old context (if any) and creates a new one.
|
|
91
|
+
Used when the existing context becomes invalid.
|
|
92
|
+
"""
|
|
93
|
+
logger.info("Reconnecting: creating new execution context")
|
|
94
|
+
# Try to destroy old context to avoid resource leak on cluster
|
|
95
|
+
# Ignore errors since context may already be invalid
|
|
96
|
+
try:
|
|
97
|
+
self.destroy_context()
|
|
98
|
+
except Exception as e:
|
|
99
|
+
logger.debug("Failed to destroy old context: %s", e)
|
|
100
|
+
self.context_id = None
|
|
101
|
+
self.create_context()
|
|
102
|
+
|
|
103
|
+
def _is_context_invalid_error(self, error: Exception) -> bool:
|
|
104
|
+
"""Check if an error indicates the context is invalid.
|
|
105
|
+
|
|
106
|
+
Only matches errors that specifically relate to execution context,
|
|
107
|
+
not general errors like "File not found" or "Variable not found".
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
error: The exception to check.
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
True if the error indicates context invalidation.
|
|
114
|
+
"""
|
|
115
|
+
error_str = str(error)
|
|
116
|
+
|
|
117
|
+
# Must contain "context" to be considered a context error (case-insensitive)
|
|
118
|
+
if "context" not in error_str.lower():
|
|
119
|
+
return False
|
|
120
|
+
|
|
121
|
+
# Use pre-compiled pattern for efficient matching
|
|
122
|
+
return CONTEXT_ERROR_PATTERN.search(error_str) is not None
|
|
123
|
+
|
|
124
|
+
def execute(self, code: str, *, allow_reconnect: bool = True) -> ExecutionResult:
|
|
125
|
+
"""Execute code on the Databricks cluster.
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
code: The Python code to execute.
|
|
129
|
+
allow_reconnect: If True, attempt to reconnect on context errors.
|
|
130
|
+
|
|
131
|
+
Returns:
|
|
132
|
+
Execution result containing output or error.
|
|
133
|
+
"""
|
|
134
|
+
if self.context_id is None:
|
|
135
|
+
self.create_context()
|
|
136
|
+
|
|
137
|
+
if self.context_id is None:
|
|
138
|
+
return ExecutionResult(
|
|
139
|
+
status="error",
|
|
140
|
+
error="Failed to create execution context",
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
if not self.config.cluster_id:
|
|
144
|
+
return ExecutionResult(
|
|
145
|
+
status="error",
|
|
146
|
+
error="Cluster ID is not configured",
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
try:
|
|
150
|
+
result = self._execute_internal(code)
|
|
151
|
+
return result
|
|
152
|
+
except Exception as e:
|
|
153
|
+
if allow_reconnect and self._is_context_invalid_error(e):
|
|
154
|
+
logger.warning("Context invalid, attempting reconnection: %s", e)
|
|
155
|
+
try:
|
|
156
|
+
# Wait before reconnection to avoid hammering the API
|
|
157
|
+
time.sleep(RECONNECT_DELAY_SECONDS)
|
|
158
|
+
self.reconnect()
|
|
159
|
+
result = self._execute_internal(code)
|
|
160
|
+
return replace(result, reconnected=True)
|
|
161
|
+
except Exception as retry_error:
|
|
162
|
+
logger.error("Reconnection failed: %s", retry_error)
|
|
163
|
+
return ExecutionResult(
|
|
164
|
+
status="error",
|
|
165
|
+
error=f"Reconnection failed: {retry_error}",
|
|
166
|
+
)
|
|
167
|
+
else:
|
|
168
|
+
logger.error("Execution failed: %s", e)
|
|
169
|
+
return ExecutionResult(
|
|
170
|
+
status="error",
|
|
171
|
+
error=str(e),
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
def _execute_internal(self, code: str) -> ExecutionResult:
|
|
175
|
+
"""Internal execution without reconnection logic.
|
|
176
|
+
|
|
177
|
+
Args:
|
|
178
|
+
code: The Python code to execute.
|
|
179
|
+
|
|
180
|
+
Returns:
|
|
181
|
+
Execution result containing output or error.
|
|
182
|
+
|
|
183
|
+
Raises:
|
|
184
|
+
Exception: If execution fails due to API errors.
|
|
185
|
+
"""
|
|
186
|
+
client = self._ensure_client()
|
|
187
|
+
response = client.command_execution.execute(
|
|
188
|
+
cluster_id=self.config.cluster_id,
|
|
189
|
+
context_id=self.context_id,
|
|
190
|
+
language=compute.Language.PYTHON,
|
|
191
|
+
command=code,
|
|
192
|
+
).result(timeout=COMMAND_EXECUTION_TIMEOUT)
|
|
193
|
+
|
|
194
|
+
if response is None:
|
|
195
|
+
return ExecutionResult(
|
|
196
|
+
status="error",
|
|
197
|
+
error="No response from Databricks",
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
# Parse the response
|
|
201
|
+
status = str(response.status) if response.status else "unknown"
|
|
202
|
+
|
|
203
|
+
# Handle results
|
|
204
|
+
if response.results:
|
|
205
|
+
results = response.results
|
|
206
|
+
|
|
207
|
+
# Check for error
|
|
208
|
+
if results.cause:
|
|
209
|
+
return ExecutionResult(
|
|
210
|
+
status="error",
|
|
211
|
+
error=results.cause,
|
|
212
|
+
traceback=results.summary.split("\n") if results.summary else None,
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
# Get output
|
|
216
|
+
output = None
|
|
217
|
+
if results.data is not None:
|
|
218
|
+
output = str(results.data)
|
|
219
|
+
elif results.summary:
|
|
220
|
+
output = results.summary
|
|
221
|
+
|
|
222
|
+
return ExecutionResult(
|
|
223
|
+
status="ok",
|
|
224
|
+
output=output,
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
return ExecutionResult(status=status)
|
|
228
|
+
|
|
229
|
+
def destroy_context(self) -> None:
|
|
230
|
+
"""Destroy the execution context."""
|
|
231
|
+
if self.context_id is None:
|
|
232
|
+
return
|
|
233
|
+
|
|
234
|
+
if not self.config.cluster_id:
|
|
235
|
+
return
|
|
236
|
+
|
|
237
|
+
try:
|
|
238
|
+
client = self._ensure_client()
|
|
239
|
+
client.command_execution.destroy(
|
|
240
|
+
cluster_id=self.config.cluster_id,
|
|
241
|
+
context_id=self.context_id,
|
|
242
|
+
)
|
|
243
|
+
finally:
|
|
244
|
+
self.context_id = None
|
|
@@ -0,0 +1,306 @@
|
|
|
1
|
+
"""Databricks Session Kernel for Jupyter."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import uuid
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from ipykernel.kernelbase import Kernel
|
|
9
|
+
|
|
10
|
+
from . import __version__
|
|
11
|
+
from .config import Config
|
|
12
|
+
from .executor import DatabricksExecutor
|
|
13
|
+
from .sync import FileSync
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class DatabricksKernel(Kernel):
|
|
17
|
+
"""Jupyter kernel that executes code on a remote Databricks cluster."""
|
|
18
|
+
|
|
19
|
+
implementation = "databricks-session-kernel"
|
|
20
|
+
implementation_version = __version__
|
|
21
|
+
language = "python"
|
|
22
|
+
language_version = "3.11"
|
|
23
|
+
language_info = {
|
|
24
|
+
"name": "python",
|
|
25
|
+
"mimetype": "text/x-python",
|
|
26
|
+
"file_extension": ".py",
|
|
27
|
+
}
|
|
28
|
+
banner = "Databricks Session Kernel - Execute Python on Databricks clusters"
|
|
29
|
+
|
|
30
|
+
def __init__(self, **kwargs: Any) -> None:
|
|
31
|
+
"""Initialize the Databricks kernel."""
|
|
32
|
+
super().__init__(**kwargs)
|
|
33
|
+
self._kernel_config = Config.load()
|
|
34
|
+
self._session_id = str(uuid.uuid4())[:8]
|
|
35
|
+
self.executor: DatabricksExecutor | None = None
|
|
36
|
+
self.file_sync: FileSync | None = None
|
|
37
|
+
self._initialized = False
|
|
38
|
+
self._last_dbfs_path: str | None = None
|
|
39
|
+
|
|
40
|
+
def _initialize(self) -> bool:
|
|
41
|
+
"""Initialize the Databricks connection.
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
True if initialization succeeded, False otherwise.
|
|
45
|
+
"""
|
|
46
|
+
if self._initialized:
|
|
47
|
+
return True
|
|
48
|
+
|
|
49
|
+
# Validate configuration
|
|
50
|
+
errors = self._kernel_config.validate()
|
|
51
|
+
if errors:
|
|
52
|
+
for error in errors:
|
|
53
|
+
self.send_response(
|
|
54
|
+
self.iopub_socket,
|
|
55
|
+
"stream",
|
|
56
|
+
{"name": "stderr", "text": f"Configuration error: {error}\n"},
|
|
57
|
+
)
|
|
58
|
+
return False
|
|
59
|
+
|
|
60
|
+
# Initialize executor and file sync (reuse existing if available)
|
|
61
|
+
if self.executor is None:
|
|
62
|
+
self.executor = DatabricksExecutor(self._kernel_config)
|
|
63
|
+
if self.file_sync is None:
|
|
64
|
+
self.file_sync = FileSync(self._kernel_config, self._session_id)
|
|
65
|
+
self._initialized = True
|
|
66
|
+
return True
|
|
67
|
+
|
|
68
|
+
def _sync_files(self) -> bool:
|
|
69
|
+
"""Synchronize files if needed.
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
True if sync succeeded or was not needed, False on error.
|
|
73
|
+
"""
|
|
74
|
+
if self.file_sync is None or self.executor is None:
|
|
75
|
+
return True
|
|
76
|
+
|
|
77
|
+
if not self.file_sync.needs_sync():
|
|
78
|
+
return True
|
|
79
|
+
|
|
80
|
+
try:
|
|
81
|
+
self.send_response(
|
|
82
|
+
self.iopub_socket,
|
|
83
|
+
"stream",
|
|
84
|
+
{"name": "stderr", "text": "Syncing files to Databricks...\n"},
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
# Upload files
|
|
88
|
+
stats = self.file_sync.sync()
|
|
89
|
+
self._last_dbfs_path = stats.dbfs_path
|
|
90
|
+
|
|
91
|
+
# Execute setup code on remote
|
|
92
|
+
setup_code = self.file_sync.get_setup_code(stats.dbfs_path)
|
|
93
|
+
result = self.executor.execute(setup_code, allow_reconnect=False)
|
|
94
|
+
|
|
95
|
+
if result.status != "ok":
|
|
96
|
+
self.send_response(
|
|
97
|
+
self.iopub_socket,
|
|
98
|
+
"stream",
|
|
99
|
+
{"name": "stderr", "text": f"Sync setup failed: {result.error}\n"},
|
|
100
|
+
)
|
|
101
|
+
return False
|
|
102
|
+
|
|
103
|
+
self.send_response(
|
|
104
|
+
self.iopub_socket,
|
|
105
|
+
"stream",
|
|
106
|
+
{"name": "stderr", "text": "Files synced successfully.\n"},
|
|
107
|
+
)
|
|
108
|
+
return True
|
|
109
|
+
|
|
110
|
+
except Exception as e:
|
|
111
|
+
self.send_response(
|
|
112
|
+
self.iopub_socket,
|
|
113
|
+
"stream",
|
|
114
|
+
{"name": "stderr", "text": f"Sync failed: {e}\n"},
|
|
115
|
+
)
|
|
116
|
+
# Continue execution even if sync fails
|
|
117
|
+
return True
|
|
118
|
+
|
|
119
|
+
def _handle_reconnection(self) -> None:
|
|
120
|
+
"""Handle session reconnection.
|
|
121
|
+
|
|
122
|
+
Re-runs the setup code to restore sys.path and notifies the user.
|
|
123
|
+
"""
|
|
124
|
+
# Notify user about reconnection
|
|
125
|
+
self.send_response(
|
|
126
|
+
self.iopub_socket,
|
|
127
|
+
"stream",
|
|
128
|
+
{
|
|
129
|
+
"name": "stderr",
|
|
130
|
+
"text": "Session reconnected. Variables have been reset.\n",
|
|
131
|
+
},
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
# Re-run setup code if we have synced files before
|
|
135
|
+
if self.file_sync and self._last_dbfs_path and self.executor:
|
|
136
|
+
try:
|
|
137
|
+
setup_code = self.file_sync.get_setup_code(self._last_dbfs_path)
|
|
138
|
+
result = self.executor.execute(setup_code, allow_reconnect=False)
|
|
139
|
+
if result.status != "ok":
|
|
140
|
+
err = result.error
|
|
141
|
+
self.send_response(
|
|
142
|
+
self.iopub_socket,
|
|
143
|
+
"stream",
|
|
144
|
+
{
|
|
145
|
+
"name": "stderr",
|
|
146
|
+
"text": f"Warning: Failed to restore sys.path: {err}\n",
|
|
147
|
+
},
|
|
148
|
+
)
|
|
149
|
+
except Exception as e:
|
|
150
|
+
# Notify user but don't fail the main execution
|
|
151
|
+
self.send_response(
|
|
152
|
+
self.iopub_socket,
|
|
153
|
+
"stream",
|
|
154
|
+
{
|
|
155
|
+
"name": "stderr",
|
|
156
|
+
"text": f"Warning: Failed to restore sys.path: {e}\n",
|
|
157
|
+
},
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
async def do_execute(
|
|
161
|
+
self,
|
|
162
|
+
code: Any,
|
|
163
|
+
silent: Any,
|
|
164
|
+
store_history: Any = True,
|
|
165
|
+
user_expressions: Any = None,
|
|
166
|
+
allow_stdin: Any = False,
|
|
167
|
+
*,
|
|
168
|
+
cell_meta: Any = None,
|
|
169
|
+
cell_id: Any = None,
|
|
170
|
+
) -> dict[str, Any]:
|
|
171
|
+
"""Execute code on the Databricks cluster.
|
|
172
|
+
|
|
173
|
+
Args:
|
|
174
|
+
code: The code to execute.
|
|
175
|
+
silent: Whether to suppress output.
|
|
176
|
+
store_history: Whether to store the code in history.
|
|
177
|
+
user_expressions: User expressions to evaluate.
|
|
178
|
+
allow_stdin: Whether to allow stdin.
|
|
179
|
+
cell_id: The cell ID.
|
|
180
|
+
|
|
181
|
+
Returns:
|
|
182
|
+
Execution result dictionary.
|
|
183
|
+
"""
|
|
184
|
+
# Skip empty code
|
|
185
|
+
code_str = str(code).strip()
|
|
186
|
+
if not code_str:
|
|
187
|
+
return {
|
|
188
|
+
"status": "ok",
|
|
189
|
+
"execution_count": self.execution_count,
|
|
190
|
+
"payload": [],
|
|
191
|
+
"user_expressions": {},
|
|
192
|
+
}
|
|
193
|
+
|
|
194
|
+
# Initialize on first execution
|
|
195
|
+
if not self._initialize():
|
|
196
|
+
return {
|
|
197
|
+
"status": "error",
|
|
198
|
+
"execution_count": self.execution_count,
|
|
199
|
+
"ename": "ConfigurationError",
|
|
200
|
+
"evalue": "Failed to initialize Databricks connection",
|
|
201
|
+
"traceback": [],
|
|
202
|
+
}
|
|
203
|
+
|
|
204
|
+
# Sync files before execution
|
|
205
|
+
self._sync_files()
|
|
206
|
+
|
|
207
|
+
# Execute on Databricks
|
|
208
|
+
assert self.executor is not None
|
|
209
|
+
try:
|
|
210
|
+
result = self.executor.execute(code_str)
|
|
211
|
+
|
|
212
|
+
# Handle reconnection: re-run setup code and notify user
|
|
213
|
+
if result.reconnected:
|
|
214
|
+
self._handle_reconnection()
|
|
215
|
+
|
|
216
|
+
if result.status == "ok":
|
|
217
|
+
if not silent and result.output:
|
|
218
|
+
self.send_response(
|
|
219
|
+
self.iopub_socket,
|
|
220
|
+
"stream",
|
|
221
|
+
{"name": "stdout", "text": result.output},
|
|
222
|
+
)
|
|
223
|
+
return {
|
|
224
|
+
"status": "ok",
|
|
225
|
+
"execution_count": self.execution_count,
|
|
226
|
+
"payload": [],
|
|
227
|
+
"user_expressions": {},
|
|
228
|
+
}
|
|
229
|
+
else:
|
|
230
|
+
# Handle error
|
|
231
|
+
error_msg = result.error or "Unknown error"
|
|
232
|
+
traceback = result.traceback or []
|
|
233
|
+
|
|
234
|
+
if not silent:
|
|
235
|
+
self.send_response(
|
|
236
|
+
self.iopub_socket,
|
|
237
|
+
"error",
|
|
238
|
+
{
|
|
239
|
+
"ename": "ExecutionError",
|
|
240
|
+
"evalue": error_msg,
|
|
241
|
+
"traceback": traceback,
|
|
242
|
+
},
|
|
243
|
+
)
|
|
244
|
+
return {
|
|
245
|
+
"status": "error",
|
|
246
|
+
"execution_count": self.execution_count,
|
|
247
|
+
"ename": "ExecutionError",
|
|
248
|
+
"evalue": error_msg,
|
|
249
|
+
"traceback": traceback,
|
|
250
|
+
}
|
|
251
|
+
|
|
252
|
+
except Exception as e:
|
|
253
|
+
error_msg = str(e)
|
|
254
|
+
if not silent:
|
|
255
|
+
self.send_response(
|
|
256
|
+
self.iopub_socket,
|
|
257
|
+
"error",
|
|
258
|
+
{
|
|
259
|
+
"ename": type(e).__name__,
|
|
260
|
+
"evalue": error_msg,
|
|
261
|
+
"traceback": [error_msg],
|
|
262
|
+
},
|
|
263
|
+
)
|
|
264
|
+
return {
|
|
265
|
+
"status": "error",
|
|
266
|
+
"execution_count": self.execution_count,
|
|
267
|
+
"ename": type(e).__name__,
|
|
268
|
+
"evalue": error_msg,
|
|
269
|
+
"traceback": [error_msg],
|
|
270
|
+
}
|
|
271
|
+
|
|
272
|
+
async def do_shutdown(self, restart: bool) -> dict[str, Any]:
|
|
273
|
+
"""Shutdown the kernel.
|
|
274
|
+
|
|
275
|
+
Args:
|
|
276
|
+
restart: Whether this is a restart.
|
|
277
|
+
|
|
278
|
+
Returns:
|
|
279
|
+
Shutdown result dictionary.
|
|
280
|
+
"""
|
|
281
|
+
if restart:
|
|
282
|
+
# On restart, keep the execution context alive for session continuity
|
|
283
|
+
# Only reset the initialized flag so we can re-initialize on next execute
|
|
284
|
+
self._initialized = False
|
|
285
|
+
return {"status": "ok", "restart": restart}
|
|
286
|
+
|
|
287
|
+
# Full shutdown: clean up everything
|
|
288
|
+
# Clean up file sync
|
|
289
|
+
if self.file_sync:
|
|
290
|
+
try:
|
|
291
|
+
self.file_sync.cleanup()
|
|
292
|
+
except Exception:
|
|
293
|
+
pass
|
|
294
|
+
self.file_sync = None
|
|
295
|
+
|
|
296
|
+
# Destroy execution context
|
|
297
|
+
if self.executor:
|
|
298
|
+
try:
|
|
299
|
+
self.executor.destroy_context()
|
|
300
|
+
except Exception:
|
|
301
|
+
pass
|
|
302
|
+
self.executor = None
|
|
303
|
+
|
|
304
|
+
self._initialized = False
|
|
305
|
+
self._last_dbfs_path = None
|
|
306
|
+
return {"status": "ok", "restart": restart}
|