flyteplugins-omegaconf 2.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,24 @@
1
+ """OmegaConf DictConfig/ListConfig support for Flyte."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import functools
6
+
7
+ from flyte.types._type_engine import TypeEngine
8
+
9
+ from .base_transformer import DictConfigTransformer, ListConfigTransformer
10
+
11
+
12
+ @functools.lru_cache(maxsize=None)
13
+ def register_omegaconf_transformers() -> None:
14
+ """Register OmegaConf transformers with Flyte TypeEngine.
15
+
16
+ Called via the ``flyte.plugins.types`` entry point on import, or manually
17
+ by importing this package.
18
+ """
19
+ TypeEngine.register(DictConfigTransformer())
20
+ TypeEngine.register(ListConfigTransformer())
21
+
22
+
23
+ # Register at module import time for backwards compatibility.
24
+ register_omegaconf_transformers()
@@ -0,0 +1,68 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Generic, Type, TypeVar
4
+
5
+ import msgpack
6
+ from flyte.types._type_engine import (
7
+ CACHE_KEY_METADATA,
8
+ MESSAGEPACK,
9
+ SERIALIZATION_FORMAT,
10
+ TypeTransformer,
11
+ )
12
+ from flyteidl2.core.literals_pb2 import Binary, Literal, Scalar
13
+ from flyteidl2.core.types_pb2 import LiteralType, SimpleType, TypeAnnotation
14
+ from google.protobuf import struct_pb2
15
+
16
+ from omegaconf import DictConfig, ListConfig
17
+
18
+ from .codec import deserialize_omegaconf, serialize_omegaconf
19
+
20
+ T = TypeVar("T", DictConfig, ListConfig)
21
+
22
+
23
+ class OmegaConfTransformerBase(TypeTransformer[T], Generic[T]):
24
+ def __init__(self, name: str, container_type: Type[T]):
25
+ super().__init__(name, container_type)
26
+ self._container_type = container_type
27
+
28
+ def get_literal_type(self, t: Type[T]) -> LiteralType:
29
+ meta_struct = struct_pb2.Struct()
30
+ meta_struct.update({CACHE_KEY_METADATA: {SERIALIZATION_FORMAT: MESSAGEPACK}})
31
+ return LiteralType(
32
+ simple=SimpleType.STRUCT,
33
+ annotation=TypeAnnotation(annotations=meta_struct),
34
+ )
35
+
36
+ async def to_literal(
37
+ self,
38
+ python_val: T,
39
+ python_type: Type[T],
40
+ expected: LiteralType,
41
+ ) -> Literal:
42
+ payload = serialize_omegaconf(python_val)
43
+ msgpack_bytes = msgpack.dumps(payload)
44
+ return Literal(scalar=Scalar(binary=Binary(value=msgpack_bytes, tag=MESSAGEPACK)))
45
+
46
+ def from_binary_idl(self, binary_idl_object: Binary, expected_python_type: Type[T]) -> T:
47
+ if binary_idl_object.tag != MESSAGEPACK:
48
+ raise TypeError(f"Unsupported binary format: `{binary_idl_object.tag}`")
49
+ payload = msgpack.loads(binary_idl_object.value, strict_map_key=False)
50
+ config = deserialize_omegaconf(payload)
51
+ if not isinstance(config, self._container_type):
52
+ raise TypeError(f"Expected {self._container_type.__name__} payload, got {type(config).__name__}")
53
+ return config
54
+
55
+ async def to_python_value(self, lv: Literal, expected_python_type: Type[T]) -> T:
56
+ if lv and lv.HasField("scalar") and lv.scalar.HasField("binary"):
57
+ return self.from_binary_idl(lv.scalar.binary, expected_python_type)
58
+ raise TypeError(f"Cannot convert literal to {self._container_type.__name__}: {lv}")
59
+
60
+
61
+ class DictConfigTransformer(OmegaConfTransformerBase[DictConfig]):
62
+ def __init__(self):
63
+ super().__init__("OmegaConf DictConfig Transformer", DictConfig)
64
+
65
+
66
+ class ListConfigTransformer(OmegaConfTransformerBase[ListConfig]):
67
+ def __init__(self):
68
+ super().__init__("OmegaConf ListConfig Transformer", ListConfig)
@@ -0,0 +1,169 @@
1
+ from __future__ import annotations
2
+
3
+ import importlib
4
+ from enum import Enum
5
+ from pathlib import Path, PurePath
6
+ from typing import Any
7
+
8
+ from omegaconf import MISSING, DictConfig, ListConfig, OmegaConf
9
+
10
+ PAYLOAD_MARKER = "__flyte_omegaconf__"
11
+ PAYLOAD_KIND = "kind"
12
+ PAYLOAD_VALUES = "values"
13
+ PAYLOAD_SCHEMA = "schema"
14
+ PAYLOAD_TYPE = "type"
15
+ PAYLOAD_NAME = "name"
16
+ PAYLOAD_VALUE = "value"
17
+
18
+ KIND_DICT = "dict"
19
+ KIND_LIST = "list"
20
+ KIND_MISSING = "missing"
21
+ KIND_ENUM = "enum"
22
+ KIND_PATH = "path"
23
+ KIND_TUPLE = "tuple"
24
+
25
+
26
+ def serialize_omegaconf(node: DictConfig | ListConfig) -> dict[str, Any]:
27
+ return _serialize_value(node)
28
+
29
+
30
+ def deserialize_omegaconf(payload: dict[str, Any]) -> DictConfig | ListConfig:
31
+ config = _deserialize_value(payload)
32
+ if isinstance(config, (DictConfig, ListConfig)):
33
+ return config
34
+ raise TypeError(f"Expected OmegaConf container payload, got {type(config).__name__}")
35
+
36
+
37
+ def _serialize_value(value: Any) -> Any:
38
+ if isinstance(value, DictConfig):
39
+ payload: dict[str, Any] = {
40
+ PAYLOAD_MARKER: True,
41
+ PAYLOAD_KIND: KIND_DICT,
42
+ PAYLOAD_VALUES: {},
43
+ }
44
+ schema_name = _schema_name(value)
45
+ if schema_name is not None:
46
+ payload[PAYLOAD_SCHEMA] = schema_name
47
+
48
+ for key in value.keys():
49
+ child_node = value._get_node(key)
50
+ payload[PAYLOAD_VALUES][key] = _serialize_child(value, key, child_node)
51
+ return payload
52
+
53
+ if isinstance(value, ListConfig):
54
+ payload = {
55
+ PAYLOAD_MARKER: True,
56
+ PAYLOAD_KIND: KIND_LIST,
57
+ PAYLOAD_VALUES: [],
58
+ }
59
+ for index in range(len(value)):
60
+ child_node = value._get_node(index)
61
+ payload[PAYLOAD_VALUES].append(_serialize_child(value, index, child_node))
62
+ return payload
63
+
64
+ if isinstance(value, Enum):
65
+ return {
66
+ PAYLOAD_MARKER: True,
67
+ PAYLOAD_KIND: KIND_ENUM,
68
+ PAYLOAD_TYPE: _qualified_name(type(value)),
69
+ PAYLOAD_NAME: value.name,
70
+ PAYLOAD_VALUE: value.value,
71
+ }
72
+
73
+ if isinstance(value, PurePath):
74
+ return {
75
+ PAYLOAD_MARKER: True,
76
+ PAYLOAD_KIND: KIND_PATH,
77
+ PAYLOAD_VALUE: str(value),
78
+ }
79
+
80
+ if isinstance(value, tuple):
81
+ return {
82
+ PAYLOAD_MARKER: True,
83
+ PAYLOAD_KIND: KIND_TUPLE,
84
+ PAYLOAD_VALUES: [_serialize_value(item) for item in value],
85
+ }
86
+
87
+ if isinstance(value, dict):
88
+ return {key: _serialize_value(item) for key, item in value.items()}
89
+
90
+ if isinstance(value, list):
91
+ return [_serialize_value(item) for item in value]
92
+
93
+ return value
94
+
95
+
96
+ def _serialize_child(parent: DictConfig | ListConfig, key: str | int, child_node: Any) -> Any:
97
+ if child_node._is_missing():
98
+ return {
99
+ PAYLOAD_MARKER: True,
100
+ PAYLOAD_KIND: KIND_MISSING,
101
+ }
102
+ return _serialize_value(parent[key])
103
+
104
+
105
+ def _deserialize_value(payload: Any) -> Any:
106
+ if isinstance(payload, list):
107
+ return [_deserialize_value(item) for item in payload]
108
+
109
+ if not isinstance(payload, dict):
110
+ return payload
111
+
112
+ if payload.get(PAYLOAD_MARKER) is not True:
113
+ return {key: _deserialize_value(value) for key, value in payload.items()}
114
+
115
+ kind = payload[PAYLOAD_KIND]
116
+
117
+ if kind == KIND_MISSING:
118
+ return MISSING
119
+
120
+ if kind == KIND_ENUM:
121
+ enum_value = payload[PAYLOAD_VALUE]
122
+ enum_name = payload[PAYLOAD_NAME]
123
+ enum_type_name = payload[PAYLOAD_TYPE]
124
+ try:
125
+ enum_type = _import_class(enum_type_name)
126
+ return enum_type[enum_name]
127
+ except (ImportError, AttributeError, ModuleNotFoundError, KeyError, ValueError):
128
+ return enum_value
129
+
130
+ if kind == KIND_PATH:
131
+ return Path(payload[PAYLOAD_VALUE])
132
+
133
+ if kind == KIND_TUPLE:
134
+ return tuple(_deserialize_value(item) for item in payload[PAYLOAD_VALUES])
135
+
136
+ if kind == KIND_LIST:
137
+ values = [_deserialize_value(item) for item in payload[PAYLOAD_VALUES]]
138
+ return OmegaConf.create(values)
139
+
140
+ if kind == KIND_DICT:
141
+ values = {key: _deserialize_value(value) for key, value in payload[PAYLOAD_VALUES].items()}
142
+ schema_name = payload.get(PAYLOAD_SCHEMA)
143
+ if schema_name is None:
144
+ return OmegaConf.create(values)
145
+
146
+ try:
147
+ schema_type = _import_class(schema_name)
148
+ return OmegaConf.merge(OmegaConf.structured(schema_type), values)
149
+ except (ImportError, AttributeError, ModuleNotFoundError):
150
+ return OmegaConf.create(values)
151
+
152
+ raise TypeError(f"Unsupported OmegaConf payload kind: {kind}")
153
+
154
+
155
+ def _schema_name(config: DictConfig) -> str | None:
156
+ schema_type = OmegaConf.get_type(config)
157
+ if schema_type in (None, dict):
158
+ return None
159
+ return _qualified_name(schema_type)
160
+
161
+
162
+ def _qualified_name(value_type: type) -> str:
163
+ return f"{value_type.__module__}.{value_type.__qualname__}"
164
+
165
+
166
+ def _import_class(fully_qualified_name: str) -> type:
167
+ module_name, class_name = fully_qualified_name.rsplit(".", 1)
168
+ module = importlib.import_module(module_name)
169
+ return getattr(module, class_name)
@@ -0,0 +1,193 @@
1
+ Metadata-Version: 2.4
2
+ Name: flyteplugins-omegaconf
3
+ Version: 2.1.6
4
+ Summary: OmegaConf DictConfig/ListConfig support for Flyte
5
+ Author-email: Samhita Alla <samhita@union.ai>
6
+ Requires-Python: >=3.10
7
+ Description-Content-Type: text/markdown
8
+ Requires-Dist: flyte
9
+ Requires-Dist: omegaconf
10
+
11
+ # flyteplugins-omegaconf
12
+
13
+ Enables [OmegaConf](https://omegaconf.readthedocs.io/) `DictConfig` and `ListConfig` as typed inputs and outputs for Flyte tasks.
14
+
15
+ ## Installation
16
+
17
+ ```bash
18
+ pip install flyteplugins-omegaconf
19
+ ```
20
+
21
+ Installing the package automatically registers the `DictConfig` and `ListConfig` transformers with Flyte's TypeEngine via the `flyte.plugins.types` entry point.
22
+
23
+ ## Usage
24
+
25
+ ### DictConfig as task inputs and outputs
26
+
27
+ ```python
28
+ import flyte
29
+ from omegaconf import DictConfig, OmegaConf
30
+
31
+ env = flyte.TaskEnvironment(name="training", image=...)
32
+
33
+ @env.task
34
+ async def preprocess(cfg: DictConfig) -> DictConfig:
35
+ return OmegaConf.merge(cfg, {"data": {"normalized": True}})
36
+
37
+ @env.task
38
+ async def train(cfg: DictConfig) -> float:
39
+ return run_experiment(cfg.optimizer.lr, cfg.training.epochs)
40
+
41
+ @env.task
42
+ async def pipeline() -> float:
43
+ cfg = OmegaConf.create({"optimizer": {"lr": 0.001}, "training": {"epochs": 10}})
44
+ processed = await preprocess(cfg)
45
+ return await train(processed)
46
+ ```
47
+
48
+ ### ListConfig as task inputs and outputs
49
+
50
+ ```python
51
+ from omegaconf import ListConfig, OmegaConf
52
+
53
+ @env.task
54
+ async def build_lr_schedule(base_lr: float, num_stages: int) -> ListConfig:
55
+ return OmegaConf.create([base_lr * (0.5 ** i) for i in range(num_stages)])
56
+
57
+ @env.task
58
+ async def train_with_schedule(cfg: DictConfig, lr_schedule: ListConfig) -> float:
59
+ final_lr = float(lr_schedule[-1])
60
+ ...
61
+ ```
62
+
63
+ ## Ways to construct a DictConfig
64
+
65
+ All of the following are valid ways to create a `DictConfig` to pass to a task:
66
+
67
+ ### 1. From a plain dict
68
+
69
+ ```python
70
+ cfg = OmegaConf.create({"optimizer": {"lr": 0.001}, "training": {"epochs": 10}})
71
+ flyte.run(train, cfg=cfg)
72
+ ```
73
+
74
+ ### 2. From a YAML file
75
+
76
+ ```python
77
+ cfg = OmegaConf.load("config.yaml")
78
+ flyte.run(train, cfg=cfg)
79
+ ```
80
+
81
+ ### 3. From a typed dataclass (structured config)
82
+
83
+ ```python
84
+ from dataclasses import dataclass, field
85
+ from omegaconf import OmegaConf
86
+
87
+ @dataclass
88
+ class OptimizerConf:
89
+ lr: float = 0.001
90
+ weight_decay: float = 1e-4
91
+
92
+ @dataclass
93
+ class TrainConf:
94
+ optimizer: OptimizerConf = field(default_factory=OptimizerConf)
95
+ epochs: int = 10
96
+
97
+ cfg = OmegaConf.structured(TrainConf())
98
+ flyte.run(train, cfg=cfg)
99
+ ```
100
+
101
+ Structured configs provide **type validation at assignment time**: `cfg.optimizer.lr = "oops"` raises `omegaconf.ValidationError`.
102
+
103
+ ### 4. Merging base config with overrides
104
+
105
+ ```python
106
+ base = OmegaConf.load("config.yaml")
107
+ override = OmegaConf.create({"optimizer": {"lr": 0.01}})
108
+ cfg = OmegaConf.merge(base, override)
109
+ flyte.run(train, cfg=cfg)
110
+ ```
111
+
112
+ ### 5. Structured config with MISSING required fields
113
+
114
+ ```python
115
+ from omegaconf import MISSING
116
+
117
+ @dataclass
118
+ class TrainConf:
119
+ data_path: str = MISSING # must be set before accessing
120
+ epochs: int = 10
121
+
122
+ # Pass with MISSING still unset — serialization succeeds
123
+ cfg = OmegaConf.structured(TrainConf())
124
+ flyte.run(train, cfg=cfg)
125
+
126
+ # Or fill it before passing
127
+ cfg = OmegaConf.structured(TrainConf(data_path="/data/imagenet"))
128
+ flyte.run(train, cfg=cfg)
129
+ ```
130
+
131
+ A config with an unset `MISSING` field serializes and deserializes successfully — the `MISSING` sentinel is preserved through the wire format. Accessing the field raises `MissingMandatoryValue`, so the task will fail if it tries to read an unfilled field.
132
+
133
+ ## Structured config deserialization
134
+
135
+ When a `DictConfig` is deserialized in a receiving task, the plugin uses **Auto mode**: it attempts to reconstruct the original dataclass-backed config, and falls back to a plain `DictConfig` if the class is not importable in the receiving task's environment.
136
+
137
+ ```python
138
+ # Task A produces a structured config
139
+ cfg = OmegaConf.structured(TrainConf(lr=0.01))
140
+ # serialized payload: {"base_dataclass": "mymodule.TrainConf", "values": {...}}
141
+
142
+ # Task B receives it
143
+ async def task_b(cfg: DictConfig) -> ...:
144
+ # If TrainConf is importable: cfg is a TrainConf-backed DictConfig (type-validated)
145
+ # If TrainConf is not importable: cfg is a plain DictConfig (no schema)
146
+ OmegaConf.get_type(cfg) # TrainConf or dict
147
+ ```
148
+
149
+ To ensure structured configs survive task hops, make sure the dataclass is defined in a module importable by all tasks in the pipeline.
150
+
151
+ ## Wire format
152
+
153
+ Both `DictConfig` and `ListConfig` are serialized as MessagePack binaries with tag `"msgpack"`:
154
+
155
+ ```
156
+ Literal(scalar=Scalar(binary=Binary(value=<msgpack bytes>, tag="msgpack")))
157
+ ```
158
+
159
+ **DictConfig payload** (msgpack-encoded dict):
160
+
161
+ ```json
162
+ {
163
+ "base_dataclass": "mymodule.TrainConf",
164
+ "values": { "optimizer": { "lr": 0.001 }, "training": { "epochs": 10 } }
165
+ }
166
+ ```
167
+
168
+ For plain dict-backed configs, `base_dataclass` is `"builtins.dict"`.
169
+
170
+ **ListConfig payload** (msgpack-encoded list):
171
+
172
+ ```json
173
+ [0.001, 0.01, 0.1]
174
+ ```
175
+
176
+ OmegaConf variable interpolations are **resolved** at serialization time (`resolve=True`). The wire representation always contains concrete values.
177
+
178
+ ## Limitations
179
+
180
+ - **Structured config schema strictness**: merging keys that don't exist as dataclass fields raises an error. Only declare structured configs when all possible keys are known upfront.
181
+ - **`MISSING` fields**: a `DictConfig` with unset `MISSING` fields serializes fine — the sentinel is preserved on the wire and accessing it still raises `MissingMandatoryValue`. However, in plain dict mode (when the originating dataclass is not importable in the receiving task), the field's type annotation is lost: the node becomes an `AnyNode` instead of the declared type (e.g. `StringNode`). In Auto mode, the schema is recovered from the dataclass, so the annotation is preserved.
182
+ - **ListConfig structured configs**: `ListConfig` always round-trips as a plain `ListConfig` — there is no structured (typed-element) ListConfig support.
183
+ - **Key types**: OmegaConf enforces string keys for `DictConfig`; integer-keyed dicts are not supported.
184
+ - **Class importability**: structured config reconstruction requires the dataclass to be importable in the receiving task. If it is not, the config falls back to a plain `DictConfig` (Auto mode).
185
+
186
+ ## Examples
187
+
188
+ See the [`examples/`](examples/) directory:
189
+
190
+ - [`example_dictconfig.py`](examples/example_dictconfig.py) — plain dict configs, nested, interpolation, merging
191
+ - [`example_structured_config.py`](examples/example_structured_config.py) — structured configs, type validation, MISSING fields, config resolution
192
+ - [`example_listconfig.py`](examples/example_listconfig.py) — numeric lists, nested lists, list of dicts, LR schedules
193
+ - [`example_pipeline.py`](examples/example_pipeline.py) — multi-task pipeline with DictConfig and ListConfig flowing between tasks
@@ -0,0 +1,8 @@
1
+ flyteplugins/omegaconf/__init__.py,sha256=NZcB8g3E1FUspUWh8fqdtdWGTen85qgTHm1bHfsev5o,696
2
+ flyteplugins/omegaconf/base_transformer.py,sha256=zPz2hMcH1mFlJCYD_iZS2g34GBv0LE3EtF4ZITCtA3w,2642
3
+ flyteplugins/omegaconf/codec.py,sha256=CtGCZl7s04weRXXQkKUDdE1XC51AO9CpAj3Zgq_5ksI,5183
4
+ flyteplugins_omegaconf-2.1.6.dist-info/METADATA,sha256=R6T6dyJvdGb1CJyGNBtVjYSgY-yxHgHHP3c_QJPiaJo,6838
5
+ flyteplugins_omegaconf-2.1.6.dist-info/WHEEL,sha256=aeYiig01lYGDzBgS8HxWXOg3uV61G9ijOsup-k9o1sk,91
6
+ flyteplugins_omegaconf-2.1.6.dist-info/entry_points.txt,sha256=F992e_5r188DCSEp6CBXmVL1gMUQ6EXrw9R_7rhAB38,89
7
+ flyteplugins_omegaconf-2.1.6.dist-info/top_level.txt,sha256=cgd779rPu9EsvdtuYgUxNHHgElaQvPn74KhB5XSeMBE,13
8
+ flyteplugins_omegaconf-2.1.6.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (82.0.1)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1,2 @@
1
+ [flyte.plugins.types]
2
+ omegaconf = flyteplugins.omegaconf:register_omegaconf_transformers
@@ -0,0 +1 @@
1
+ flyteplugins