jupyter-databricks-kernel 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,3 @@
1
+ """Jupyter kernel for Databricks remote execution."""
2
+
3
+ __version__ = "0.1.0"
@@ -0,0 +1,8 @@
1
+ """Entry point for the Databricks kernel."""
2
+
3
+ from ipykernel.kernelapp import IPKernelApp
4
+
5
+ from .kernel import DatabricksKernel
6
+
7
+ if __name__ == "__main__":
8
+ IPKernelApp.launch_instance(kernel_class=DatabricksKernel)
@@ -0,0 +1,134 @@
1
+ """Configuration management for Databricks kernel."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ import sys
7
+ import tomllib
8
+ from dataclasses import dataclass, field
9
+ from pathlib import Path
10
+
11
+
12
+ @dataclass
13
+ class SyncConfig:
14
+ """Configuration for file synchronization.
15
+
16
+ The sync module applies default exclusion patterns automatically.
17
+ When use_gitignore is True, .gitignore rules are also applied.
18
+ User-specified exclude patterns are applied in addition to those defaults.
19
+ """
20
+
21
+ enabled: bool = True
22
+ source: str = "."
23
+ exclude: list[str] = field(default_factory=list)
24
+ max_size_mb: float | None = None
25
+ max_file_size_mb: float | None = None
26
+ use_gitignore: bool = True
27
+
28
+
29
+ @dataclass
30
+ class Config:
31
+ """Main configuration for the Databricks kernel."""
32
+
33
+ cluster_id: str | None = None
34
+ sync: SyncConfig = field(default_factory=SyncConfig)
35
+
36
+ @classmethod
37
+ def load(cls, config_path: Path | None = None) -> Config:
38
+ """Load configuration from environment variables and pyproject.toml.
39
+
40
+ Priority order:
41
+ 1. Environment variables (highest priority)
42
+ 2. pyproject.toml [tool.jupyter-databricks-kernel]
43
+ 3. Default values
44
+
45
+ Args:
46
+ config_path: Optional path to the config file.
47
+ Defaults to pyproject.toml in current directory.
48
+
49
+ Returns:
50
+ Loaded configuration.
51
+ """
52
+ config = cls()
53
+
54
+ # Load from environment variables (highest priority)
55
+ config.cluster_id = os.environ.get("DATABRICKS_CLUSTER_ID")
56
+
57
+ # Determine config file path
58
+ if config_path is None:
59
+ config_path = Path.cwd() / "pyproject.toml"
60
+
61
+ # Load from config file if it exists
62
+ if config_path.exists():
63
+ config._load_from_pyproject(config_path)
64
+
65
+ return config
66
+
67
+ def _load_from_pyproject(self, config_path: Path) -> None:
68
+ """Load configuration from pyproject.toml.
69
+
70
+ Args:
71
+ config_path: Path to pyproject.toml.
72
+ """
73
+ try:
74
+ with open(config_path, "rb") as f:
75
+ data = tomllib.load(f)
76
+ except tomllib.TOMLDecodeError as e:
77
+ print(
78
+ f"Warning: Failed to parse {config_path}: {e}. "
79
+ "Using default configuration.",
80
+ file=sys.stderr,
81
+ )
82
+ return
83
+
84
+ # Get [tool.jupyter-databricks-kernel] section
85
+ tool_config = data.get("tool", {}).get("jupyter-databricks-kernel", {})
86
+ if not tool_config:
87
+ return
88
+
89
+ # Override cluster_id if specified in file (but env var has priority)
90
+ if "cluster_id" in tool_config and self.cluster_id is None:
91
+ self.cluster_id = tool_config["cluster_id"]
92
+
93
+ # Load sync configuration
94
+ if "sync" in tool_config:
95
+ sync_data = tool_config["sync"]
96
+ if "enabled" in sync_data:
97
+ self.sync.enabled = sync_data["enabled"]
98
+ if "source" in sync_data:
99
+ self.sync.source = sync_data["source"]
100
+ if "exclude" in sync_data:
101
+ self.sync.exclude = sync_data["exclude"]
102
+ if "max_size_mb" in sync_data:
103
+ self.sync.max_size_mb = sync_data["max_size_mb"]
104
+ if "max_file_size_mb" in sync_data:
105
+ self.sync.max_file_size_mb = sync_data["max_file_size_mb"]
106
+ if "use_gitignore" in sync_data:
107
+ self.sync.use_gitignore = sync_data["use_gitignore"]
108
+
109
+ def validate(self) -> list[str]:
110
+ """Validate the configuration.
111
+
112
+ Note: Authentication is handled by the Databricks SDK, which
113
+ automatically resolves credentials from environment variables,
114
+ CLI config, or cloud provider authentication.
115
+
116
+ Returns:
117
+ List of validation error messages. Empty if valid.
118
+ """
119
+ errors: list[str] = []
120
+
121
+ if not self.cluster_id:
122
+ errors.append(
123
+ "DATABRICKS_CLUSTER_ID environment variable is not set. "
124
+ "Please set it to your Databricks cluster ID."
125
+ )
126
+
127
+ # Validate sync size limits
128
+ if self.sync.max_size_mb is not None and self.sync.max_size_mb <= 0:
129
+ errors.append("max_size_mb must be a positive number.")
130
+
131
+ if self.sync.max_file_size_mb is not None and self.sync.max_file_size_mb <= 0:
132
+ errors.append("max_file_size_mb must be a positive number.")
133
+
134
+ return errors
@@ -0,0 +1,244 @@
1
+ """Databricks execution context management."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ import re
7
+ import time
8
+ from dataclasses import dataclass, replace
9
+ from datetime import timedelta
10
+ from typing import TYPE_CHECKING
11
+
12
+ from databricks.sdk import WorkspaceClient
13
+ from databricks.sdk.service import compute
14
+
15
+ if TYPE_CHECKING:
16
+ from .config import Config
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ # Retry and timeout configuration
21
+ RECONNECT_DELAY_SECONDS = 1.0 # Delay before reconnection attempt
22
+ CONTEXT_CREATION_TIMEOUT = timedelta(minutes=5) # Timeout for context creation
23
+ COMMAND_EXECUTION_TIMEOUT = timedelta(minutes=10) # Timeout for command execution
24
+
25
+ # Pre-compiled pattern for context error detection
26
+ # Matches errors that specifically relate to execution context invalidation
27
+ CONTEXT_ERROR_PATTERN = re.compile(
28
+ r"context\s*(not\s*found|does\s*not\s*exist|is\s*invalid|expired)|"
29
+ r"invalid\s*context|"
30
+ r"\bcontext_id\b|"
31
+ r"execution\s*context",
32
+ re.IGNORECASE,
33
+ )
34
+
35
+
36
+ @dataclass
37
+ class ExecutionResult:
38
+ """Result of a command execution."""
39
+
40
+ status: str
41
+ output: str | None = None
42
+ error: str | None = None
43
+ traceback: list[str] | None = None
44
+ reconnected: bool = False
45
+
46
+
47
+ class DatabricksExecutor:
48
+ """Manages Databricks execution context and command execution."""
49
+
50
+ def __init__(self, config: Config) -> None:
51
+ """Initialize the executor.
52
+
53
+ Args:
54
+ config: Kernel configuration.
55
+ """
56
+ self.config = config
57
+ self.client: WorkspaceClient | None = None
58
+ self.context_id: str | None = None
59
+
60
+ def _ensure_client(self) -> WorkspaceClient:
61
+ """Ensure the WorkspaceClient is initialized.
62
+
63
+ Returns:
64
+ The WorkspaceClient instance.
65
+ """
66
+ if self.client is None:
67
+ self.client = WorkspaceClient()
68
+ return self.client
69
+
70
+ def create_context(self) -> None:
71
+ """Create an execution context on the Databricks cluster."""
72
+ if self.context_id is not None:
73
+ return # Context already exists
74
+
75
+ if not self.config.cluster_id:
76
+ raise ValueError("Cluster ID is not configured")
77
+
78
+ client = self._ensure_client()
79
+ response = client.command_execution.create(
80
+ cluster_id=self.config.cluster_id,
81
+ language=compute.Language.PYTHON,
82
+ ).result(timeout=CONTEXT_CREATION_TIMEOUT)
83
+
84
+ if response and response.id:
85
+ self.context_id = response.id
86
+
87
+ def reconnect(self) -> None:
88
+ """Recreate the execution context.
89
+
90
+ Destroys the old context (if any) and creates a new one.
91
+ Used when the existing context becomes invalid.
92
+ """
93
+ logger.info("Reconnecting: creating new execution context")
94
+ # Try to destroy old context to avoid resource leak on cluster
95
+ # Ignore errors since context may already be invalid
96
+ try:
97
+ self.destroy_context()
98
+ except Exception as e:
99
+ logger.debug("Failed to destroy old context: %s", e)
100
+ self.context_id = None
101
+ self.create_context()
102
+
103
+ def _is_context_invalid_error(self, error: Exception) -> bool:
104
+ """Check if an error indicates the context is invalid.
105
+
106
+ Only matches errors that specifically relate to execution context,
107
+ not general errors like "File not found" or "Variable not found".
108
+
109
+ Args:
110
+ error: The exception to check.
111
+
112
+ Returns:
113
+ True if the error indicates context invalidation.
114
+ """
115
+ error_str = str(error)
116
+
117
+ # Must contain "context" to be considered a context error (case-insensitive)
118
+ if "context" not in error_str.lower():
119
+ return False
120
+
121
+ # Use pre-compiled pattern for efficient matching
122
+ return CONTEXT_ERROR_PATTERN.search(error_str) is not None
123
+
124
+ def execute(self, code: str, *, allow_reconnect: bool = True) -> ExecutionResult:
125
+ """Execute code on the Databricks cluster.
126
+
127
+ Args:
128
+ code: The Python code to execute.
129
+ allow_reconnect: If True, attempt to reconnect on context errors.
130
+
131
+ Returns:
132
+ Execution result containing output or error.
133
+ """
134
+ if self.context_id is None:
135
+ self.create_context()
136
+
137
+ if self.context_id is None:
138
+ return ExecutionResult(
139
+ status="error",
140
+ error="Failed to create execution context",
141
+ )
142
+
143
+ if not self.config.cluster_id:
144
+ return ExecutionResult(
145
+ status="error",
146
+ error="Cluster ID is not configured",
147
+ )
148
+
149
+ try:
150
+ result = self._execute_internal(code)
151
+ return result
152
+ except Exception as e:
153
+ if allow_reconnect and self._is_context_invalid_error(e):
154
+ logger.warning("Context invalid, attempting reconnection: %s", e)
155
+ try:
156
+ # Wait before reconnection to avoid hammering the API
157
+ time.sleep(RECONNECT_DELAY_SECONDS)
158
+ self.reconnect()
159
+ result = self._execute_internal(code)
160
+ return replace(result, reconnected=True)
161
+ except Exception as retry_error:
162
+ logger.error("Reconnection failed: %s", retry_error)
163
+ return ExecutionResult(
164
+ status="error",
165
+ error=f"Reconnection failed: {retry_error}",
166
+ )
167
+ else:
168
+ logger.error("Execution failed: %s", e)
169
+ return ExecutionResult(
170
+ status="error",
171
+ error=str(e),
172
+ )
173
+
174
+ def _execute_internal(self, code: str) -> ExecutionResult:
175
+ """Internal execution without reconnection logic.
176
+
177
+ Args:
178
+ code: The Python code to execute.
179
+
180
+ Returns:
181
+ Execution result containing output or error.
182
+
183
+ Raises:
184
+ Exception: If execution fails due to API errors.
185
+ """
186
+ client = self._ensure_client()
187
+ response = client.command_execution.execute(
188
+ cluster_id=self.config.cluster_id,
189
+ context_id=self.context_id,
190
+ language=compute.Language.PYTHON,
191
+ command=code,
192
+ ).result(timeout=COMMAND_EXECUTION_TIMEOUT)
193
+
194
+ if response is None:
195
+ return ExecutionResult(
196
+ status="error",
197
+ error="No response from Databricks",
198
+ )
199
+
200
+ # Parse the response
201
+ status = str(response.status) if response.status else "unknown"
202
+
203
+ # Handle results
204
+ if response.results:
205
+ results = response.results
206
+
207
+ # Check for error
208
+ if results.cause:
209
+ return ExecutionResult(
210
+ status="error",
211
+ error=results.cause,
212
+ traceback=results.summary.split("\n") if results.summary else None,
213
+ )
214
+
215
+ # Get output
216
+ output = None
217
+ if results.data is not None:
218
+ output = str(results.data)
219
+ elif results.summary:
220
+ output = results.summary
221
+
222
+ return ExecutionResult(
223
+ status="ok",
224
+ output=output,
225
+ )
226
+
227
+ return ExecutionResult(status=status)
228
+
229
+ def destroy_context(self) -> None:
230
+ """Destroy the execution context."""
231
+ if self.context_id is None:
232
+ return
233
+
234
+ if not self.config.cluster_id:
235
+ return
236
+
237
+ try:
238
+ client = self._ensure_client()
239
+ client.command_execution.destroy(
240
+ cluster_id=self.config.cluster_id,
241
+ context_id=self.context_id,
242
+ )
243
+ finally:
244
+ self.context_id = None
@@ -0,0 +1,306 @@
1
+ """Databricks Session Kernel for Jupyter."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import uuid
6
+ from typing import Any
7
+
8
+ from ipykernel.kernelbase import Kernel
9
+
10
+ from . import __version__
11
+ from .config import Config
12
+ from .executor import DatabricksExecutor
13
+ from .sync import FileSync
14
+
15
+
16
+ class DatabricksKernel(Kernel):
17
+ """Jupyter kernel that executes code on a remote Databricks cluster."""
18
+
19
+ implementation = "databricks-session-kernel"
20
+ implementation_version = __version__
21
+ language = "python"
22
+ language_version = "3.11"
23
+ language_info = {
24
+ "name": "python",
25
+ "mimetype": "text/x-python",
26
+ "file_extension": ".py",
27
+ }
28
+ banner = "Databricks Session Kernel - Execute Python on Databricks clusters"
29
+
30
+ def __init__(self, **kwargs: Any) -> None:
31
+ """Initialize the Databricks kernel."""
32
+ super().__init__(**kwargs)
33
+ self._kernel_config = Config.load()
34
+ self._session_id = str(uuid.uuid4())[:8]
35
+ self.executor: DatabricksExecutor | None = None
36
+ self.file_sync: FileSync | None = None
37
+ self._initialized = False
38
+ self._last_dbfs_path: str | None = None
39
+
40
+ def _initialize(self) -> bool:
41
+ """Initialize the Databricks connection.
42
+
43
+ Returns:
44
+ True if initialization succeeded, False otherwise.
45
+ """
46
+ if self._initialized:
47
+ return True
48
+
49
+ # Validate configuration
50
+ errors = self._kernel_config.validate()
51
+ if errors:
52
+ for error in errors:
53
+ self.send_response(
54
+ self.iopub_socket,
55
+ "stream",
56
+ {"name": "stderr", "text": f"Configuration error: {error}\n"},
57
+ )
58
+ return False
59
+
60
+ # Initialize executor and file sync (reuse existing if available)
61
+ if self.executor is None:
62
+ self.executor = DatabricksExecutor(self._kernel_config)
63
+ if self.file_sync is None:
64
+ self.file_sync = FileSync(self._kernel_config, self._session_id)
65
+ self._initialized = True
66
+ return True
67
+
68
+ def _sync_files(self) -> bool:
69
+ """Synchronize files if needed.
70
+
71
+ Returns:
72
+ True if sync succeeded or was not needed, False on error.
73
+ """
74
+ if self.file_sync is None or self.executor is None:
75
+ return True
76
+
77
+ if not self.file_sync.needs_sync():
78
+ return True
79
+
80
+ try:
81
+ self.send_response(
82
+ self.iopub_socket,
83
+ "stream",
84
+ {"name": "stderr", "text": "Syncing files to Databricks...\n"},
85
+ )
86
+
87
+ # Upload files
88
+ stats = self.file_sync.sync()
89
+ self._last_dbfs_path = stats.dbfs_path
90
+
91
+ # Execute setup code on remote
92
+ setup_code = self.file_sync.get_setup_code(stats.dbfs_path)
93
+ result = self.executor.execute(setup_code, allow_reconnect=False)
94
+
95
+ if result.status != "ok":
96
+ self.send_response(
97
+ self.iopub_socket,
98
+ "stream",
99
+ {"name": "stderr", "text": f"Sync setup failed: {result.error}\n"},
100
+ )
101
+ return False
102
+
103
+ self.send_response(
104
+ self.iopub_socket,
105
+ "stream",
106
+ {"name": "stderr", "text": "Files synced successfully.\n"},
107
+ )
108
+ return True
109
+
110
+ except Exception as e:
111
+ self.send_response(
112
+ self.iopub_socket,
113
+ "stream",
114
+ {"name": "stderr", "text": f"Sync failed: {e}\n"},
115
+ )
116
+ # Continue execution even if sync fails
117
+ return True
118
+
119
+ def _handle_reconnection(self) -> None:
120
+ """Handle session reconnection.
121
+
122
+ Re-runs the setup code to restore sys.path and notifies the user.
123
+ """
124
+ # Notify user about reconnection
125
+ self.send_response(
126
+ self.iopub_socket,
127
+ "stream",
128
+ {
129
+ "name": "stderr",
130
+ "text": "Session reconnected. Variables have been reset.\n",
131
+ },
132
+ )
133
+
134
+ # Re-run setup code if we have synced files before
135
+ if self.file_sync and self._last_dbfs_path and self.executor:
136
+ try:
137
+ setup_code = self.file_sync.get_setup_code(self._last_dbfs_path)
138
+ result = self.executor.execute(setup_code, allow_reconnect=False)
139
+ if result.status != "ok":
140
+ err = result.error
141
+ self.send_response(
142
+ self.iopub_socket,
143
+ "stream",
144
+ {
145
+ "name": "stderr",
146
+ "text": f"Warning: Failed to restore sys.path: {err}\n",
147
+ },
148
+ )
149
+ except Exception as e:
150
+ # Notify user but don't fail the main execution
151
+ self.send_response(
152
+ self.iopub_socket,
153
+ "stream",
154
+ {
155
+ "name": "stderr",
156
+ "text": f"Warning: Failed to restore sys.path: {e}\n",
157
+ },
158
+ )
159
+
160
+ async def do_execute(
161
+ self,
162
+ code: Any,
163
+ silent: Any,
164
+ store_history: Any = True,
165
+ user_expressions: Any = None,
166
+ allow_stdin: Any = False,
167
+ *,
168
+ cell_meta: Any = None,
169
+ cell_id: Any = None,
170
+ ) -> dict[str, Any]:
171
+ """Execute code on the Databricks cluster.
172
+
173
+ Args:
174
+ code: The code to execute.
175
+ silent: Whether to suppress output.
176
+ store_history: Whether to store the code in history.
177
+ user_expressions: User expressions to evaluate.
178
+ allow_stdin: Whether to allow stdin.
179
+ cell_id: The cell ID.
180
+
181
+ Returns:
182
+ Execution result dictionary.
183
+ """
184
+ # Skip empty code
185
+ code_str = str(code).strip()
186
+ if not code_str:
187
+ return {
188
+ "status": "ok",
189
+ "execution_count": self.execution_count,
190
+ "payload": [],
191
+ "user_expressions": {},
192
+ }
193
+
194
+ # Initialize on first execution
195
+ if not self._initialize():
196
+ return {
197
+ "status": "error",
198
+ "execution_count": self.execution_count,
199
+ "ename": "ConfigurationError",
200
+ "evalue": "Failed to initialize Databricks connection",
201
+ "traceback": [],
202
+ }
203
+
204
+ # Sync files before execution
205
+ self._sync_files()
206
+
207
+ # Execute on Databricks
208
+ assert self.executor is not None
209
+ try:
210
+ result = self.executor.execute(code_str)
211
+
212
+ # Handle reconnection: re-run setup code and notify user
213
+ if result.reconnected:
214
+ self._handle_reconnection()
215
+
216
+ if result.status == "ok":
217
+ if not silent and result.output:
218
+ self.send_response(
219
+ self.iopub_socket,
220
+ "stream",
221
+ {"name": "stdout", "text": result.output},
222
+ )
223
+ return {
224
+ "status": "ok",
225
+ "execution_count": self.execution_count,
226
+ "payload": [],
227
+ "user_expressions": {},
228
+ }
229
+ else:
230
+ # Handle error
231
+ error_msg = result.error or "Unknown error"
232
+ traceback = result.traceback or []
233
+
234
+ if not silent:
235
+ self.send_response(
236
+ self.iopub_socket,
237
+ "error",
238
+ {
239
+ "ename": "ExecutionError",
240
+ "evalue": error_msg,
241
+ "traceback": traceback,
242
+ },
243
+ )
244
+ return {
245
+ "status": "error",
246
+ "execution_count": self.execution_count,
247
+ "ename": "ExecutionError",
248
+ "evalue": error_msg,
249
+ "traceback": traceback,
250
+ }
251
+
252
+ except Exception as e:
253
+ error_msg = str(e)
254
+ if not silent:
255
+ self.send_response(
256
+ self.iopub_socket,
257
+ "error",
258
+ {
259
+ "ename": type(e).__name__,
260
+ "evalue": error_msg,
261
+ "traceback": [error_msg],
262
+ },
263
+ )
264
+ return {
265
+ "status": "error",
266
+ "execution_count": self.execution_count,
267
+ "ename": type(e).__name__,
268
+ "evalue": error_msg,
269
+ "traceback": [error_msg],
270
+ }
271
+
272
+ async def do_shutdown(self, restart: bool) -> dict[str, Any]:
273
+ """Shutdown the kernel.
274
+
275
+ Args:
276
+ restart: Whether this is a restart.
277
+
278
+ Returns:
279
+ Shutdown result dictionary.
280
+ """
281
+ if restart:
282
+ # On restart, keep the execution context alive for session continuity
283
+ # Only reset the initialized flag so we can re-initialize on next execute
284
+ self._initialized = False
285
+ return {"status": "ok", "restart": restart}
286
+
287
+ # Full shutdown: clean up everything
288
+ # Clean up file sync
289
+ if self.file_sync:
290
+ try:
291
+ self.file_sync.cleanup()
292
+ except Exception:
293
+ pass
294
+ self.file_sync = None
295
+
296
+ # Destroy execution context
297
+ if self.executor:
298
+ try:
299
+ self.executor.destroy_context()
300
+ except Exception:
301
+ pass
302
+ self.executor = None
303
+
304
+ self._initialized = False
305
+ self._last_dbfs_path = None
306
+ return {"status": "ok", "restart": restart}