flwr 1.18.0__py3-none-any.whl → 1.20.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (174) hide show
  1. flwr/app/__init__.py +15 -0
  2. flwr/app/error.py +68 -0
  3. flwr/app/metadata.py +223 -0
  4. flwr/cli/build.py +94 -59
  5. flwr/cli/log.py +3 -3
  6. flwr/cli/login/login.py +3 -7
  7. flwr/cli/ls.py +15 -36
  8. flwr/cli/new/new.py +12 -4
  9. flwr/cli/new/templates/app/README.flowertune.md.tpl +2 -0
  10. flwr/cli/new/templates/app/README.md.tpl +5 -0
  11. flwr/cli/new/templates/app/code/client.baseline.py.tpl +1 -1
  12. flwr/cli/new/templates/app/code/model.baseline.py.tpl +1 -1
  13. flwr/cli/new/templates/app/code/server.baseline.py.tpl +2 -3
  14. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +25 -17
  15. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +13 -1
  16. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +21 -2
  17. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +18 -1
  18. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +19 -2
  19. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +18 -1
  20. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +20 -3
  21. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +18 -1
  22. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +18 -1
  23. flwr/cli/run/run.py +48 -49
  24. flwr/cli/stop.py +2 -2
  25. flwr/cli/utils.py +38 -5
  26. flwr/client/__init__.py +2 -2
  27. flwr/client/client_app.py +1 -1
  28. flwr/client/clientapp/__init__.py +0 -7
  29. flwr/client/grpc_adapter_client/connection.py +15 -8
  30. flwr/client/grpc_rere_client/connection.py +142 -97
  31. flwr/client/grpc_rere_client/grpc_adapter.py +34 -6
  32. flwr/client/message_handler/message_handler.py +1 -1
  33. flwr/client/mod/comms_mods.py +36 -17
  34. flwr/client/rest_client/connection.py +176 -103
  35. flwr/clientapp/__init__.py +15 -0
  36. flwr/common/__init__.py +2 -2
  37. flwr/common/auth_plugin/__init__.py +2 -0
  38. flwr/common/auth_plugin/auth_plugin.py +29 -3
  39. flwr/common/constant.py +39 -8
  40. flwr/common/event_log_plugin/event_log_plugin.py +3 -3
  41. flwr/common/exit/exit_code.py +16 -1
  42. flwr/common/exit_handlers.py +30 -0
  43. flwr/common/grpc.py +12 -1
  44. flwr/common/heartbeat.py +165 -0
  45. flwr/common/inflatable.py +290 -0
  46. flwr/common/inflatable_protobuf_utils.py +141 -0
  47. flwr/common/inflatable_utils.py +508 -0
  48. flwr/common/message.py +110 -242
  49. flwr/common/record/__init__.py +2 -1
  50. flwr/common/record/array.py +402 -0
  51. flwr/common/record/arraychunk.py +59 -0
  52. flwr/common/record/arrayrecord.py +103 -225
  53. flwr/common/record/configrecord.py +59 -4
  54. flwr/common/record/conversion_utils.py +1 -1
  55. flwr/common/record/metricrecord.py +55 -4
  56. flwr/common/record/recorddict.py +69 -1
  57. flwr/common/recorddict_compat.py +2 -2
  58. flwr/common/retry_invoker.py +5 -1
  59. flwr/common/serde.py +59 -211
  60. flwr/common/serde_utils.py +175 -0
  61. flwr/common/typing.py +5 -3
  62. flwr/compat/__init__.py +15 -0
  63. flwr/compat/client/__init__.py +15 -0
  64. flwr/{client → compat/client}/app.py +28 -185
  65. flwr/compat/common/__init__.py +15 -0
  66. flwr/compat/server/__init__.py +15 -0
  67. flwr/compat/server/app.py +174 -0
  68. flwr/compat/simulation/__init__.py +15 -0
  69. flwr/proto/appio_pb2.py +43 -0
  70. flwr/proto/appio_pb2.pyi +151 -0
  71. flwr/proto/appio_pb2_grpc.py +4 -0
  72. flwr/proto/appio_pb2_grpc.pyi +4 -0
  73. flwr/proto/clientappio_pb2.py +12 -19
  74. flwr/proto/clientappio_pb2.pyi +23 -101
  75. flwr/proto/clientappio_pb2_grpc.py +269 -28
  76. flwr/proto/clientappio_pb2_grpc.pyi +114 -20
  77. flwr/proto/fleet_pb2.py +24 -27
  78. flwr/proto/fleet_pb2.pyi +19 -35
  79. flwr/proto/fleet_pb2_grpc.py +117 -13
  80. flwr/proto/fleet_pb2_grpc.pyi +47 -6
  81. flwr/proto/heartbeat_pb2.py +33 -0
  82. flwr/proto/heartbeat_pb2.pyi +66 -0
  83. flwr/proto/heartbeat_pb2_grpc.py +4 -0
  84. flwr/proto/heartbeat_pb2_grpc.pyi +4 -0
  85. flwr/proto/message_pb2.py +28 -11
  86. flwr/proto/message_pb2.pyi +125 -0
  87. flwr/proto/recorddict_pb2.py +16 -28
  88. flwr/proto/recorddict_pb2.pyi +46 -64
  89. flwr/proto/run_pb2.py +24 -32
  90. flwr/proto/run_pb2.pyi +4 -52
  91. flwr/proto/serverappio_pb2.py +9 -23
  92. flwr/proto/serverappio_pb2.pyi +0 -110
  93. flwr/proto/serverappio_pb2_grpc.py +177 -72
  94. flwr/proto/serverappio_pb2_grpc.pyi +75 -33
  95. flwr/proto/simulationio_pb2.py +12 -11
  96. flwr/proto/simulationio_pb2_grpc.py +35 -0
  97. flwr/proto/simulationio_pb2_grpc.pyi +14 -0
  98. flwr/server/__init__.py +1 -1
  99. flwr/server/app.py +69 -187
  100. flwr/server/compat/app_utils.py +50 -28
  101. flwr/server/fleet_event_log_interceptor.py +6 -2
  102. flwr/server/grid/grpc_grid.py +148 -41
  103. flwr/server/grid/inmemory_grid.py +5 -4
  104. flwr/server/serverapp/app.py +45 -17
  105. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +21 -3
  106. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +102 -8
  107. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +2 -5
  108. flwr/server/superlink/fleet/message_handler/message_handler.py +130 -19
  109. flwr/server/superlink/fleet/rest_rere/rest_api.py +73 -13
  110. flwr/server/superlink/fleet/vce/vce_api.py +6 -3
  111. flwr/server/superlink/linkstate/in_memory_linkstate.py +138 -43
  112. flwr/server/superlink/linkstate/linkstate.py +53 -20
  113. flwr/server/superlink/linkstate/sqlite_linkstate.py +149 -55
  114. flwr/server/superlink/linkstate/utils.py +33 -29
  115. flwr/server/superlink/serverappio/serverappio_grpc.py +4 -1
  116. flwr/server/superlink/serverappio/serverappio_servicer.py +230 -84
  117. flwr/server/superlink/simulation/simulationio_grpc.py +1 -1
  118. flwr/server/superlink/simulation/simulationio_servicer.py +26 -2
  119. flwr/server/superlink/utils.py +9 -2
  120. flwr/server/utils/validator.py +2 -2
  121. flwr/serverapp/__init__.py +15 -0
  122. flwr/simulation/app.py +25 -0
  123. flwr/simulation/run_simulation.py +17 -0
  124. flwr/supercore/__init__.py +15 -0
  125. flwr/{server/superlink → supercore}/ffs/__init__.py +2 -0
  126. flwr/{server/superlink → supercore}/ffs/disk_ffs.py +1 -1
  127. flwr/supercore/grpc_health/__init__.py +22 -0
  128. flwr/supercore/grpc_health/simple_health_servicer.py +38 -0
  129. flwr/supercore/license_plugin/__init__.py +22 -0
  130. flwr/supercore/license_plugin/license_plugin.py +26 -0
  131. flwr/supercore/object_store/__init__.py +24 -0
  132. flwr/supercore/object_store/in_memory_object_store.py +229 -0
  133. flwr/supercore/object_store/object_store.py +170 -0
  134. flwr/supercore/object_store/object_store_factory.py +44 -0
  135. flwr/supercore/object_store/utils.py +43 -0
  136. flwr/supercore/scheduler/__init__.py +22 -0
  137. flwr/supercore/scheduler/plugin.py +71 -0
  138. flwr/{client/nodestate/nodestate.py → supercore/utils.py} +14 -13
  139. flwr/superexec/deployment.py +7 -4
  140. flwr/superexec/exec_event_log_interceptor.py +8 -4
  141. flwr/superexec/exec_grpc.py +25 -5
  142. flwr/superexec/exec_license_interceptor.py +82 -0
  143. flwr/superexec/exec_servicer.py +135 -24
  144. flwr/superexec/exec_user_auth_interceptor.py +45 -8
  145. flwr/superexec/executor.py +5 -1
  146. flwr/superexec/simulation.py +8 -3
  147. flwr/superlink/__init__.py +15 -0
  148. flwr/{client/supernode → supernode}/__init__.py +0 -7
  149. flwr/supernode/cli/__init__.py +24 -0
  150. flwr/{client/supernode/app.py → supernode/cli/flower_supernode.py} +3 -19
  151. flwr/supernode/cli/flwr_clientapp.py +88 -0
  152. flwr/supernode/nodestate/in_memory_nodestate.py +199 -0
  153. flwr/supernode/nodestate/nodestate.py +227 -0
  154. flwr/supernode/runtime/__init__.py +15 -0
  155. flwr/{client/clientapp/app.py → supernode/runtime/run_clientapp.py} +135 -89
  156. flwr/supernode/scheduler/__init__.py +22 -0
  157. flwr/supernode/scheduler/simple_clientapp_scheduler_plugin.py +49 -0
  158. flwr/supernode/servicer/__init__.py +15 -0
  159. flwr/supernode/servicer/clientappio/__init__.py +22 -0
  160. flwr/supernode/servicer/clientappio/clientappio_servicer.py +303 -0
  161. flwr/supernode/start_client_internal.py +589 -0
  162. {flwr-1.18.0.dist-info → flwr-1.20.0.dist-info}/METADATA +6 -4
  163. {flwr-1.18.0.dist-info → flwr-1.20.0.dist-info}/RECORD +171 -123
  164. {flwr-1.18.0.dist-info → flwr-1.20.0.dist-info}/WHEEL +1 -1
  165. {flwr-1.18.0.dist-info → flwr-1.20.0.dist-info}/entry_points.txt +2 -2
  166. flwr/client/clientapp/clientappio_servicer.py +0 -244
  167. flwr/client/heartbeat.py +0 -74
  168. flwr/client/nodestate/in_memory_nodestate.py +0 -38
  169. /flwr/{client → compat/client}/grpc_client/__init__.py +0 -0
  170. /flwr/{client → compat/client}/grpc_client/connection.py +0 -0
  171. /flwr/{server/superlink → supercore}/ffs/ffs.py +0 -0
  172. /flwr/{server/superlink → supercore}/ffs/ffs_factory.py +0 -0
  173. /flwr/{client → supernode}/nodestate/__init__.py +0 -0
  174. /flwr/{client → supernode}/nodestate/nodestate_factory.py +0 -0
