openrewrite-remote 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,386 @@
1
+ from collections import OrderedDict
2
+ from enum import Enum
3
+ from pathlib import Path
4
+ from typing import Protocol, TypeVar, Optional, Type, Dict, Callable, List, cast, Iterable, Any, Tuple, \
5
+ get_args, TYPE_CHECKING
6
+ from uuid import UUID
7
+
8
+ import cbor2
9
+ from cbor2 import CBORDecoder
10
+ from rewrite import Markers, Marker, ParseErrorVisitor
11
+
12
+ from rewrite import Tree, TreeVisitor, Cursor, FileAttributes
13
+ from rewrite.remote.event import DiffEvent, EventType
14
+ from . import remote_utils, type_utils
15
+
16
+ if TYPE_CHECKING:
17
+ from .remoting import RemotingContext
18
+
19
+ A = TypeVar('T')
20
+ T = TypeVar('T', bound=Tree)
21
+ V = TypeVar('V')
22
+ I = TypeVar('I')
23
+ P = TypeVar('P')
24
+
25
+
26
+ class Receiver(Protocol):
27
+ def fork(self, context: 'ReceiverContext') -> 'ReceiverContext':
28
+ ...
29
+
30
+ def receive(self, before: Optional[T], ctx: 'ReceiverContext') -> object:
31
+ ...
32
+
33
+
34
+ class OmniReceiver(Receiver):
35
+ def fork(self, ctx: 'ReceiverContext') -> 'ReceiverContext':
36
+ raise NotImplementedError("Cannot fork OmniReceiver")
37
+
38
+ def receive(self, before: Optional['Tree'], ctx: 'ReceiverContext') -> 'Tree':
39
+ visitor = self.Visitor()
40
+ return visitor.visit(before, ctx)
41
+
42
+ class Visitor(TreeVisitor[Tree, 'ReceiverContext']):
43
+ def visit(self, tree: Optional[Tree], ctx: 'ReceiverContext', parent: Optional[Cursor] = None) -> Optional[
44
+ Tree]:
45
+ self.cursor = Cursor(self.cursor, tree)
46
+ tree = ctx.polymorphic_receive_tree(tree)
47
+ self.cursor = self.cursor.parent
48
+ return tree
49
+
50
+
51
+ class TreeReceiver(Protocol):
52
+ def receive_node(self) -> DiffEvent:
53
+ ...
54
+
55
+ def receive_value(self, expected_type: Type) -> DiffEvent:
56
+ ...
57
+
58
+
59
+ class ReceiverFactory(Protocol):
60
+ def create(self, type_name: str, ctx: 'ReceiverContext') -> Tree:
61
+ ...
62
+
63
+
64
+ class DetailsReceiver(Protocol[T]):
65
+ def receive_details(self, before: Optional[T], type: Optional[Type[T]], ctx: 'ReceiverContext') -> T:
66
+ pass
67
+
68
+
69
+ class ReceiverContext:
70
+ Registry: Dict[Type, Callable[[], Receiver]] = OrderedDict()
71
+
72
+ def __init__(self, receiver: TreeReceiver, visitor: Optional[TreeVisitor] = None,
73
+ factory: ReceiverFactory = None):
74
+ self.receiver = receiver
75
+ self.visitor = visitor
76
+ self.factory = factory
77
+
78
+ def fork(self, visitor: TreeVisitor, factory: ReceiverFactory) -> 'ReceiverContext':
79
+ return ReceiverContext(self.receiver, visitor, factory)
80
+
81
+ def receive_any_tree(self, before) -> T:
82
+ return OmniReceiver().receive(before, self)
83
+
84
+ def receive_tree(self, before: Optional[Tree], tree_type: Optional[str], ctx: 'ReceiverContext') -> Tree:
85
+ if before:
86
+ return before.accept(self.visitor, ctx)
87
+ else:
88
+ return self.factory.create(tree_type, ctx)
89
+
90
+ def polymorphic_receive_tree(self, before: Optional[Tree]) -> Optional[Tree]:
91
+ diff_event = self.receiver.receive_node()
92
+ if diff_event.event_type in (EventType.Add, EventType.Update):
93
+ tree_receiver = self.new_receiver(diff_event.concrete_type or type(before).__name__)
94
+ forked = tree_receiver.fork(self)
95
+ return forked.receive_tree(
96
+ None if diff_event.event_type == EventType.Add else before,
97
+ diff_event.concrete_type, forked
98
+ )
99
+ elif diff_event.event_type == EventType.Delete:
100
+ return None
101
+ else:
102
+ return before
103
+
104
+ def new_receiver(self, type_name: str) -> Receiver:
105
+ type_ = type_utils.get_type(type_name)
106
+ for entry_type, factory in self.Registry.items():
107
+ if issubclass(type_, entry_type):
108
+ return factory()
109
+ raise ValueError(f"Unsupported receiver type: {type_name}")
110
+
111
+ def receive_node(self, before: Optional[A],
112
+ details: Callable[[Optional[A], Optional[str], 'ReceiverContext'], A]) -> Optional[A]:
113
+ evt = self.receiver.receive_node()
114
+ if evt.event_type == EventType.Delete:
115
+ return None
116
+ elif evt.event_type == EventType.Add:
117
+ return details(None, evt.concrete_type, self)
118
+ elif evt.event_type == EventType.Update:
119
+ return details(before, evt.concrete_type, self)
120
+ return before
121
+
122
+ def receive_markers(self, before: Optional[Markers], type: Optional[str], ctx: 'ReceiverContext') -> Markers:
123
+ id_ = self.receive_value(getattr(before, 'id', None), UUID)
124
+ after_markers = self.receive_values(getattr(before, 'markers', None), Marker)
125
+ if before:
126
+ return before.with_id(id_).with_markers(after_markers)
127
+ else:
128
+ return Markers(id_, after_markers)
129
+
130
+ def receive_nodes(self, before: Optional[List[A]], details: Callable[[Optional[A], Optional[str], 'ReceiverContext'], A]) -> Optional[List[A]]:
131
+ return remote_utils.receive_nodes(before, details, self)
132
+
133
+ def receive_values(self, before: Optional[List[V]], type: Type) -> Optional[List[V]]:
134
+ return remote_utils.receive_values(before, type, self)
135
+
136
+ def receive_value(self, before: Optional[V], type: Type) -> Optional[V]:
137
+ return self.receive_value0(before, type)
138
+
139
+ def receive_value0(self, before: Optional[V], type: Type) -> Optional[V]:
140
+ evt = self.receiver.receive_value(type)
141
+ if evt.event_type in (EventType.Update, EventType.Add):
142
+ return evt.msg
143
+ elif evt.event_type == EventType.Delete:
144
+ return None
145
+ return before
146
+
147
+ @staticmethod
148
+ def register(type_: Type, receiver_factory: Callable[[], Receiver]):
149
+ ReceiverContext.Registry[type_] = receiver_factory
150
+
151
+
152
+ class ValueDeserializer(Protocol):
153
+ def deserialize(self, type_: Optional[Type], reader: CBORDecoder, context: 'DeserializationContext') -> Optional[Any]:
154
+ ...
155
+
156
+
157
+ class DefaultValueDeserializer(ValueDeserializer):
158
+ def deserialize(self, expected_type: Optional[Type], reader: CBORDecoder, context: 'DeserializationContext') -> Any:
159
+ cbor_map = reader.decode_map()
160
+ error_message = "No deserializer found for: " + ", ".join(f"{k}: {v}" for k, v in cbor_map.items())
161
+ raise NotImplementedError(error_message)
162
+
163
+
164
+ class DeserializationContext:
165
+ DefaultDeserializer = DefaultValueDeserializer()
166
+
167
+ def __init__(self, remoting_context: 'RemotingContext', value_deserializers: Optional[List[Tuple[Type, ValueDeserializer]]] = None):
168
+ self.remoting_context = remoting_context
169
+ self.value_deserializers = value_deserializers or []
170
+
171
+ def deserialize(self, expected_type: Type, decoder: CBORDecoder) -> Any:
172
+ value = decoder.decode()
173
+ if value is None:
174
+ return None
175
+
176
+ if expected_type == UUID:
177
+ return UUID(bytes=value)
178
+
179
+ if expected_type == str:
180
+ return value
181
+
182
+ if expected_type == bool:
183
+ return value
184
+
185
+ if expected_type == list:
186
+ return value
187
+
188
+ if expected_type == int:
189
+ return value
190
+
191
+ if expected_type == Path:
192
+ return Path(value)
193
+
194
+ if expected_type == float:
195
+ return value
196
+
197
+ if issubclass(expected_type, Enum):
198
+ return expected_type(value)
199
+
200
+ state = decoder.peek_state()
201
+
202
+ if state in {cbor2.CborReaderState.BOOLEAN, cbor2.CborReaderState.UNSIGNED_INT,
203
+ cbor2.CborReaderState.NEGATIVE_INT}:
204
+ result = decoder.read_int()
205
+ obj = self.remoting_context.get_object(result)
206
+ return obj if obj is not None else result
207
+
208
+ if state in {cbor2.CborReaderState.HALF_FLOAT, cbor2.CborReaderState.FLOAT, cbor2.CborReaderState.DOUBLE}:
209
+ return decoder.read_double()
210
+
211
+ if state == cbor2.CborReaderState.TEXT_STRING:
212
+ str_value = decoder.read_text_string()
213
+ if decoder.peek_state() == cbor2.CborReaderState.END_ARRAY:
214
+ return str_value
215
+
216
+ concrete_type = str_value if decoder.peek_state() != cbor2.CborReaderState.END_ARRAY else expected_type.__name__
217
+
218
+ if concrete_type == "org.openrewrite.FileAttributes":
219
+ map = decoder.read_cbor_map()
220
+ return FileAttributes(None, None, None, False, False, False, 0)
221
+
222
+ if concrete_type == "java.lang.String":
223
+ return decoder.decode()
224
+ if concrete_type == "java.lang.Boolean":
225
+ return decoder.decode()
226
+ if concrete_type == "java.lang.Integer":
227
+ return decoder.decode()
228
+ if concrete_type == "java.lang.Character":
229
+ return decoder.decode()[0]
230
+ if concrete_type == "java.lang.Long":
231
+ return decoder.decode()
232
+ if concrete_type == "java.lang.Double":
233
+ return decoder.decode()
234
+ if concrete_type == "java.lang.Float":
235
+ return decoder.decode()
236
+ if concrete_type == "java.math.BigInteger":
237
+ return decoder.decode()
238
+ if concrete_type == "java.math.BigDecimal":
239
+ return decoder.decode()
240
+
241
+ raise NotImplementedError(f"No deserialization implemented for: {concrete_type}")
242
+
243
+ if state == cbor2.CborReaderState.ARRAY:
244
+ decoder.read_array_start()
245
+ concrete_type = decoder.read_text_string()
246
+ actual_type = type_utils.get_type(concrete_type)
247
+ for type_, deserializer in self.value_deserializers:
248
+ if issubclass(actual_type, type_):
249
+ return deserializer.deserialize(actual_type, decoder, self)
250
+
251
+ if state == cbor2.CborReaderState.MAP:
252
+ if issubclass(expected_type, Marker):
253
+ if decoder.peek_state() == cbor2.CborReaderState.UNSIGNED_INT:
254
+ obj_id = decoder.read_int()
255
+ return self.remoting_context.get_object(obj_id)
256
+
257
+ marker_map = ValueDeserializer.read_cbor_map(decoder)
258
+
259
+ if "@c" not in marker_map:
260
+ raise NotImplementedError("Expected @c key")
261
+
262
+ concrete_type = marker_map["@c"]
263
+
264
+ if concrete_type in {"org.openrewrite.marker.SearchResult", "Rewrite.Core.Marker.SearchResult"}:
265
+ desc = marker_map.get("description", None)
266
+ marker = SearchResult(UUID(bytes=marker_map["id"]), desc)
267
+ else:
268
+ marker = UnknownJavaMarker(UUID(bytes=marker_map["id"]), marker_map)
269
+
270
+ if "@ref" in marker_map:
271
+ self.remoting_context.add(marker_map["@ref"], marker)
272
+
273
+ return marker
274
+
275
+ decoder.read_map_start()
276
+ if decoder.read_text_string() != "@c":
277
+ raise NotImplementedError("Expected @c key")
278
+ concrete_type = decoder.read_text_string()
279
+ actual_type = type_utils.get_type(concrete_type)
280
+ for type_, deserializer in self.value_deserializers:
281
+ if issubclass(actual_type, type_):
282
+ return deserializer.deserialize(actual_type, decoder, self)
283
+ raise NotImplementedError(f"No deserialization implemented for: {expected_type}")
284
+
285
+
286
+ class JsonReceiver(TreeReceiver):
287
+ DEBUG = False
288
+
289
+ def __init__(self, stream, context: DeserializationContext):
290
+ super().__init__()
291
+ self._stream = stream
292
+ self._decoder = cbor2.CBORDecoder(self._stream)
293
+ self._context = context
294
+ self._count = 0
295
+
296
+ def receive_node(self):
297
+ array = self._decoder.decode()
298
+ if isinstance(array, list):
299
+ event_type = EventType(cast(int, array[0]))
300
+ msg = None
301
+ concrete_type = None
302
+
303
+ if event_type in {EventType.Add, EventType.Update}:
304
+ if event_type == EventType.Add and isinstance(array[1], str):
305
+ concrete_type = array[1]
306
+
307
+ elif event_type not in {EventType.Delete, EventType.NoChange, EventType.StartList, EventType.EndList}:
308
+ raise NotImplementedError(event_type)
309
+
310
+ if self.DEBUG:
311
+ print(f"[{self._count}] {DiffEvent(event_type, concrete_type, msg)}")
312
+ self._count += 1
313
+ return DiffEvent(event_type, concrete_type, msg)
314
+
315
+ else:
316
+ raise NotImplementedError(f"Unexpected state: {type(array)}")
317
+
318
+ def receive_value(self, expected_type: Type):
319
+ length = remote_utils.decode_array_start(self._decoder)
320
+ event_type = EventType(self._decoder.decode())
321
+ msg = None
322
+ concrete_type = None
323
+
324
+ if event_type in {EventType.Add, EventType.Update}:
325
+ if (bool(get_args(expected_type)) and
326
+ issubclass(expected_type, (list, Iterable))):
327
+ # special case for list events
328
+ msg = self._decoder.decode()
329
+ else:
330
+ msg = self._context.deserialize(expected_type, self._decoder)
331
+
332
+ elif event_type not in {EventType.Delete, EventType.NoChange, EventType.StartList, EventType.EndList}:
333
+ raise NotImplementedError(event_type)
334
+
335
+ if length is None:
336
+ array_end = self._decoder.decode()
337
+ # assert array_end == break_marker
338
+
339
+ if self.DEBUG:
340
+ print(f"[{self._count}] {DiffEvent(event_type, concrete_type, msg)}")
341
+ self._count += 1
342
+ return DiffEvent(event_type, concrete_type, msg)
343
+
344
+
345
+ class ParseErrorReceiver(Receiver):
346
+ def fork(self, ctx):
347
+ return ctx.fork(self.Visitor(), self.Factory())
348
+
349
+ def receive(self, before, ctx):
350
+ forked = self.fork(ctx)
351
+ return forked.visitor.visit(before, forked)
352
+
353
+ class Visitor(ParseErrorVisitor):
354
+ def visit(self, tree, ctx, parent: Optional[Cursor] = None):
355
+ self.cursor = Cursor(self.cursor, tree)
356
+ tree = ctx.receive_node(tree, ctx.receive_tree)
357
+ self.cursor = self.cursor.parent
358
+ return tree
359
+
360
+ def visit_parse_error(self, parse_error, ctx):
361
+ parse_error = parse_error.with_id(ctx.receive_value(parse_error.id))
362
+ parse_error = parse_error.with_markers(ctx.receive_node(parse_error.markers, ctx.receive_markers))
363
+ parse_error = parse_error.with_source_path(ctx.receive_value(parse_error.source_path))
364
+ parse_error = parse_error.with_file_attributes(ctx.receive_value(parse_error.file_attributes))
365
+ parse_error = parse_error.with_charset_name(ctx.receive_value(parse_error.charset_name))
366
+ parse_error = parse_error.with_charset_bom_marked(ctx.receive_value(parse_error.charset_bom_marked))
367
+ parse_error = parse_error.with_checksum(ctx.receive_value(parse_error.checksum))
368
+ parse_error = parse_error.with_text(ctx.receive_value(parse_error.text))
369
+ # parse_error = parse_error.with_erroneous(ctx.receive_tree(parse_error.erroneous))
370
+ return parse_error
371
+
372
+ class Factory(ReceiverFactory):
373
+ def create(self, type_, ctx):
374
+ if type_ in ["rewrite.parser.ParseError", "org.openrewrite.tree.ParseError"]:
375
+ return ParseError(
376
+ ctx.receive_value(None),
377
+ ctx.receive_node(None, ctx.receive_markers),
378
+ ctx.receive_value(None),
379
+ ctx.receive_value(None),
380
+ ctx.receive_value(None),
381
+ ctx.receive_value(None),
382
+ ctx.receive_value(None),
383
+ ctx.receive_value(None),
384
+ None # ctx.receive_tree(None)
385
+ )
386
+ raise NotImplementedError("No factory method for type: " + type_)
@@ -0,0 +1,280 @@
1
+ import socket
2
+ import struct
3
+ from enum import Enum, auto
4
+ from io import BytesIO
5
+ from typing import Callable, List, Optional, TypeVar, Dict, TYPE_CHECKING, Iterable, cast, Type
6
+
7
+ from cbor2 import CBORDecoder, CBORDecodeValueError
8
+
9
+ from .event import EventType
10
+
11
+ if TYPE_CHECKING:
12
+ from .receiver import ReceiverContext, DetailsReceiver
13
+
14
+ T = TypeVar('T')
15
+ I = TypeVar('I')
16
+
17
+
18
+ class Operation(Enum):
19
+ Add = auto()
20
+ Delete = auto()
21
+ NoChange = auto()
22
+ Update = auto()
23
+ Move = auto()
24
+
25
+
26
+ def receive_nodes(before: Optional[List[T]],
27
+ details: Callable[[Optional[T], Optional[str], 'ReceiverContext'], T],
28
+ ctx: 'ReceiverContext') -> Optional[List[T]]:
29
+ list_event = ctx.receiver.receive_value(list)
30
+ if list_event.event_type == EventType.NoChange:
31
+ return before
32
+ elif list_event.event_type == EventType.Delete:
33
+ return None
34
+ elif list_event.event_type == EventType.Add:
35
+ after_size = list_event.msg
36
+ after = [None] * after_size # Preallocate list
37
+ for i in range(after_size):
38
+ diff_event = ctx.receiver.receive_node()
39
+ if diff_event.event_type == EventType.Add:
40
+ after[i] = details.receive_details(None, diff_event.concrete_type, ctx)
41
+ elif diff_event.event_type == EventType.NoChange:
42
+ after[i] = None # Or some default value
43
+ else:
44
+ raise NotImplementedError(f"Unexpected operation: {diff_event.event_type}")
45
+ return after
46
+ elif list_event.event_type == EventType.Update:
47
+ return _receive_updated_nodes(before, list_event.msg, details, ctx)
48
+ else:
49
+ raise NotImplementedError(f"Unexpected operation: {list_event.event_type}")
50
+
51
+
52
+ def _receive_updated_nodes(before: List[T],
53
+ after_size: int,
54
+ details: 'DetailsReceiver[T]',
55
+ ctx: 'ReceiverContext') -> List[T]:
56
+ modified = False
57
+ after_list = before
58
+ evt = ctx.receiver.receive_node()
59
+ if evt.event_type != EventType.StartList:
60
+ raise ValueError(f"Expected start list event: {evt.event_type}")
61
+
62
+ before_idx = 0
63
+ while True:
64
+ evt = ctx.receiver.receive_node()
65
+ if evt.event_type in (EventType.NoChange, EventType.EndList):
66
+ break
67
+
68
+ if evt.event_type in (EventType.Delete, EventType.Update, EventType.Add):
69
+ if not modified:
70
+ after_list = _copy_range(before, before_idx)
71
+ modified = True
72
+
73
+ if evt.event_type == EventType.NoChange:
74
+ if modified:
75
+ after_list.append(before[before_idx])
76
+ before_idx += 1
77
+ elif evt.event_type == EventType.Delete:
78
+ before_idx += 1
79
+ elif evt.event_type == EventType.Update:
80
+ after_list.append(details.receive_details(before[before_idx], evt.concrete_type, ctx))
81
+ before_idx += 1
82
+ elif evt.event_type == EventType.Add:
83
+ after_list.append(details.receive_details(None, evt.concrete_type, ctx))
84
+
85
+ if evt.event_type == EventType.EndList:
86
+ break
87
+
88
+ return after_list[:after_size] if len(after_list) > after_size else after_list
89
+
90
+
91
+ def receive_values(before: Optional[List[T]], type: Type, ctx: 'ReceiverContext') -> Optional[List[T]]:
92
+ list_event = ctx.receiver.receive_value(list)
93
+ if list_event.event_type == EventType.NoChange:
94
+ return before
95
+ elif list_event.event_type == EventType.Delete:
96
+ return None
97
+ elif list_event.event_type == EventType.Add:
98
+ after_size = list_event.msg
99
+ after = [None] * after_size # Preallocate list
100
+ for i in range(after_size):
101
+ diff_event = ctx.receiver.receive_value(type)
102
+ if diff_event.event_type == EventType.Add:
103
+ after[i] = diff_event.msg
104
+ elif diff_event.event_type == EventType.NoChange:
105
+ after[i] = None # Or some default value
106
+ else:
107
+ raise NotImplementedError(f"Unexpected operation: {diff_event.event_type}")
108
+ return after
109
+ elif list_event.event_type == EventType.Update:
110
+ return _receive_updated_values(before, list_event.msg, type, ctx)
111
+ else:
112
+ raise NotImplementedError(f"Unexpected operation: {list_event.event_type}")
113
+
114
+
115
+ def _receive_updated_values(before: List[T], after_size: int, type: Type, ctx: 'ReceiverContext') -> List[T]:
116
+ modified = False
117
+ after_list = before
118
+ evt = ctx.receiver.receive_node()
119
+ if evt.event_type != EventType.StartList:
120
+ raise ValueError(f"Expected start list event: {evt.event_type}")
121
+
122
+ before_idx = 0
123
+ while True:
124
+ evt = ctx.receiver.receive_value(type)
125
+ if evt.event_type in (EventType.NoChange, EventType.EndList):
126
+ break
127
+
128
+ if evt.event_type in (EventType.Delete, EventType.Update, EventType.Add):
129
+ if not modified:
130
+ after_list = _copy_range(before, before_idx)
131
+ modified = True
132
+
133
+ if evt.event_type == EventType.NoChange:
134
+ if modified:
135
+ after_list.append(before[before_idx])
136
+ before_idx += 1
137
+ elif evt.event_type == EventType.Delete:
138
+ before_idx += 1
139
+ elif evt.event_type in (EventType.Update, EventType.Add):
140
+ after_list.append(evt.msg)
141
+ if evt.event_type == EventType.Update:
142
+ before_idx += 1
143
+
144
+ if evt.event_type == EventType.EndList:
145
+ break
146
+
147
+ return after_list[:after_size] if len(after_list) > after_size else after_list
148
+
149
+
150
+ def _copy_range(before: Iterable[T], j: int) -> List[T]:
151
+ if isinstance(before, list):
152
+ return before[:j]
153
+ elif hasattr(before, 'getrange'): # If the object has a 'getrange' method (e.g., an immutable list)
154
+ return before.getrange(0, j)
155
+ else:
156
+ return list(before)[:j]
157
+
158
+
159
+ def calculate_list_diff(
160
+ before: List[T],
161
+ after: List[T],
162
+ id_function: Callable[[T], I],
163
+ consumer: Callable[[Operation, int, int, Optional[T], Optional[T]], None]
164
+ ) -> None:
165
+ before_idx, after_idx = 0, 0
166
+ before_size, after_size = len(before), len(after)
167
+ after_map = None
168
+
169
+ while before_idx < before_size or after_idx < after_size:
170
+ # Check if we've reached the end of either of the lists
171
+ if before_idx >= before_size:
172
+ consumer(Operation.Add, -1, after_idx, None, after[after_idx])
173
+ after_idx += 1
174
+ continue
175
+ elif after_idx >= after_size:
176
+ consumer(Operation.Delete, before_idx, -1, before[before_idx], None)
177
+ before_idx += 1
178
+ continue
179
+
180
+ if before[before_idx] == after[after_idx]:
181
+ consumer(
182
+ Operation.NoChange,
183
+ before_idx,
184
+ after_idx,
185
+ before[before_idx],
186
+ after[after_idx]
187
+ )
188
+ before_idx += 1
189
+ after_idx += 1
190
+ else:
191
+ before_id = id_function(before[before_idx])
192
+ after_id = id_function(after[after_idx])
193
+
194
+ if before_id == after_id:
195
+ consumer(
196
+ Operation.Update,
197
+ before_idx,
198
+ after_idx,
199
+ before[before_idx],
200
+ after[after_idx]
201
+ )
202
+ before_idx += 1
203
+ after_idx += 1
204
+ else:
205
+ if after_map is None:
206
+ after_map = create_index_map(after, after_idx, id_function)
207
+
208
+ # If elements at current indices are not equal, figure out the operation
209
+ if before_id not in after_map:
210
+ consumer(Operation.Delete, before_idx, -1, before[before_idx], None)
211
+ before_idx += 1
212
+ else:
213
+ consumer(Operation.Add, -1, after_idx, None, after[after_idx])
214
+ after_idx += 1
215
+
216
+
217
+ def create_index_map(
218
+ lst: List[T],
219
+ from_index: int,
220
+ id_function: Callable[[T], I]
221
+ ) -> Dict[I, int]:
222
+ result = {}
223
+ for i in range(from_index, len(lst)):
224
+ result[id_function(lst[i])] = i
225
+ return result
226
+
227
+
228
+ def _decode_length(decoder: CBORDecoder, subtype: int, allow_indefinite: bool = False):
229
+ if subtype < 24:
230
+ return subtype
231
+ elif subtype == 24:
232
+ return decoder.read(1)[0]
233
+ elif subtype == 25:
234
+ return cast(int, struct.unpack(">H", decoder.read(2))[0])
235
+ elif subtype == 26:
236
+ return cast(int, struct.unpack(">L", decoder.read(4))[0])
237
+ elif subtype == 27:
238
+ return cast(int, struct.unpack(">Q", decoder.read(8))[0])
239
+ elif subtype == 31 and allow_indefinite:
240
+ return None
241
+ else:
242
+ raise CBORDecodeValueError(f"unknown unsigned integer subtype 0x{subtype:x}")
243
+
244
+ def decode_array_start(decoder: CBORDecoder):
245
+ initial_byte = decoder.read(1)[0]
246
+ major_type = initial_byte >> 5
247
+ assert major_type == 4
248
+ subtype = initial_byte & 31
249
+ return _decode_length(decoder, subtype, allow_indefinite=True)
250
+
251
+
252
+ COPY_BUFFER_SIZE = 4096
253
+ COPY_BUFFER = bytearray(COPY_BUFFER_SIZE)
254
+ COMMAND_END = bytes([0x81, 0x17])
255
+
256
+ def read_to_command_end(sock: socket.socket) -> BytesIO:
257
+ memory_stream = BytesIO()
258
+
259
+ try:
260
+ while True:
261
+ bytes_read = sock.recv_into(COPY_BUFFER, len(COPY_BUFFER))
262
+ if bytes_read == 0:
263
+ break
264
+ memory_stream.write(COPY_BUFFER[:bytes_read])
265
+ if (bytes_read > 1 and
266
+ COPY_BUFFER[bytes_read - 2] == COMMAND_END[0] and
267
+ COPY_BUFFER[bytes_read - 1] == COMMAND_END[1]):
268
+ break
269
+ elif bytes_read == 1:
270
+ original_position = memory_stream.tell()
271
+ memory_stream.seek(-2, 1) # Move back by 2 bytes
272
+ if (memory_stream.read(1)[0] == COMMAND_END[0] and
273
+ memory_stream.read(1)[0] == COMMAND_END[1]):
274
+ memory_stream.seek(original_position)
275
+ break
276
+ except socket.error as e:
277
+ print(f"Socket error: {e}")
278
+
279
+ memory_stream.seek(0)
280
+ return memory_stream