openrewrite-remote 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openrewrite_remote-0.2.0.dist-info/METADATA +16 -0
- openrewrite_remote-0.2.0.dist-info/RECORD +18 -0
- openrewrite_remote-0.2.0.dist-info/WHEEL +4 -0
- rewrite/remote/__init__.py +7 -0
- rewrite/remote/client.py +40 -0
- rewrite/remote/event.py +20 -0
- rewrite/remote/java/extensions.py +176 -0
- rewrite/remote/java/receiver.py +1376 -0
- rewrite/remote/java/sender.py +659 -0
- rewrite/remote/python/extensions.py +179 -0
- rewrite/remote/python/receiver.py +1873 -0
- rewrite/remote/python/sender.py +894 -0
- rewrite/remote/receiver.py +386 -0
- rewrite/remote/remote_utils.py +280 -0
- rewrite/remote/remoting.py +342 -0
- rewrite/remote/sender.py +424 -0
- rewrite/remote/server.py +188 -0
- rewrite/remote/type_utils.py +60 -0
@@ -0,0 +1,386 @@
|
|
1
|
+
from collections import OrderedDict
|
2
|
+
from enum import Enum
|
3
|
+
from pathlib import Path
|
4
|
+
from typing import Protocol, TypeVar, Optional, Type, Dict, Callable, List, cast, Iterable, Any, Tuple, \
|
5
|
+
get_args, TYPE_CHECKING
|
6
|
+
from uuid import UUID
|
7
|
+
|
8
|
+
import cbor2
|
9
|
+
from cbor2 import CBORDecoder
|
10
|
+
from rewrite import Markers, Marker, ParseErrorVisitor
|
11
|
+
|
12
|
+
from rewrite import Tree, TreeVisitor, Cursor, FileAttributes
|
13
|
+
from rewrite.remote.event import DiffEvent, EventType
|
14
|
+
from . import remote_utils, type_utils
|
15
|
+
|
16
|
+
if TYPE_CHECKING:
|
17
|
+
from .remoting import RemotingContext
|
18
|
+
|
19
|
+
A = TypeVar('T')
|
20
|
+
T = TypeVar('T', bound=Tree)
|
21
|
+
V = TypeVar('V')
|
22
|
+
I = TypeVar('I')
|
23
|
+
P = TypeVar('P')
|
24
|
+
|
25
|
+
|
26
|
+
class Receiver(Protocol):
|
27
|
+
def fork(self, context: 'ReceiverContext') -> 'ReceiverContext':
|
28
|
+
...
|
29
|
+
|
30
|
+
def receive(self, before: Optional[T], ctx: 'ReceiverContext') -> object:
|
31
|
+
...
|
32
|
+
|
33
|
+
|
34
|
+
class OmniReceiver(Receiver):
|
35
|
+
def fork(self, ctx: 'ReceiverContext') -> 'ReceiverContext':
|
36
|
+
raise NotImplementedError("Cannot fork OmniReceiver")
|
37
|
+
|
38
|
+
def receive(self, before: Optional['Tree'], ctx: 'ReceiverContext') -> 'Tree':
|
39
|
+
visitor = self.Visitor()
|
40
|
+
return visitor.visit(before, ctx)
|
41
|
+
|
42
|
+
class Visitor(TreeVisitor[Tree, 'ReceiverContext']):
|
43
|
+
def visit(self, tree: Optional[Tree], ctx: 'ReceiverContext', parent: Optional[Cursor] = None) -> Optional[
|
44
|
+
Tree]:
|
45
|
+
self.cursor = Cursor(self.cursor, tree)
|
46
|
+
tree = ctx.polymorphic_receive_tree(tree)
|
47
|
+
self.cursor = self.cursor.parent
|
48
|
+
return tree
|
49
|
+
|
50
|
+
|
51
|
+
class TreeReceiver(Protocol):
|
52
|
+
def receive_node(self) -> DiffEvent:
|
53
|
+
...
|
54
|
+
|
55
|
+
def receive_value(self, expected_type: Type) -> DiffEvent:
|
56
|
+
...
|
57
|
+
|
58
|
+
|
59
|
+
class ReceiverFactory(Protocol):
|
60
|
+
def create(self, type_name: str, ctx: 'ReceiverContext') -> Tree:
|
61
|
+
...
|
62
|
+
|
63
|
+
|
64
|
+
class DetailsReceiver(Protocol[T]):
|
65
|
+
def receive_details(self, before: Optional[T], type: Optional[Type[T]], ctx: 'ReceiverContext') -> T:
|
66
|
+
pass
|
67
|
+
|
68
|
+
|
69
|
+
class ReceiverContext:
|
70
|
+
Registry: Dict[Type, Callable[[], Receiver]] = OrderedDict()
|
71
|
+
|
72
|
+
def __init__(self, receiver: TreeReceiver, visitor: Optional[TreeVisitor] = None,
|
73
|
+
factory: ReceiverFactory = None):
|
74
|
+
self.receiver = receiver
|
75
|
+
self.visitor = visitor
|
76
|
+
self.factory = factory
|
77
|
+
|
78
|
+
def fork(self, visitor: TreeVisitor, factory: ReceiverFactory) -> 'ReceiverContext':
|
79
|
+
return ReceiverContext(self.receiver, visitor, factory)
|
80
|
+
|
81
|
+
def receive_any_tree(self, before) -> T:
|
82
|
+
return OmniReceiver().receive(before, self)
|
83
|
+
|
84
|
+
def receive_tree(self, before: Optional[Tree], tree_type: Optional[str], ctx: 'ReceiverContext') -> Tree:
|
85
|
+
if before:
|
86
|
+
return before.accept(self.visitor, ctx)
|
87
|
+
else:
|
88
|
+
return self.factory.create(tree_type, ctx)
|
89
|
+
|
90
|
+
def polymorphic_receive_tree(self, before: Optional[Tree]) -> Optional[Tree]:
|
91
|
+
diff_event = self.receiver.receive_node()
|
92
|
+
if diff_event.event_type in (EventType.Add, EventType.Update):
|
93
|
+
tree_receiver = self.new_receiver(diff_event.concrete_type or type(before).__name__)
|
94
|
+
forked = tree_receiver.fork(self)
|
95
|
+
return forked.receive_tree(
|
96
|
+
None if diff_event.event_type == EventType.Add else before,
|
97
|
+
diff_event.concrete_type, forked
|
98
|
+
)
|
99
|
+
elif diff_event.event_type == EventType.Delete:
|
100
|
+
return None
|
101
|
+
else:
|
102
|
+
return before
|
103
|
+
|
104
|
+
def new_receiver(self, type_name: str) -> Receiver:
|
105
|
+
type_ = type_utils.get_type(type_name)
|
106
|
+
for entry_type, factory in self.Registry.items():
|
107
|
+
if issubclass(type_, entry_type):
|
108
|
+
return factory()
|
109
|
+
raise ValueError(f"Unsupported receiver type: {type_name}")
|
110
|
+
|
111
|
+
def receive_node(self, before: Optional[A],
|
112
|
+
details: Callable[[Optional[A], Optional[str], 'ReceiverContext'], A]) -> Optional[A]:
|
113
|
+
evt = self.receiver.receive_node()
|
114
|
+
if evt.event_type == EventType.Delete:
|
115
|
+
return None
|
116
|
+
elif evt.event_type == EventType.Add:
|
117
|
+
return details(None, evt.concrete_type, self)
|
118
|
+
elif evt.event_type == EventType.Update:
|
119
|
+
return details(before, evt.concrete_type, self)
|
120
|
+
return before
|
121
|
+
|
122
|
+
def receive_markers(self, before: Optional[Markers], type: Optional[str], ctx: 'ReceiverContext') -> Markers:
|
123
|
+
id_ = self.receive_value(getattr(before, 'id', None), UUID)
|
124
|
+
after_markers = self.receive_values(getattr(before, 'markers', None), Marker)
|
125
|
+
if before:
|
126
|
+
return before.with_id(id_).with_markers(after_markers)
|
127
|
+
else:
|
128
|
+
return Markers(id_, after_markers)
|
129
|
+
|
130
|
+
def receive_nodes(self, before: Optional[List[A]], details: Callable[[Optional[A], Optional[str], 'ReceiverContext'], A]) -> Optional[List[A]]:
|
131
|
+
return remote_utils.receive_nodes(before, details, self)
|
132
|
+
|
133
|
+
def receive_values(self, before: Optional[List[V]], type: Type) -> Optional[List[V]]:
|
134
|
+
return remote_utils.receive_values(before, type, self)
|
135
|
+
|
136
|
+
def receive_value(self, before: Optional[V], type: Type) -> Optional[V]:
|
137
|
+
return self.receive_value0(before, type)
|
138
|
+
|
139
|
+
def receive_value0(self, before: Optional[V], type: Type) -> Optional[V]:
|
140
|
+
evt = self.receiver.receive_value(type)
|
141
|
+
if evt.event_type in (EventType.Update, EventType.Add):
|
142
|
+
return evt.msg
|
143
|
+
elif evt.event_type == EventType.Delete:
|
144
|
+
return None
|
145
|
+
return before
|
146
|
+
|
147
|
+
@staticmethod
|
148
|
+
def register(type_: Type, receiver_factory: Callable[[], Receiver]):
|
149
|
+
ReceiverContext.Registry[type_] = receiver_factory
|
150
|
+
|
151
|
+
|
152
|
+
class ValueDeserializer(Protocol):
|
153
|
+
def deserialize(self, type_: Optional[Type], reader: CBORDecoder, context: 'DeserializationContext') -> Optional[Any]:
|
154
|
+
...
|
155
|
+
|
156
|
+
|
157
|
+
class DefaultValueDeserializer(ValueDeserializer):
|
158
|
+
def deserialize(self, expected_type: Optional[Type], reader: CBORDecoder, context: 'DeserializationContext') -> Any:
|
159
|
+
cbor_map = reader.decode_map()
|
160
|
+
error_message = "No deserializer found for: " + ", ".join(f"{k}: {v}" for k, v in cbor_map.items())
|
161
|
+
raise NotImplementedError(error_message)
|
162
|
+
|
163
|
+
|
164
|
+
class DeserializationContext:
|
165
|
+
DefaultDeserializer = DefaultValueDeserializer()
|
166
|
+
|
167
|
+
def __init__(self, remoting_context: 'RemotingContext', value_deserializers: Optional[List[Tuple[Type, ValueDeserializer]]] = None):
|
168
|
+
self.remoting_context = remoting_context
|
169
|
+
self.value_deserializers = value_deserializers or []
|
170
|
+
|
171
|
+
def deserialize(self, expected_type: Type, decoder: CBORDecoder) -> Any:
|
172
|
+
value = decoder.decode()
|
173
|
+
if value is None:
|
174
|
+
return None
|
175
|
+
|
176
|
+
if expected_type == UUID:
|
177
|
+
return UUID(bytes=value)
|
178
|
+
|
179
|
+
if expected_type == str:
|
180
|
+
return value
|
181
|
+
|
182
|
+
if expected_type == bool:
|
183
|
+
return value
|
184
|
+
|
185
|
+
if expected_type == list:
|
186
|
+
return value
|
187
|
+
|
188
|
+
if expected_type == int:
|
189
|
+
return value
|
190
|
+
|
191
|
+
if expected_type == Path:
|
192
|
+
return Path(value)
|
193
|
+
|
194
|
+
if expected_type == float:
|
195
|
+
return value
|
196
|
+
|
197
|
+
if issubclass(expected_type, Enum):
|
198
|
+
return expected_type(value)
|
199
|
+
|
200
|
+
state = decoder.peek_state()
|
201
|
+
|
202
|
+
if state in {cbor2.CborReaderState.BOOLEAN, cbor2.CborReaderState.UNSIGNED_INT,
|
203
|
+
cbor2.CborReaderState.NEGATIVE_INT}:
|
204
|
+
result = decoder.read_int()
|
205
|
+
obj = self.remoting_context.get_object(result)
|
206
|
+
return obj if obj is not None else result
|
207
|
+
|
208
|
+
if state in {cbor2.CborReaderState.HALF_FLOAT, cbor2.CborReaderState.FLOAT, cbor2.CborReaderState.DOUBLE}:
|
209
|
+
return decoder.read_double()
|
210
|
+
|
211
|
+
if state == cbor2.CborReaderState.TEXT_STRING:
|
212
|
+
str_value = decoder.read_text_string()
|
213
|
+
if decoder.peek_state() == cbor2.CborReaderState.END_ARRAY:
|
214
|
+
return str_value
|
215
|
+
|
216
|
+
concrete_type = str_value if decoder.peek_state() != cbor2.CborReaderState.END_ARRAY else expected_type.__name__
|
217
|
+
|
218
|
+
if concrete_type == "org.openrewrite.FileAttributes":
|
219
|
+
map = decoder.read_cbor_map()
|
220
|
+
return FileAttributes(None, None, None, False, False, False, 0)
|
221
|
+
|
222
|
+
if concrete_type == "java.lang.String":
|
223
|
+
return decoder.decode()
|
224
|
+
if concrete_type == "java.lang.Boolean":
|
225
|
+
return decoder.decode()
|
226
|
+
if concrete_type == "java.lang.Integer":
|
227
|
+
return decoder.decode()
|
228
|
+
if concrete_type == "java.lang.Character":
|
229
|
+
return decoder.decode()[0]
|
230
|
+
if concrete_type == "java.lang.Long":
|
231
|
+
return decoder.decode()
|
232
|
+
if concrete_type == "java.lang.Double":
|
233
|
+
return decoder.decode()
|
234
|
+
if concrete_type == "java.lang.Float":
|
235
|
+
return decoder.decode()
|
236
|
+
if concrete_type == "java.math.BigInteger":
|
237
|
+
return decoder.decode()
|
238
|
+
if concrete_type == "java.math.BigDecimal":
|
239
|
+
return decoder.decode()
|
240
|
+
|
241
|
+
raise NotImplementedError(f"No deserialization implemented for: {concrete_type}")
|
242
|
+
|
243
|
+
if state == cbor2.CborReaderState.ARRAY:
|
244
|
+
decoder.read_array_start()
|
245
|
+
concrete_type = decoder.read_text_string()
|
246
|
+
actual_type = type_utils.get_type(concrete_type)
|
247
|
+
for type_, deserializer in self.value_deserializers:
|
248
|
+
if issubclass(actual_type, type_):
|
249
|
+
return deserializer.deserialize(actual_type, decoder, self)
|
250
|
+
|
251
|
+
if state == cbor2.CborReaderState.MAP:
|
252
|
+
if issubclass(expected_type, Marker):
|
253
|
+
if decoder.peek_state() == cbor2.CborReaderState.UNSIGNED_INT:
|
254
|
+
obj_id = decoder.read_int()
|
255
|
+
return self.remoting_context.get_object(obj_id)
|
256
|
+
|
257
|
+
marker_map = ValueDeserializer.read_cbor_map(decoder)
|
258
|
+
|
259
|
+
if "@c" not in marker_map:
|
260
|
+
raise NotImplementedError("Expected @c key")
|
261
|
+
|
262
|
+
concrete_type = marker_map["@c"]
|
263
|
+
|
264
|
+
if concrete_type in {"org.openrewrite.marker.SearchResult", "Rewrite.Core.Marker.SearchResult"}:
|
265
|
+
desc = marker_map.get("description", None)
|
266
|
+
marker = SearchResult(UUID(bytes=marker_map["id"]), desc)
|
267
|
+
else:
|
268
|
+
marker = UnknownJavaMarker(UUID(bytes=marker_map["id"]), marker_map)
|
269
|
+
|
270
|
+
if "@ref" in marker_map:
|
271
|
+
self.remoting_context.add(marker_map["@ref"], marker)
|
272
|
+
|
273
|
+
return marker
|
274
|
+
|
275
|
+
decoder.read_map_start()
|
276
|
+
if decoder.read_text_string() != "@c":
|
277
|
+
raise NotImplementedError("Expected @c key")
|
278
|
+
concrete_type = decoder.read_text_string()
|
279
|
+
actual_type = type_utils.get_type(concrete_type)
|
280
|
+
for type_, deserializer in self.value_deserializers:
|
281
|
+
if issubclass(actual_type, type_):
|
282
|
+
return deserializer.deserialize(actual_type, decoder, self)
|
283
|
+
raise NotImplementedError(f"No deserialization implemented for: {expected_type}")
|
284
|
+
|
285
|
+
|
286
|
+
class JsonReceiver(TreeReceiver):
|
287
|
+
DEBUG = False
|
288
|
+
|
289
|
+
def __init__(self, stream, context: DeserializationContext):
|
290
|
+
super().__init__()
|
291
|
+
self._stream = stream
|
292
|
+
self._decoder = cbor2.CBORDecoder(self._stream)
|
293
|
+
self._context = context
|
294
|
+
self._count = 0
|
295
|
+
|
296
|
+
def receive_node(self):
|
297
|
+
array = self._decoder.decode()
|
298
|
+
if isinstance(array, list):
|
299
|
+
event_type = EventType(cast(int, array[0]))
|
300
|
+
msg = None
|
301
|
+
concrete_type = None
|
302
|
+
|
303
|
+
if event_type in {EventType.Add, EventType.Update}:
|
304
|
+
if event_type == EventType.Add and isinstance(array[1], str):
|
305
|
+
concrete_type = array[1]
|
306
|
+
|
307
|
+
elif event_type not in {EventType.Delete, EventType.NoChange, EventType.StartList, EventType.EndList}:
|
308
|
+
raise NotImplementedError(event_type)
|
309
|
+
|
310
|
+
if self.DEBUG:
|
311
|
+
print(f"[{self._count}] {DiffEvent(event_type, concrete_type, msg)}")
|
312
|
+
self._count += 1
|
313
|
+
return DiffEvent(event_type, concrete_type, msg)
|
314
|
+
|
315
|
+
else:
|
316
|
+
raise NotImplementedError(f"Unexpected state: {type(array)}")
|
317
|
+
|
318
|
+
def receive_value(self, expected_type: Type):
|
319
|
+
length = remote_utils.decode_array_start(self._decoder)
|
320
|
+
event_type = EventType(self._decoder.decode())
|
321
|
+
msg = None
|
322
|
+
concrete_type = None
|
323
|
+
|
324
|
+
if event_type in {EventType.Add, EventType.Update}:
|
325
|
+
if (bool(get_args(expected_type)) and
|
326
|
+
issubclass(expected_type, (list, Iterable))):
|
327
|
+
# special case for list events
|
328
|
+
msg = self._decoder.decode()
|
329
|
+
else:
|
330
|
+
msg = self._context.deserialize(expected_type, self._decoder)
|
331
|
+
|
332
|
+
elif event_type not in {EventType.Delete, EventType.NoChange, EventType.StartList, EventType.EndList}:
|
333
|
+
raise NotImplementedError(event_type)
|
334
|
+
|
335
|
+
if length is None:
|
336
|
+
array_end = self._decoder.decode()
|
337
|
+
# assert array_end == break_marker
|
338
|
+
|
339
|
+
if self.DEBUG:
|
340
|
+
print(f"[{self._count}] {DiffEvent(event_type, concrete_type, msg)}")
|
341
|
+
self._count += 1
|
342
|
+
return DiffEvent(event_type, concrete_type, msg)
|
343
|
+
|
344
|
+
|
345
|
+
class ParseErrorReceiver(Receiver):
|
346
|
+
def fork(self, ctx):
|
347
|
+
return ctx.fork(self.Visitor(), self.Factory())
|
348
|
+
|
349
|
+
def receive(self, before, ctx):
|
350
|
+
forked = self.fork(ctx)
|
351
|
+
return forked.visitor.visit(before, forked)
|
352
|
+
|
353
|
+
class Visitor(ParseErrorVisitor):
|
354
|
+
def visit(self, tree, ctx, parent: Optional[Cursor] = None):
|
355
|
+
self.cursor = Cursor(self.cursor, tree)
|
356
|
+
tree = ctx.receive_node(tree, ctx.receive_tree)
|
357
|
+
self.cursor = self.cursor.parent
|
358
|
+
return tree
|
359
|
+
|
360
|
+
def visit_parse_error(self, parse_error, ctx):
|
361
|
+
parse_error = parse_error.with_id(ctx.receive_value(parse_error.id))
|
362
|
+
parse_error = parse_error.with_markers(ctx.receive_node(parse_error.markers, ctx.receive_markers))
|
363
|
+
parse_error = parse_error.with_source_path(ctx.receive_value(parse_error.source_path))
|
364
|
+
parse_error = parse_error.with_file_attributes(ctx.receive_value(parse_error.file_attributes))
|
365
|
+
parse_error = parse_error.with_charset_name(ctx.receive_value(parse_error.charset_name))
|
366
|
+
parse_error = parse_error.with_charset_bom_marked(ctx.receive_value(parse_error.charset_bom_marked))
|
367
|
+
parse_error = parse_error.with_checksum(ctx.receive_value(parse_error.checksum))
|
368
|
+
parse_error = parse_error.with_text(ctx.receive_value(parse_error.text))
|
369
|
+
# parse_error = parse_error.with_erroneous(ctx.receive_tree(parse_error.erroneous))
|
370
|
+
return parse_error
|
371
|
+
|
372
|
+
class Factory(ReceiverFactory):
|
373
|
+
def create(self, type_, ctx):
|
374
|
+
if type_ in ["rewrite.parser.ParseError", "org.openrewrite.tree.ParseError"]:
|
375
|
+
return ParseError(
|
376
|
+
ctx.receive_value(None),
|
377
|
+
ctx.receive_node(None, ctx.receive_markers),
|
378
|
+
ctx.receive_value(None),
|
379
|
+
ctx.receive_value(None),
|
380
|
+
ctx.receive_value(None),
|
381
|
+
ctx.receive_value(None),
|
382
|
+
ctx.receive_value(None),
|
383
|
+
ctx.receive_value(None),
|
384
|
+
None # ctx.receive_tree(None)
|
385
|
+
)
|
386
|
+
raise NotImplementedError("No factory method for type: " + type_)
|
@@ -0,0 +1,280 @@
|
|
1
|
+
import socket
|
2
|
+
import struct
|
3
|
+
from enum import Enum, auto
|
4
|
+
from io import BytesIO
|
5
|
+
from typing import Callable, List, Optional, TypeVar, Dict, TYPE_CHECKING, Iterable, cast, Type
|
6
|
+
|
7
|
+
from cbor2 import CBORDecoder, CBORDecodeValueError
|
8
|
+
|
9
|
+
from .event import EventType
|
10
|
+
|
11
|
+
if TYPE_CHECKING:
|
12
|
+
from .receiver import ReceiverContext, DetailsReceiver
|
13
|
+
|
14
|
+
T = TypeVar('T')
|
15
|
+
I = TypeVar('I')
|
16
|
+
|
17
|
+
|
18
|
+
class Operation(Enum):
|
19
|
+
Add = auto()
|
20
|
+
Delete = auto()
|
21
|
+
NoChange = auto()
|
22
|
+
Update = auto()
|
23
|
+
Move = auto()
|
24
|
+
|
25
|
+
|
26
|
+
def receive_nodes(before: Optional[List[T]],
|
27
|
+
details: Callable[[Optional[T], Optional[str], 'ReceiverContext'], T],
|
28
|
+
ctx: 'ReceiverContext') -> Optional[List[T]]:
|
29
|
+
list_event = ctx.receiver.receive_value(list)
|
30
|
+
if list_event.event_type == EventType.NoChange:
|
31
|
+
return before
|
32
|
+
elif list_event.event_type == EventType.Delete:
|
33
|
+
return None
|
34
|
+
elif list_event.event_type == EventType.Add:
|
35
|
+
after_size = list_event.msg
|
36
|
+
after = [None] * after_size # Preallocate list
|
37
|
+
for i in range(after_size):
|
38
|
+
diff_event = ctx.receiver.receive_node()
|
39
|
+
if diff_event.event_type == EventType.Add:
|
40
|
+
after[i] = details.receive_details(None, diff_event.concrete_type, ctx)
|
41
|
+
elif diff_event.event_type == EventType.NoChange:
|
42
|
+
after[i] = None # Or some default value
|
43
|
+
else:
|
44
|
+
raise NotImplementedError(f"Unexpected operation: {diff_event.event_type}")
|
45
|
+
return after
|
46
|
+
elif list_event.event_type == EventType.Update:
|
47
|
+
return _receive_updated_nodes(before, list_event.msg, details, ctx)
|
48
|
+
else:
|
49
|
+
raise NotImplementedError(f"Unexpected operation: {list_event.event_type}")
|
50
|
+
|
51
|
+
|
52
|
+
def _receive_updated_nodes(before: List[T],
|
53
|
+
after_size: int,
|
54
|
+
details: 'DetailsReceiver[T]',
|
55
|
+
ctx: 'ReceiverContext') -> List[T]:
|
56
|
+
modified = False
|
57
|
+
after_list = before
|
58
|
+
evt = ctx.receiver.receive_node()
|
59
|
+
if evt.event_type != EventType.StartList:
|
60
|
+
raise ValueError(f"Expected start list event: {evt.event_type}")
|
61
|
+
|
62
|
+
before_idx = 0
|
63
|
+
while True:
|
64
|
+
evt = ctx.receiver.receive_node()
|
65
|
+
if evt.event_type in (EventType.NoChange, EventType.EndList):
|
66
|
+
break
|
67
|
+
|
68
|
+
if evt.event_type in (EventType.Delete, EventType.Update, EventType.Add):
|
69
|
+
if not modified:
|
70
|
+
after_list = _copy_range(before, before_idx)
|
71
|
+
modified = True
|
72
|
+
|
73
|
+
if evt.event_type == EventType.NoChange:
|
74
|
+
if modified:
|
75
|
+
after_list.append(before[before_idx])
|
76
|
+
before_idx += 1
|
77
|
+
elif evt.event_type == EventType.Delete:
|
78
|
+
before_idx += 1
|
79
|
+
elif evt.event_type == EventType.Update:
|
80
|
+
after_list.append(details.receive_details(before[before_idx], evt.concrete_type, ctx))
|
81
|
+
before_idx += 1
|
82
|
+
elif evt.event_type == EventType.Add:
|
83
|
+
after_list.append(details.receive_details(None, evt.concrete_type, ctx))
|
84
|
+
|
85
|
+
if evt.event_type == EventType.EndList:
|
86
|
+
break
|
87
|
+
|
88
|
+
return after_list[:after_size] if len(after_list) > after_size else after_list
|
89
|
+
|
90
|
+
|
91
|
+
def receive_values(before: Optional[List[T]], type: Type, ctx: 'ReceiverContext') -> Optional[List[T]]:
|
92
|
+
list_event = ctx.receiver.receive_value(list)
|
93
|
+
if list_event.event_type == EventType.NoChange:
|
94
|
+
return before
|
95
|
+
elif list_event.event_type == EventType.Delete:
|
96
|
+
return None
|
97
|
+
elif list_event.event_type == EventType.Add:
|
98
|
+
after_size = list_event.msg
|
99
|
+
after = [None] * after_size # Preallocate list
|
100
|
+
for i in range(after_size):
|
101
|
+
diff_event = ctx.receiver.receive_value(type)
|
102
|
+
if diff_event.event_type == EventType.Add:
|
103
|
+
after[i] = diff_event.msg
|
104
|
+
elif diff_event.event_type == EventType.NoChange:
|
105
|
+
after[i] = None # Or some default value
|
106
|
+
else:
|
107
|
+
raise NotImplementedError(f"Unexpected operation: {diff_event.event_type}")
|
108
|
+
return after
|
109
|
+
elif list_event.event_type == EventType.Update:
|
110
|
+
return _receive_updated_values(before, list_event.msg, type, ctx)
|
111
|
+
else:
|
112
|
+
raise NotImplementedError(f"Unexpected operation: {list_event.event_type}")
|
113
|
+
|
114
|
+
|
115
|
+
def _receive_updated_values(before: List[T], after_size: int, type: Type, ctx: 'ReceiverContext') -> List[T]:
|
116
|
+
modified = False
|
117
|
+
after_list = before
|
118
|
+
evt = ctx.receiver.receive_node()
|
119
|
+
if evt.event_type != EventType.StartList:
|
120
|
+
raise ValueError(f"Expected start list event: {evt.event_type}")
|
121
|
+
|
122
|
+
before_idx = 0
|
123
|
+
while True:
|
124
|
+
evt = ctx.receiver.receive_value(type)
|
125
|
+
if evt.event_type in (EventType.NoChange, EventType.EndList):
|
126
|
+
break
|
127
|
+
|
128
|
+
if evt.event_type in (EventType.Delete, EventType.Update, EventType.Add):
|
129
|
+
if not modified:
|
130
|
+
after_list = _copy_range(before, before_idx)
|
131
|
+
modified = True
|
132
|
+
|
133
|
+
if evt.event_type == EventType.NoChange:
|
134
|
+
if modified:
|
135
|
+
after_list.append(before[before_idx])
|
136
|
+
before_idx += 1
|
137
|
+
elif evt.event_type == EventType.Delete:
|
138
|
+
before_idx += 1
|
139
|
+
elif evt.event_type in (EventType.Update, EventType.Add):
|
140
|
+
after_list.append(evt.msg)
|
141
|
+
if evt.event_type == EventType.Update:
|
142
|
+
before_idx += 1
|
143
|
+
|
144
|
+
if evt.event_type == EventType.EndList:
|
145
|
+
break
|
146
|
+
|
147
|
+
return after_list[:after_size] if len(after_list) > after_size else after_list
|
148
|
+
|
149
|
+
|
150
|
+
def _copy_range(before: Iterable[T], j: int) -> List[T]:
|
151
|
+
if isinstance(before, list):
|
152
|
+
return before[:j]
|
153
|
+
elif hasattr(before, 'getrange'): # If the object has a 'getrange' method (e.g., an immutable list)
|
154
|
+
return before.getrange(0, j)
|
155
|
+
else:
|
156
|
+
return list(before)[:j]
|
157
|
+
|
158
|
+
|
159
|
+
def calculate_list_diff(
|
160
|
+
before: List[T],
|
161
|
+
after: List[T],
|
162
|
+
id_function: Callable[[T], I],
|
163
|
+
consumer: Callable[[Operation, int, int, Optional[T], Optional[T]], None]
|
164
|
+
) -> None:
|
165
|
+
before_idx, after_idx = 0, 0
|
166
|
+
before_size, after_size = len(before), len(after)
|
167
|
+
after_map = None
|
168
|
+
|
169
|
+
while before_idx < before_size or after_idx < after_size:
|
170
|
+
# Check if we've reached the end of either of the lists
|
171
|
+
if before_idx >= before_size:
|
172
|
+
consumer(Operation.Add, -1, after_idx, None, after[after_idx])
|
173
|
+
after_idx += 1
|
174
|
+
continue
|
175
|
+
elif after_idx >= after_size:
|
176
|
+
consumer(Operation.Delete, before_idx, -1, before[before_idx], None)
|
177
|
+
before_idx += 1
|
178
|
+
continue
|
179
|
+
|
180
|
+
if before[before_idx] == after[after_idx]:
|
181
|
+
consumer(
|
182
|
+
Operation.NoChange,
|
183
|
+
before_idx,
|
184
|
+
after_idx,
|
185
|
+
before[before_idx],
|
186
|
+
after[after_idx]
|
187
|
+
)
|
188
|
+
before_idx += 1
|
189
|
+
after_idx += 1
|
190
|
+
else:
|
191
|
+
before_id = id_function(before[before_idx])
|
192
|
+
after_id = id_function(after[after_idx])
|
193
|
+
|
194
|
+
if before_id == after_id:
|
195
|
+
consumer(
|
196
|
+
Operation.Update,
|
197
|
+
before_idx,
|
198
|
+
after_idx,
|
199
|
+
before[before_idx],
|
200
|
+
after[after_idx]
|
201
|
+
)
|
202
|
+
before_idx += 1
|
203
|
+
after_idx += 1
|
204
|
+
else:
|
205
|
+
if after_map is None:
|
206
|
+
after_map = create_index_map(after, after_idx, id_function)
|
207
|
+
|
208
|
+
# If elements at current indices are not equal, figure out the operation
|
209
|
+
if before_id not in after_map:
|
210
|
+
consumer(Operation.Delete, before_idx, -1, before[before_idx], None)
|
211
|
+
before_idx += 1
|
212
|
+
else:
|
213
|
+
consumer(Operation.Add, -1, after_idx, None, after[after_idx])
|
214
|
+
after_idx += 1
|
215
|
+
|
216
|
+
|
217
|
+
def create_index_map(
|
218
|
+
lst: List[T],
|
219
|
+
from_index: int,
|
220
|
+
id_function: Callable[[T], I]
|
221
|
+
) -> Dict[I, int]:
|
222
|
+
result = {}
|
223
|
+
for i in range(from_index, len(lst)):
|
224
|
+
result[id_function(lst[i])] = i
|
225
|
+
return result
|
226
|
+
|
227
|
+
|
228
|
+
def _decode_length(decoder: CBORDecoder, subtype: int, allow_indefinite: bool = False):
|
229
|
+
if subtype < 24:
|
230
|
+
return subtype
|
231
|
+
elif subtype == 24:
|
232
|
+
return decoder.read(1)[0]
|
233
|
+
elif subtype == 25:
|
234
|
+
return cast(int, struct.unpack(">H", decoder.read(2))[0])
|
235
|
+
elif subtype == 26:
|
236
|
+
return cast(int, struct.unpack(">L", decoder.read(4))[0])
|
237
|
+
elif subtype == 27:
|
238
|
+
return cast(int, struct.unpack(">Q", decoder.read(8))[0])
|
239
|
+
elif subtype == 31 and allow_indefinite:
|
240
|
+
return None
|
241
|
+
else:
|
242
|
+
raise CBORDecodeValueError(f"unknown unsigned integer subtype 0x{subtype:x}")
|
243
|
+
|
244
|
+
def decode_array_start(decoder: CBORDecoder):
|
245
|
+
initial_byte = decoder.read(1)[0]
|
246
|
+
major_type = initial_byte >> 5
|
247
|
+
assert major_type == 4
|
248
|
+
subtype = initial_byte & 31
|
249
|
+
return _decode_length(decoder, subtype, allow_indefinite=True)
|
250
|
+
|
251
|
+
|
252
|
+
COPY_BUFFER_SIZE = 4096
|
253
|
+
COPY_BUFFER = bytearray(COPY_BUFFER_SIZE)
|
254
|
+
COMMAND_END = bytes([0x81, 0x17])
|
255
|
+
|
256
|
+
def read_to_command_end(sock: socket.socket) -> BytesIO:
|
257
|
+
memory_stream = BytesIO()
|
258
|
+
|
259
|
+
try:
|
260
|
+
while True:
|
261
|
+
bytes_read = sock.recv_into(COPY_BUFFER, len(COPY_BUFFER))
|
262
|
+
if bytes_read == 0:
|
263
|
+
break
|
264
|
+
memory_stream.write(COPY_BUFFER[:bytes_read])
|
265
|
+
if (bytes_read > 1 and
|
266
|
+
COPY_BUFFER[bytes_read - 2] == COMMAND_END[0] and
|
267
|
+
COPY_BUFFER[bytes_read - 1] == COMMAND_END[1]):
|
268
|
+
break
|
269
|
+
elif bytes_read == 1:
|
270
|
+
original_position = memory_stream.tell()
|
271
|
+
memory_stream.seek(-2, 1) # Move back by 2 bytes
|
272
|
+
if (memory_stream.read(1)[0] == COMMAND_END[0] and
|
273
|
+
memory_stream.read(1)[0] == COMMAND_END[1]):
|
274
|
+
memory_stream.seek(original_position)
|
275
|
+
break
|
276
|
+
except socket.error as e:
|
277
|
+
print(f"Socket error: {e}")
|
278
|
+
|
279
|
+
memory_stream.seek(0)
|
280
|
+
return memory_stream
|