@@ -0,0 +1,290 @@
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """InflatableObject base class."""
16
+
17
+
18
+ from __future__ import annotations
19
+
20
+ import hashlib
21
+ import threading
22
+ from collections.abc import Iterator
23
+ from contextlib import contextmanager
24
+ from typing import TypeVar, cast
25
+
26
+ from flwr.proto.message_pb2 import ObjectTree # pylint: disable=E0611
27
+
28
+ from .constant import HEAD_BODY_DIVIDER, HEAD_VALUE_DIVIDER
29
+
30
+
31
+ class UnexpectedObjectContentError(Exception):
32
+ """Exception raised when the content of an object does not conform to the expected
33
+ structure for an InflatableObject (i.e., head, body, and values within the head)."""
34
+
35
+ def __init__(self, object_id: str, reason: str):
36
+ super().__init__(
37
+ f"Object with ID '{object_id}' has an unexpected structure. {reason}"
38
+ )
39
+
40
+
41
+ _ctx = threading.local()
42
+
43
+
44
+ def _is_recompute_enabled() -> bool:
45
+ """Check if recomputing object IDs is enabled."""
46
+ return getattr(_ctx, "recompute_object_id_enabled", True)
47
+
48
+
49
+ def _get_computed_object_ids() -> set[str]:
50
+ """Get the set of computed object IDs."""
51
+ return getattr(_ctx, "computed_object_ids", set())
52
+
53
+
54
+ @contextmanager
55
+ def no_object_id_recompute() -> Iterator[None]:
56
+ """Context manager to disable recomputing object IDs."""
57
+ old_value = _is_recompute_enabled()
58
+ old_set = _get_computed_object_ids()
59
+ _ctx.recompute_object_id_enabled = False
60
+ _ctx.computed_object_ids = set()
61
+ try:
62
+ yield
63
+ finally:
64
+ _ctx.recompute_object_id_enabled = old_value
65
+ _ctx.computed_object_ids = old_set
66
+
67
+
68
+ class InflatableObject:
69
+ """Base class for inflatable objects."""
70
+
71
+ def deflate(self) -> bytes:
72
+ """Deflate object."""
73
+ raise NotImplementedError()
74
+
75
+ @classmethod
76
+ def inflate(
77
+ cls, object_content: bytes, children: dict[str, InflatableObject] | None = None
78
+ ) -> InflatableObject:
79
+ """Inflate the object from bytes.
80
+
81
+ Parameters
82
+ ----------
83
+ object_content : bytes
84
+ The deflated object content.
85
+
86
+ children : Optional[dict[str, InflatableObject]] (default: None)
87
+ Dictionary of children InflatableObjects mapped to their object IDs. These
88
+ childrens enable the full inflation of the parent InflatableObject.
89
+
90
+ Returns
91
+ -------
92
+ InflatableObject
93
+ The inflated object.
94
+ """
95
+ raise NotImplementedError()
96
+
97
+ @property
98
+ def object_id(self) -> str:
99
+ """Get object_id."""
100
+ # If recomputing object ID is disabled and the object ID is already computed,
101
+ # return the cached object ID.
102
+ if (
103
+ not _is_recompute_enabled()
104
+ and (obj_id := self.__dict__.get("_object_id"))
105
+ in _get_computed_object_ids()
106
+ ):
107
+ return cast(str, obj_id)
108
+
109
+ if self.is_dirty or "_object_id" not in self.__dict__:
110
+ obj_id = get_object_id(self.deflate())
111
+ self.__dict__["_object_id"] = obj_id
112
+
113
+ # If recomputing object ID is disabled, add the object ID to the set of
114
+ # computed object IDs to avoid recomputing it within the context.
115
+ if not _is_recompute_enabled():
116
+ _get_computed_object_ids().add(obj_id)
117
+ return cast(str, self.__dict__["_object_id"])
118
+
119
+ @property
120
+ def children(self) -> dict[str, InflatableObject] | None:
121
+ """Get all child objects as a dictionary or None if there are no children."""
122
+ return None
123
+
124
+ @property
125
+ def is_dirty(self) -> bool:
126
+ """Check if the object is dirty after the last deflation.
127
+
128
+ An object is considered dirty if its content has changed since the last its
129
+ object ID was computed.
130
+ """
131
+ return True
132
+
133
+
134
+ T = TypeVar("T", bound=InflatableObject)
135
+
136
+
137
+ def get_object_id(object_content: bytes) -> str:
138
+ """Return a SHA-256 hash of the (deflated) object content."""
139
+ return hashlib.sha256(object_content).hexdigest()
140
+
141
+
142
+ def get_object_body(object_content: bytes, cls: type[T]) -> bytes:
143
+ """Return object body but raise an error if object type doesn't match class name."""
144
+ class_name = cls.__qualname__
145
+ object_type = get_object_type_from_object_content(object_content)
146
+ if not object_type == class_name:
147
+ raise ValueError(
148
+ f"Class name ({class_name}) and object type "
149
+ f"({object_type}) do not match."
150
+ )
151
+
152
+ # Return object body
153
+ return _get_object_body(object_content)
154
+
155
+
156
+ def add_header_to_object_body(object_body: bytes, obj: InflatableObject) -> bytes:
157
+ """Add header to object content."""
158
+ # Construct header
159
+ header = f"%s{HEAD_VALUE_DIVIDER}%s{HEAD_VALUE_DIVIDER}%d" % (
160
+ obj.__class__.__qualname__, # Type of object
161
+ ",".join((obj.children or {}).keys()), # IDs of child objects
162
+ len(object_body), # Length of object body
163
+ )
164
+
165
+ # Concatenate header and object body
166
+ ret = bytearray()
167
+ ret.extend(header.encode(encoding="utf-8"))
168
+ ret.extend(HEAD_BODY_DIVIDER)
169
+ ret.extend(object_body)
170
+ return bytes(ret)
171
+
172
+
173
+ def _get_object_head(object_content: bytes) -> bytes:
174
+ """Return object head from object content."""
175
+ index = object_content.find(HEAD_BODY_DIVIDER)
176
+ return object_content[:index]
177
+
178
+
179
+ def _get_object_body(object_content: bytes) -> bytes:
180
+ """Return object body from object content."""
181
+ index = object_content.find(HEAD_BODY_DIVIDER)
182
+ return object_content[index + len(HEAD_BODY_DIVIDER) :]
183
+
184
+
185
+ def is_valid_sha256_hash(object_id: str) -> bool:
186
+ """Check if the given string is a valid SHA-256 hash.
187
+
188
+ Parameters
189
+ ----------
190
+ object_id : str
191
+ The string to check.
192
+
193
+ Returns
194
+ -------
195
+ bool
196
+ ``True`` if the string is a valid SHA-256 hash, ``False`` otherwise.
197
+ """
198
+ if len(object_id) != 64:
199
+ return False
200
+ try:
201
+ # If base 16 int conversion succeeds, it's a valid hexadecimal str
202
+ int(object_id, 16)
203
+ return True
204
+ except ValueError:
205
+ return False
206
+
207
+
208
+ def get_object_type_from_object_content(object_content: bytes) -> str:
209
+ """Return object type from bytes."""
210
+ return get_object_head_values_from_object_content(object_content)[0]
211
+
212
+
213
+ def get_object_children_ids_from_object_content(object_content: bytes) -> list[str]:
214
+ """Return object children IDs from bytes."""
215
+ return get_object_head_values_from_object_content(object_content)[1]
216
+
217
+
218
+ def get_object_body_len_from_object_content(object_content: bytes) -> int:
219
+ """Return length of the object body."""
220
+ return get_object_head_values_from_object_content(object_content)[2]
221
+
222
+
223
+ def get_object_head_values_from_object_content(
224
+ object_content: bytes,
225
+ ) -> tuple[str, list[str], int]:
226
+ """Return object type and body length from object content.
227
+
228
+ Parameters
229
+ ----------
230
+ object_content : bytes
231
+ The deflated object content.
232
+
233
+ Returns
234
+ -------
235
+ tuple[str, list[str], int]
236
+ A tuple containing:
237
+ - The object type as a string.
238
+ - A list of child object IDs as strings.
239
+ - The length of the object body as an integer.
240
+ """
241
+ head = _get_object_head(object_content).decode(encoding="utf-8")
242
+ obj_type, children_str, body_len = head.split(HEAD_VALUE_DIVIDER)
243
+ children_ids = children_str.split(",") if children_str else []
244
+ return obj_type, children_ids, int(body_len)
245
+
246
+
247
+ def get_descendant_object_ids(obj: InflatableObject) -> set[str]:
248
+ """Get a set of object IDs of all descendants."""
249
+ descendants = set(get_all_nested_objects(obj).keys())
250
+ # Exclude Object ID of parent object
251
+ descendants.discard(obj.object_id)
252
+ return descendants
253
+
254
+
255
+ def get_all_nested_objects(obj: InflatableObject) -> dict[str, InflatableObject]:
256
+ """Get a dictionary of all nested objects, including the object itself.
257
+
258
+ Each key in the dictionary is an object ID, and the entries are ordered by post-
259
+ order traversal, i.e., child objects appear before their respective parents.
260
+ """
261
+ ret: dict[str, InflatableObject] = {}
262
+ if children := obj.children:
263
+ for child in children.values():
264
+ ret.update(get_all_nested_objects(child))
265
+
266
+ ret[obj.object_id] = obj
267
+
268
+ return ret
269
+
270
+
271
+ def get_object_tree(obj: InflatableObject) -> ObjectTree:
272
+ """Get a tree representation of the InflatableObject."""
273
+ tree_children = []
274
+ if children := obj.children:
275
+ for child in children.values():
276
+ tree_children.append(get_object_tree(child))
277
+ return ObjectTree(object_id=obj.object_id, children=tree_children)
278
+
279
+
280
+ def iterate_object_tree(
281
+ tree: ObjectTree,
282
+ ) -> Iterator[ObjectTree]:
283
+ """Iterate over the object tree and yield object IDs.
284
+
285
+ This function performs a post-order traversal of the tree, yielding the object ID of
286
+ each node after all its children have been yielded.
287
+ """
288
+ for child in tree.children:
289
+ yield from iterate_object_tree(child)
290
+ yield tree
@@ -0,0 +1,141 @@
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """InflatableObject gRPC utils."""
16
+
17
+
18
+ from typing import Callable
19
+
20
+ from flwr.proto.message_pb2 import ( # pylint: disable=E0611
21
+ ConfirmMessageReceivedRequest,
22
+ ConfirmMessageReceivedResponse,
23
+ PullObjectRequest,
24
+ PullObjectResponse,
25
+ PushObjectRequest,
26
+ PushObjectResponse,
27
+ )
28
+ from flwr.proto.node_pb2 import Node # pylint: disable=E0611
29
+
30
+ from .inflatable_utils import ObjectIdNotPreregisteredError, ObjectUnavailableError
31
+
32
+ ConfirmMessageReceivedProtobuf = Callable[
33
+ [ConfirmMessageReceivedRequest], ConfirmMessageReceivedResponse
34
+ ]
35
+
36
+
37
+ def make_pull_object_fn_protobuf(
38
+ pull_object_protobuf: Callable[[PullObjectRequest], PullObjectResponse],
39
+ node: Node,
40
+ run_id: int,
41
+ ) -> Callable[[str], bytes]:
42
+ """Create a pull object function that uses gRPC to pull objects.
43
+
44
+ Parameters
45
+ ----------
46
+ pull_object_protobuf : Callable[[PullObjectRequest], PullObjectResponse]
47
+ A callable that takes a `PullObjectRequest` and returns a `PullObjectResponse`.
48
+ This function is typically backed by a gRPC client stub.
49
+ node : Node
50
+ The node making the request.
51
+ run_id : int
52
+ The run ID for the current operation.
53
+
54
+ Returns
55
+ -------
56
+ Callable[[str], bytes]
57
+ A function that takes an object ID and returns the object content as bytes.
58
+ The function raises `ObjectIdNotPreregisteredError` if the object ID is not
59
+ pre-registered, or `ObjectUnavailableError` if the object is not yet available.
60
+ """
61
+
62
+ def pull_object_fn(object_id: str) -> bytes:
63
+ request = PullObjectRequest(node=node, run_id=run_id, object_id=object_id)
64
+ response: PullObjectResponse = pull_object_protobuf(request)
65
+ if not response.object_found:
66
+ raise ObjectIdNotPreregisteredError(object_id)
67
+ if not response.object_available:
68
+ raise ObjectUnavailableError(object_id)
69
+ return response.object_content
70
+
71
+ return pull_object_fn
72
+
73
+
74
+ def make_push_object_fn_protobuf(
75
+ push_object_protobuf: Callable[[PushObjectRequest], PushObjectResponse],
76
+ node: Node,
77
+ run_id: int,
78
+ ) -> Callable[[str, bytes], None]:
79
+ """Create a push object function that uses gRPC to push objects.
80
+
81
+ Parameters
82
+ ----------
83
+ push_object_protobuf : Callable[[PushObjectRequest], PushObjectResponse]
84
+ A callable that takes a `PushObjectRequest` and returns a `PushObjectResponse`.
85
+ This function is typically backed by a gRPC client stub.
86
+ node : Node
87
+ The node making the request.
88
+ run_id : int
89
+ The run ID for the current operation.
90
+
91
+ Returns
92
+ -------
93
+ Callable[[str, bytes], None]
94
+ A function that takes an object ID and its content as bytes, and pushes it
95
+ to the servicer. The function raises `ObjectIdNotPreregisteredError` if
96
+ the object ID is not pre-registered.
97
+ """
98
+
99
+ def push_object_fn(object_id: str, object_content: bytes) -> None:
100
+ request = PushObjectRequest(
101
+ node=node, run_id=run_id, object_id=object_id, object_content=object_content
102
+ )
103
+ response: PushObjectResponse = push_object_protobuf(request)
104
+ if not response.stored:
105
+ raise ObjectIdNotPreregisteredError(object_id)
106
+
107
+ return push_object_fn
108
+
109
+
110
+ def make_confirm_message_received_fn_protobuf(
111
+ confirm_message_received_protobuf: ConfirmMessageReceivedProtobuf,
112
+ node: Node,
113
+ run_id: int,
114
+ ) -> Callable[[str], None]:
115
+ """Create a confirm message received function that uses protobuf.
116
+
117
+ Parameters
118
+ ----------
119
+ confirm_message_received_protobuf : ConfirmMessageReceivedProtobuf
120
+ A callable that takes a `ConfirmMessageReceivedRequest` and returns a
121
+ `ConfirmMessageReceivedResponse`, confirming message receipt.
122
+ This function is typically backed by a gRPC client stub.
123
+ node : Node
124
+ The node making the request.
125
+ run_id : int
126
+ The run ID for the current message.
127
+
128
+ Returns
129
+ -------
130
+ Callable[[str], None]
131
+ A wrapper function that takes an object ID and confirms that
132
+ the message has been received.
133
+ """
134
+
135
+ def confirm_message_received_fn(object_id: str) -> None:
136
+ request = ConfirmMessageReceivedRequest(
137
+ node=node, run_id=run_id, message_object_id=object_id
138
+ )
139
+ confirm_message_received_protobuf(request)
140
+
141
+ return confirm_message_received_fn