flyteplugins-papermill 2.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,15 @@
1
+ __all__ = [
2
+ "NotebookTask",
3
+ "load_dataframe",
4
+ "load_dir",
5
+ "load_file",
6
+ "record_outputs",
7
+ ]
8
+
9
+ from flyteplugins.papermill.notebook import (
10
+ load_dataframe,
11
+ load_dir,
12
+ load_file,
13
+ record_outputs,
14
+ )
15
+ from flyteplugins.papermill.task import NotebookTask
@@ -0,0 +1,97 @@
1
+ """Notebook kernel setup — called from the injected setup cell.
2
+
3
+ This module is imported inside the notebook kernel subprocess. Keep imports
4
+ at function scope so the module itself has no side-effects on import.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+
10
+ def initialize_context() -> None:
11
+ """Initialize the Flyte runtime context inside a notebook kernel.
12
+
13
+ Reads the serialized task context from the ``_FLYTE_NB_CTX`` environment
14
+ variable (set by the parent task runner before launching papermill) and
15
+ reconstructs a ``TaskContext`` + ContextVar so that Flyte APIs work
16
+ normally inside the notebook.
17
+
18
+ Does nothing if the environment variable is not set.
19
+ """
20
+ import json
21
+ import os
22
+
23
+ raw = os.environ.get("_FLYTE_NB_CTX")
24
+ if not raw:
25
+ return
26
+
27
+ try:
28
+ data = json.loads(raw)
29
+
30
+ import flyte.report as _report
31
+ from flyte._context import Context, ContextData, root_context_var
32
+ from flyte._internal.imagebuild.image_builder import ImageCache
33
+ from flyte.models import ActionID, CodeBundle, RawDataPath, TaskContext
34
+
35
+ if data.get("mode") == "local":
36
+ # Local mode: no controller or remote connection needed.
37
+ # Set _init_config directly to avoid @syncify running in a
38
+ # background thread, which can cause visibility issues with
39
+ # module-level globals in the kernel's main thread.
40
+ import flyte._initialize as _init_mod
41
+
42
+ with _init_mod._init_lock:
43
+ if _init_mod._init_config is None:
44
+ from pathlib import Path
45
+
46
+ from flyte._initialize import _InitConfig
47
+
48
+ _init_mod._init_config = _InitConfig(root_dir=Path.cwd())
49
+ else:
50
+ # Remote mode: use the same init + controller pattern as runtime.py.
51
+ from flyte._initialize import init_in_cluster
52
+ from flyte._internal.controllers import create_controller
53
+
54
+ controller_kwargs = init_in_cluster()
55
+ create_controller(ct="remote", **controller_kwargs)
56
+
57
+ action = ActionID(
58
+ name=data["action_name"],
59
+ run_name=data["run_name"],
60
+ project=data["project"],
61
+ domain=data["domain"],
62
+ org=data["org"],
63
+ )
64
+
65
+ cb = None
66
+ if "code_bundle" in data:
67
+ cbd = data["code_bundle"]
68
+ cb = CodeBundle(
69
+ tgz=cbd.get("tgz"),
70
+ pkl=cbd.get("pkl"),
71
+ destination=cbd.get("destination", "."),
72
+ computed_version=cbd.get("computed_version", ""),
73
+ )
74
+
75
+ ic = None
76
+ if "image_cache" in data:
77
+ ic = ImageCache.from_transport(data["image_cache"])
78
+
79
+ tctx = TaskContext(
80
+ action=action,
81
+ version=data["version"],
82
+ output_path=data["output_path"],
83
+ run_base_dir=data["run_base_dir"],
84
+ raw_data_path=RawDataPath(path=data["raw_data_path"]),
85
+ mode=data["mode"],
86
+ interactive_mode=data.get("interactive_mode", False),
87
+ report=_report.Report(name=action.name),
88
+ code_bundle=cb,
89
+ compiled_image_cache=ic,
90
+ )
91
+ root_context_var.set(Context(data=ContextData(task_context=tctx)))
92
+ print("[flyte-notebook] Context initialized successfully")
93
+ except Exception as err:
94
+ import traceback
95
+
96
+ print(f"[flyte-notebook] WARNING: Failed to initialize context: {err}")
97
+ traceback.print_exc()
@@ -0,0 +1,122 @@
1
+ """Helpers for use inside Jupyter notebooks executed by NotebookTask.
2
+
3
+ These functions are meant to be called from within a notebook that is
4
+ being run as a Flyte task via papermill.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from typing import Any
10
+
11
+
12
+ def record_outputs(**kwargs: Any) -> str:
13
+ """Record output values from a notebook for use by downstream Flyte tasks.
14
+
15
+ Call this as the **last expression** in a cell tagged ``"outputs"``.
16
+ The returned protobuf text is captured by Jupyter as the cell output
17
+ and later extracted by ``NotebookTask``.
18
+
19
+ Values are serialized as Flyte Literals, so any type supported by
20
+ Flyte's type system works — primitives, ``File``, ``Dir``,
21
+ ``DataFrame``, dataclasses, etc.
22
+
23
+ Example (cell tagged ``"outputs"``)::
24
+
25
+ from flyteplugins.papermill import record_outputs
26
+
27
+ record_outputs(result=42, summary="done")
28
+
29
+ Args:
30
+ **kwargs: Output name/value pairs.
31
+
32
+ Returns:
33
+ Protobuf text representation of a ``LiteralMap``. Jupyter captures
34
+ this as the cell's text/plain output.
35
+ """
36
+ from flyte.syncify import syncify
37
+ from flyte.types import TypeEngine
38
+ from google.protobuf import text_format
39
+
40
+ @syncify
41
+ async def _to_literal_map(values: dict[str, Any]):
42
+ from flyteidl2.core.literals_pb2 import LiteralMap
43
+
44
+ literals = {}
45
+ for name, val in values.items():
46
+ py_type = type(val)
47
+ lit_type = TypeEngine.to_literal_type(py_type)
48
+ literals[name] = await TypeEngine.to_literal(val, py_type, lit_type)
49
+ return LiteralMap(literals=literals)
50
+
51
+ literal_map = _to_literal_map(kwargs)
52
+ return text_format.MessageToString(literal_map)
53
+
54
+
55
+ def load_file(path: str):
56
+ """Load a ``flyte.io.File`` from a serialized path inside a notebook.
57
+
58
+ When a ``File`` is passed as an input to a ``NotebookTask``, it is
59
+ serialized to its remote path string for papermill injection. Use
60
+ this helper to reconstruct the ``File`` object inside the notebook::
61
+
62
+ from flyteplugins.papermill import load_file
63
+
64
+ f = load_file(my_file_path) # my_file_path injected by papermill
65
+ with f.open_sync() as fh:
66
+ data = fh.read()
67
+
68
+ Args:
69
+ path: The remote path string (injected as a papermill parameter).
70
+
71
+ Returns:
72
+ A ``flyte.io.File`` instance pointing at the remote path.
73
+ """
74
+ from flyte.io import File
75
+
76
+ return File(path=path)
77
+
78
+
79
+ def load_dir(path: str):
80
+ """Load a ``flyte.io.Dir`` from a serialized path inside a notebook.
81
+
82
+ When a ``Dir`` is passed as an input to a ``NotebookTask``, it is
83
+ serialized to its remote path string. Use this helper to
84
+ reconstruct it::
85
+
86
+ from flyteplugins.papermill import load_dir
87
+
88
+ d = load_dir(my_dir_path)
89
+
90
+ Args:
91
+ path: The remote path string (injected as a papermill parameter).
92
+
93
+ Returns:
94
+ A ``flyte.io.Dir`` instance pointing at the remote path.
95
+ """
96
+ from flyte.io import Dir
97
+
98
+ return Dir(path=path)
99
+
100
+
101
+ def load_dataframe(uri: str, fmt: str = "parquet"):
102
+ """Load a ``flyte.io.DataFrame`` from a serialized URI inside a notebook.
103
+
104
+ When a ``DataFrame`` is passed as an input to a ``NotebookTask``, it is
105
+ serialized to its remote URI for papermill injection. Use this helper
106
+ to reconstruct it::
107
+
108
+ from flyteplugins.papermill import load_dataframe
109
+
110
+ df = load_dataframe(my_df_uri)
111
+ pandas_df = df.all() # materializes as pandas DataFrame
112
+
113
+ Args:
114
+ uri: The remote URI string (injected as a papermill parameter).
115
+ fmt: The storage format (default ``"parquet"``).
116
+
117
+ Returns:
118
+ A ``flyte.io.DataFrame`` instance pointing at the remote URI.
119
+ """
120
+ from flyte.io import DataFrame
121
+
122
+ return DataFrame(uri=uri, format=fmt)
@@ -0,0 +1,164 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import os
5
+ import pathlib
6
+ import sys
7
+
8
+ from flyte._task import TaskTemplate
9
+
10
+
11
+ class NotebookTaskResolver:
12
+ """Resolver for NotebookTask instances.
13
+
14
+ Serializes all task state (notebook path, type schemas, execution config)
15
+ into the loader args at serialization time so the task can be reconstructed
16
+ in the container without importing the user's module. This lets NotebookTask
17
+ be defined inline (e.g. inside a task) rather than only at module level.
18
+
19
+ Loader args format (key-value pairs):
20
+ notebook <relative-or-absolute notebook path>
21
+ name <task name>
22
+ input-schema <JSON: {field: LiteralType dict}>
23
+ output-schema <JSON: {field: LiteralType dict}>
24
+ config <JSON: execution params>
25
+ """
26
+
27
+ @property
28
+ def import_path(self) -> str:
29
+ return "flyteplugins.papermill.resolver.NotebookTaskResolver"
30
+
31
+ def load_task(self, loader_args: list[str]) -> TaskTemplate:
32
+ # Ensure all IO type transformers are registered before attempting
33
+ # guess_python_type — they are registered as side-effects of import.
34
+ import flyte.io # noqa: F401
35
+ from flyte.types import TypeEngine
36
+ from flyteidl2.core.types_pb2 import LiteralType
37
+ from google.protobuf import json_format
38
+
39
+ from flyteplugins.papermill.task import NotebookTask
40
+
41
+ # Parse flat key-value list
42
+ it = iter(loader_args)
43
+ args_dict: dict[str, str] = {}
44
+ for key in it:
45
+ try:
46
+ args_dict[key] = next(it)
47
+ except StopIteration:
48
+ raise ValueError(f"Odd number of loader args — missing value for key '{key}'")
49
+
50
+ notebook_path = args_dict["notebook"]
51
+ name = args_dict["name"]
52
+
53
+ # Relative paths are stored relative to the bundle root (root_dir at
54
+ # serialization time). At execution time, the bundle extraction dir is
55
+ # prepended to sys.path by download_code_bundle, so we search there.
56
+ if not os.path.isabs(notebook_path):
57
+ for p in sys.path:
58
+ candidate = os.path.abspath(os.path.join(p, notebook_path))
59
+ if os.path.exists(candidate):
60
+ notebook_path = candidate
61
+ break
62
+ else:
63
+ # CWD fallback: covers the common destination="." case where
64
+ # the bundle is extracted into the current working directory.
65
+ notebook_path = os.path.abspath(notebook_path)
66
+
67
+ def _schema_to_types(schema_json: str) -> dict | None:
68
+ schema = json.loads(schema_json)
69
+ if not schema:
70
+ return None
71
+ result: dict = {}
72
+ for field_name, lt_dict in schema.items():
73
+ lt = LiteralType()
74
+ json_format.ParseDict(lt_dict, lt)
75
+ result[field_name] = TypeEngine.guess_python_type(lt)
76
+ return result
77
+
78
+ inputs = _schema_to_types(args_dict.get("input-schema", "{}"))
79
+ outputs = _schema_to_types(args_dict.get("output-schema", "{}"))
80
+ config: dict = json.loads(args_dict.get("config", "{}"))
81
+
82
+ return NotebookTask(
83
+ name=name,
84
+ notebook_path=notebook_path,
85
+ task_environment=None,
86
+ inputs=inputs,
87
+ outputs=outputs,
88
+ **config,
89
+ )
90
+
91
+ def loader_args(self, task: TaskTemplate, root_dir: pathlib.Path | None) -> list[str]:
92
+ from flyte.types import TypeEngine
93
+ from google.protobuf import json_format
94
+
95
+ from flyteplugins.papermill.task import NotebookTask
96
+
97
+ if not isinstance(task, NotebookTask):
98
+ raise TypeError(f"NotebookTaskResolver only handles NotebookTask, got {type(task)}")
99
+
100
+ # If the notebook is inside the bundle root, store a relative path so
101
+ # it resolves correctly wherever the bundle is extracted in the container.
102
+ # If outside root_dir (or no root_dir), preserve the path exactly as the
103
+ # user wrote it: absolute paths are expected to exist in the container
104
+ # image; relative paths remain relative and must be resolvable from CWD
105
+ # or sys.path at execution time.
106
+ nb_path = pathlib.Path(task._resolved_notebook_path)
107
+ if root_dir is not None:
108
+ try:
109
+ notebook_arg = str(nb_path.relative_to(pathlib.Path(root_dir).resolve()))
110
+ except ValueError:
111
+ notebook_arg = str(nb_path)
112
+ else:
113
+ # No bundle root: use the original path as the user wrote it.
114
+ notebook_arg = task.notebook_path
115
+
116
+ def _types_to_schema(types_dict: dict) -> dict:
117
+ schema: dict = {}
118
+ for field_name, typ in types_dict.items():
119
+ lt = TypeEngine.to_literal_type(typ)
120
+ schema[field_name] = json_format.MessageToDict(lt)
121
+ return schema
122
+
123
+ # Inputs
124
+ input_types = {name: typ for name, (typ, _) in (task.interface.inputs or {}).items()}
125
+ input_schema = _types_to_schema(input_types)
126
+
127
+ # Outputs: strip the auto-added notebook File outputs so the reconstructed
128
+ # task can re-add them via output_notebooks=True in the config.
129
+ skip_outputs = {"output_notebook", "output_notebook_executed"} if task.output_notebooks else set()
130
+ output_types = {name: typ for name, typ in (task.interface.outputs or {}).items() if name not in skip_outputs}
131
+ output_schema = _types_to_schema(output_types)
132
+
133
+ # Execution config (plugin_config is serialization-only; not needed at execution time)
134
+ config: dict = {
135
+ "log_output": task.log_output,
136
+ "start_timeout": task.start_timeout,
137
+ "report_mode": task.report_mode,
138
+ "request_save_on_cell_execute": task.request_save_on_cell_execute,
139
+ "progress_bar": task.progress_bar,
140
+ "output_notebooks": task.output_notebooks,
141
+ }
142
+ if task.kernel_name is not None:
143
+ config["kernel_name"] = task.kernel_name
144
+ if task.engine_name is not None:
145
+ config["engine_name"] = task.engine_name
146
+ if task.execution_timeout is not None:
147
+ config["execution_timeout"] = task.execution_timeout
148
+ if task.language is not None:
149
+ config["language"] = task.language
150
+ if task.engine_kwargs:
151
+ config["engine_kwargs"] = task.engine_kwargs
152
+
153
+ return [
154
+ "notebook",
155
+ notebook_arg,
156
+ "name",
157
+ task.name,
158
+ "input-schema",
159
+ json.dumps(input_schema),
160
+ "output-schema",
161
+ json.dumps(output_schema),
162
+ "config",
163
+ json.dumps(config),
164
+ ]