openrewrite-remote 0.11.0__py3-none-any.whl → 0.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,12 @@
1
+ Metadata-Version: 2.1
2
+ Name: openrewrite-remote
3
+ Version: 0.13.0
4
+ Summary: Remoting functionality for the OpenRewrite library.
5
+ Author-email: "Moderne Inc." <support@moderne.io>
6
+ License: Moderne, Inc. Commercial License
7
+ Requires-Python: <4,>=3.9
8
+ Requires-Dist: cbor2==5.6.5
9
+ Requires-Dist: openrewrite
10
+ Requires-Dist: pip>=24.3.1
11
+ Requires-Dist: pypi-simple>=1.6.1
12
+ Requires-Dist: toml>=0.10.2
@@ -0,0 +1,14 @@
1
+ rewrite_remote/__init__.py,sha256=uuLrPH--ewvE-5owXbNItXDfjCypMXQgsm-72hO_dtc,286
2
+ rewrite_remote/client.py,sha256=95ZCAtVOngF0ZqqKnOsrweUeGKruf3UKGPXNGTrNyy0,1853
3
+ rewrite_remote/event.py,sha256=texLJD1mcFkpBpiXAa-Rmip0Tgqm2OlBpRPHFZyWcBs,359
4
+ rewrite_remote/receiver.py,sha256=U50jtbhmAVy361Vm5Bwd-amv19KcS8ltDRBOy2nC51E,18983
5
+ rewrite_remote/remote_utils.py,sha256=wUo9WZoldgCLihFJGf6RaE1SufhDiEPCFlX74tcODVM,10552
6
+ rewrite_remote/remoting.py,sha256=83Wvvj8tMCkUjOam0wWevWJeN-uHW1k9lGdoCsI0u0g,13690
7
+ rewrite_remote/sender.py,sha256=pa68X6bjvCQW_25wuVtHq0ByjA6WgI1Rl3EbL8CcvX4,20003
8
+ rewrite_remote/server.py,sha256=mKI9_PVvBuShoQmfqk0EhNA70CUJOtyzHB3xv-cNBYs,9413
9
+ rewrite_remote/type_utils.py,sha256=oVrB0olWFSCqhmg2nTU2wrwiAU7kBCUscjwdHK7gf3Y,4219
10
+ openrewrite_remote-0.13.0.dist-info/METADATA,sha256=ufTVYfANQruG21mIlcj0tg0TcrcOsR4R2cV7Cva-Gzk,386
11
+ openrewrite_remote-0.13.0.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
12
+ openrewrite_remote-0.13.0.dist-info/entry_points.txt,sha256=SMukuF7TPjQr3IZIcH8f98-_QBCqYSbYXYrVv-5UzRI,69
13
+ openrewrite_remote-0.13.0.dist-info/top_level.txt,sha256=ansTioSZ-62aH3F2L3d1Bua0pJF4GOtgQ1PpG-CzcP0,15
14
+ openrewrite_remote-0.13.0.dist-info/RECORD,,
@@ -1,4 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: poetry-core 1.9.1
2
+ Generator: setuptools (75.6.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
+
@@ -0,0 +1,2 @@
1
+ [console_scripts]
2
+ start_python_remoting = rewrite_remote.server:main
@@ -0,0 +1 @@
1
+ rewrite_remote
@@ -0,0 +1,12 @@
1
+ __path__ = __import__("pkgutil").extend_path(__path__, __name__)
2
+ from typing import TypeVar
3
+
4
+ from .receiver import *
5
+ from .sender import *
6
+ from .remoting import *
7
+
8
+ __all__ = [
9
+ name
10
+ for name in dir()
11
+ if not name.startswith("_") and not isinstance(globals()[name], TypeVar)
12
+ ]
@@ -0,0 +1,70 @@
1
+ import socket
2
+ import tempfile
3
+ from pathlib import Path
4
+
5
+ import rewrite.java.tree as j
6
+ import rewrite.python.tree as py
7
+ from rewrite import Markers
8
+ from rewrite import random_id, Cursor, PrintOutputCapture
9
+ from rewrite.java import Space, JavaType
10
+ from rewrite.python import Py
11
+ from rewrite.python.remote.receiver import PythonReceiver
12
+ from rewrite.python.remote.sender import PythonSender
13
+
14
+ from .receiver import ReceiverContext
15
+ from .remoting import (
16
+ RemotingContext,
17
+ RemotePrinterFactory,
18
+ )
19
+ from .sender import SenderContext
20
+
21
+ SenderContext.register(Py, lambda: PythonSender())
22
+ ReceiverContext.register(Py, lambda: PythonReceiver())
23
+
24
+ # Path to the Unix domain socket
25
+ SOCKET_PATH = tempfile.gettempdir() + "/rewrite-java.sock"
26
+
27
+ # Create a Unix domain socket
28
+ client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
29
+
30
+ # Connect the socket to the path where the server is listening
31
+ client.connect(("localhost", 65432))
32
+ print(f"Connected to {SOCKET_PATH}")
33
+
34
+ try:
35
+ remoting = RemotingContext()
36
+ remoting.connect(client)
37
+ RemotePrinterFactory(remoting.client).set_current()
38
+
39
+ literal = j.Literal(
40
+ random_id(),
41
+ Space.SINGLE_SPACE,
42
+ Markers.EMPTY,
43
+ True,
44
+ "True",
45
+ None,
46
+ JavaType.Primitive(),
47
+ )
48
+ assert_ = py.AssertStatement(
49
+ random_id(),
50
+ Space.EMPTY,
51
+ Markers.EMPTY,
52
+ [j.JRightPadded(literal, Space.EMPTY, Markers.EMPTY)],
53
+ )
54
+ cu = py.CompilationUnit(
55
+ random_id(),
56
+ Space.EMPTY,
57
+ Markers.EMPTY,
58
+ Path("/foo.py"),
59
+ None,
60
+ None,
61
+ False,
62
+ None,
63
+ [],
64
+ [j.JRightPadded(assert_, Space.EMPTY, Markers.EMPTY)],
65
+ Space.EMPTY,
66
+ )
67
+ printed = cu.print(Cursor(None, Cursor.ROOT_VALUE), PrintOutputCapture(0))
68
+ assert printed == "assert True"
69
+ finally:
70
+ client.close()
@@ -1,6 +1,6 @@
1
1
  from dataclasses import dataclass
2
2
  from enum import Enum
3
- from typing import Optional
3
+ from typing import Optional, Any
4
4
 
5
5
 
6
6
  class EventType(Enum):
@@ -17,4 +17,4 @@ class EventType(Enum):
17
17
  class DiffEvent:
18
18
  event_type: EventType
19
19
  concrete_type: Optional[str] = None
20
- msg: Optional[any] = None
20
+ msg: Optional[Any] = None
@@ -1,100 +1,136 @@
1
+ # type: ignore
2
+ # Ignoring type checking for this file because there is too many errors for now
3
+
1
4
  from collections import OrderedDict
2
5
  from enum import Enum
3
6
  from pathlib import Path
4
- from typing import Protocol, TypeVar, Optional, Type, Dict, Callable, List, cast, Iterable, Any, Tuple, \
5
- get_args, TYPE_CHECKING
7
+ from typing import (
8
+ Protocol,
9
+ TypeVar,
10
+ Optional,
11
+ Type,
12
+ Dict,
13
+ Callable,
14
+ List,
15
+ cast,
16
+ Iterable,
17
+ Any,
18
+ Tuple,
19
+ get_args,
20
+ TYPE_CHECKING,
21
+ )
6
22
  from uuid import UUID
7
23
 
8
24
  import cbor2
9
25
  from cbor2 import CBORDecoder
10
- from rewrite import Markers, Marker, ParseErrorVisitor
11
-
26
+ from cbor2._decoder import major_decoders
27
+ from rewrite import (
28
+ Markers,
29
+ Marker,
30
+ ParseError,
31
+ ParseErrorVisitor,
32
+ SearchResult,
33
+ UnknownJavaMarker,
34
+ )
12
35
  from rewrite import Tree, TreeVisitor, Cursor, FileAttributes
13
- from rewrite.remote.event import DiffEvent, EventType
36
+
37
+ from .event import DiffEvent, EventType
14
38
  from . import remote_utils, type_utils
15
39
 
16
40
  if TYPE_CHECKING:
17
41
  from .remoting import RemotingContext
18
42
 
19
- A = TypeVar('T')
20
- T = TypeVar('T', bound=Tree)
21
- V = TypeVar('V')
22
- I = TypeVar('I')
23
- P = TypeVar('P')
43
+ A = TypeVar("A")
44
+ T = TypeVar("T", bound=Tree)
45
+ V = TypeVar("V")
46
+ I = TypeVar("I")
47
+ P = TypeVar("P")
24
48
 
25
49
 
26
50
  class Receiver(Protocol):
27
- def fork(self, context: 'ReceiverContext') -> 'ReceiverContext':
28
- ...
51
+ def fork(self, context: "ReceiverContext") -> "ReceiverContext": ...
29
52
 
30
- def receive(self, before: Optional[T], ctx: 'ReceiverContext') -> object:
31
- ...
53
+ def receive(self, before: Optional[T], ctx: "ReceiverContext") -> object: ...
32
54
 
33
55
 
34
56
  class OmniReceiver(Receiver):
35
- def fork(self, ctx: 'ReceiverContext') -> 'ReceiverContext':
57
+ def fork(self, ctx: "ReceiverContext") -> "ReceiverContext":
36
58
  raise NotImplementedError("Cannot fork OmniReceiver")
37
59
 
38
- def receive(self, before: Optional['Tree'], ctx: 'ReceiverContext') -> 'Tree':
60
+ def receive(self, before: Optional["Tree"], ctx: "ReceiverContext") -> "Tree":
39
61
  visitor = self.Visitor()
40
62
  return visitor.visit(before, ctx)
41
63
 
42
- class Visitor(TreeVisitor[Tree, 'ReceiverContext']):
43
- def visit(self, tree: Optional[Tree], ctx: 'ReceiverContext', parent: Optional[Cursor] = None) -> Optional[
44
- Tree]:
45
- self.cursor = Cursor(self.cursor, tree)
64
+ class Visitor(TreeVisitor[Tree, "ReceiverContext"]): # type: ignore
65
+ def visit(
66
+ self,
67
+ tree: Optional[Tree],
68
+ ctx: "ReceiverContext",
69
+ parent: Optional[Cursor] = None,
70
+ ) -> Optional[Tree]:
71
+ self.cursor = Cursor(self.cursor, tree) # type: ignore
46
72
  tree = ctx.polymorphic_receive_tree(tree)
47
73
  self.cursor = self.cursor.parent
48
74
  return tree
49
75
 
50
76
 
51
77
  class TreeReceiver(Protocol):
52
- def receive_node(self) -> DiffEvent:
53
- ...
78
+ def receive_node(self) -> DiffEvent: ...
54
79
 
55
- def receive_value(self, expected_type: Type) -> DiffEvent:
56
- ...
80
+ def receive_value(self, expected_type: Type[Any]) -> DiffEvent: ...
57
81
 
58
82
 
59
83
  class ReceiverFactory(Protocol):
60
- def create(self, type_name: str, ctx: 'ReceiverContext') -> Tree:
61
- ...
84
+ def create(self, type_name: Optional[str], ctx: "ReceiverContext") -> Tree: ...
62
85
 
63
86
 
64
87
  class DetailsReceiver(Protocol[T]):
65
- def receive_details(self, before: Optional[T], type: Optional[Type[T]], ctx: 'ReceiverContext') -> T:
88
+ def receive_details(
89
+ self, before: Optional[T], type: Optional[Type[T]], ctx: "ReceiverContext"
90
+ ) -> T:
66
91
  pass
67
92
 
68
93
 
69
94
  class ReceiverContext:
70
- Registry: Dict[Type, Callable[[], Receiver]] = OrderedDict()
71
-
72
- def __init__(self, receiver: TreeReceiver, visitor: Optional[TreeVisitor] = None,
73
- factory: ReceiverFactory = None):
95
+ Registry: Dict[Type[Any], Callable[[], Receiver]] = OrderedDict()
96
+
97
+ def __init__(
98
+ self,
99
+ receiver: TreeReceiver,
100
+ visitor: Optional[TreeVisitor] = None,
101
+ factory: Optional[ReceiverFactory] = None,
102
+ ):
74
103
  self.receiver = receiver
75
104
  self.visitor = visitor
76
105
  self.factory = factory
77
106
 
78
- def fork(self, visitor: TreeVisitor, factory: ReceiverFactory) -> 'ReceiverContext':
107
+ def fork(self, visitor: TreeVisitor, factory: ReceiverFactory) -> "ReceiverContext":
79
108
  return ReceiverContext(self.receiver, visitor, factory)
80
109
 
81
- def receive_any_tree(self, before) -> T:
82
- return OmniReceiver().receive(before, self)
110
+ def receive_any_tree(self, before: Optional[T]) -> Optional[T]:
111
+ return cast(Optional[T], OmniReceiver().receive(before, self))
83
112
 
84
- def receive_tree(self, before: Optional[Tree], tree_type: Optional[str], ctx: 'ReceiverContext') -> Tree:
113
+ def receive_tree(
114
+ self, before: Optional[Tree], tree_type: Optional[str], ctx: "ReceiverContext"
115
+ ) -> Tree:
85
116
  if before:
86
117
  return before.accept(self.visitor, ctx)
87
118
  else:
88
- return self.factory.create(tree_type, ctx)
119
+ if self.factory is not None:
120
+ return self.factory.create(tree_type, ctx)
121
+ raise ValueError("Factory is not defined")
89
122
 
90
123
  def polymorphic_receive_tree(self, before: Optional[Tree]) -> Optional[Tree]:
91
124
  diff_event = self.receiver.receive_node()
92
125
  if diff_event.event_type in (EventType.Add, EventType.Update):
93
- tree_receiver = self.new_receiver(diff_event.concrete_type or type(before).__name__)
126
+ tree_receiver = self.new_receiver(
127
+ diff_event.concrete_type or type(before).__name__
128
+ )
94
129
  forked = tree_receiver.fork(self)
95
130
  return forked.receive_tree(
96
131
  None if diff_event.event_type == EventType.Add else before,
97
- diff_event.concrete_type, forked
132
+ diff_event.concrete_type,
133
+ forked,
98
134
  )
99
135
  elif diff_event.event_type == EventType.Delete:
100
136
  return None
@@ -108,8 +144,11 @@ class ReceiverContext:
108
144
  return factory()
109
145
  raise ValueError(f"Unsupported receiver type: {type_name}")
110
146
 
111
- def receive_node(self, before: Optional[A],
112
- details: Callable[[Optional[A], Optional[str], 'ReceiverContext'], A]) -> Optional[A]:
147
+ def receive_node(
148
+ self,
149
+ before: Optional[A],
150
+ details: Callable[[Optional[A], Optional[str], "ReceiverContext"], A],
151
+ ) -> Optional[A]:
113
152
  evt = self.receiver.receive_node()
114
153
  if evt.event_type == EventType.Delete:
115
154
  return None
@@ -119,24 +158,34 @@ class ReceiverContext:
119
158
  return details(before, evt.concrete_type, self)
120
159
  return before
121
160
 
122
- def receive_markers(self, before: Optional[Markers], type: Optional[str], ctx: 'ReceiverContext') -> Markers:
123
- id_ = self.receive_value(getattr(before, 'id', None), UUID)
124
- after_markers = self.receive_values(getattr(before, 'markers', None), Marker)
161
+ def receive_markers(
162
+ self, before: Optional[Markers], type: Optional[str], ctx: "ReceiverContext"
163
+ ) -> Markers:
164
+ id_ = self.receive_value(getattr(before, "id", None), UUID)
165
+ after_markers: Optional[List[Marker]] = self.receive_values(
166
+ getattr(before, "markers", None), Marker
167
+ )
125
168
  if before:
126
169
  return before.with_id(id_).with_markers(after_markers)
127
170
  else:
128
171
  return Markers(id_, after_markers)
129
172
 
130
- def receive_nodes(self, before: Optional[List[A]], details: Callable[[Optional[A], Optional[str], 'ReceiverContext'], A]) -> Optional[List[A]]:
173
+ def receive_nodes(
174
+ self,
175
+ before: Optional[List[A]],
176
+ details: Callable[[Optional[A], Optional[str], "ReceiverContext"], A],
177
+ ) -> Optional[List[A]]:
131
178
  return remote_utils.receive_nodes(before, details, self)
132
179
 
133
- def receive_values(self, before: Optional[List[V]], type: Type) -> Optional[List[V]]:
180
+ def receive_values(
181
+ self, before: Optional[List[V]], type: Type[Any]
182
+ ) -> Optional[List[V]]:
134
183
  return remote_utils.receive_values(before, type, self)
135
184
 
136
- def receive_value(self, before: Optional[V], type: Type) -> Optional[V]:
185
+ def receive_value(self, before: Optional[V], type: Type[Any]) -> Optional[V]:
137
186
  return self.receive_value0(before, type)
138
187
 
139
- def receive_value0(self, before: Optional[V], type: Type) -> Optional[V]:
188
+ def receive_value0(self, before: Optional[V], type: Type[Any]) -> Optional[V]:
140
189
  evt = self.receiver.receive_value(type)
141
190
  if evt.event_type in (EventType.Update, EventType.Add):
142
191
  return evt.msg
@@ -145,67 +194,136 @@ class ReceiverContext:
145
194
  return before
146
195
 
147
196
  @staticmethod
148
- def register(type_: Type, receiver_factory: Callable[[], Receiver]):
197
+ def register(type_: Type[Any], receiver_factory: Callable[[], Receiver]) -> None:
149
198
  ReceiverContext.Registry[type_] = receiver_factory
150
199
 
151
200
 
152
201
  class ValueDeserializer(Protocol):
153
- def deserialize(self, type_: Optional[Type], reader: CBORDecoder, context: 'DeserializationContext') -> Optional[Any]:
154
- ...
202
+ def deserialize(
203
+ self,
204
+ type_: Optional[Type[Any]],
205
+ reader: CBORDecoder,
206
+ context: "DeserializationContext",
207
+ ) -> Optional[Any]: ...
155
208
 
156
209
 
157
210
  class DefaultValueDeserializer(ValueDeserializer):
158
- def deserialize(self, expected_type: Optional[Type], reader: CBORDecoder, context: 'DeserializationContext') -> Any:
159
- cbor_map = reader.decode_map()
160
- error_message = "No deserializer found for: " + ", ".join(f"{k}: {v}" for k, v in cbor_map.items())
211
+ def deserialize(
212
+ self,
213
+ expected_type: Optional[Type[Any]],
214
+ reader: CBORDecoder,
215
+ context: "DeserializationContext",
216
+ ) -> Any:
217
+ cbor_map = reader.decode_map(subtype=None)
218
+ error_message = "No deserializer found for: " + ", ".join(
219
+ f"{k}: {v}" for k, v in cbor_map.items()
220
+ )
161
221
  raise NotImplementedError(error_message)
162
222
 
163
223
 
164
224
  class DeserializationContext:
165
225
  DefaultDeserializer = DefaultValueDeserializer()
166
226
 
167
- def __init__(self, remoting_context: 'RemotingContext', value_deserializers: Optional[List[Tuple[Type, ValueDeserializer]]] = None):
227
+ value_deserializers: List[Tuple[Type, ValueDeserializer]]
228
+
229
+ def __init__(
230
+ self,
231
+ remoting_context: "RemotingContext",
232
+ value_deserializers: Optional[List[Tuple[Type, ValueDeserializer]]] = None,
233
+ ):
168
234
  self.remoting_context = remoting_context
169
235
  self.value_deserializers = value_deserializers or []
170
236
 
171
- def deserialize(self, expected_type: Type, decoder: CBORDecoder) -> Any:
172
- value = decoder.decode()
173
- if value is None:
174
- return None
175
-
237
+ def deserialize(self, expected_type: Type[Any], decoder: CBORDecoder) -> Any:
176
238
  if expected_type == UUID:
177
- return UUID(bytes=value)
239
+ return UUID(bytes=decoder.decode()) # type: ignore
178
240
 
179
241
  if expected_type == str:
180
- return value
242
+ return decoder.decode()
181
243
 
182
244
  if expected_type == bool:
183
- return value
245
+ return decoder.decode()
184
246
 
185
247
  if expected_type == list:
186
- return value
248
+ return decoder.decode()
187
249
 
188
250
  if expected_type == int:
189
- return value
251
+ return decoder.decode()
190
252
 
191
253
  if expected_type == Path:
192
- return Path(value)
254
+ return Path(decoder.decode()) # type: ignore
193
255
 
194
256
  if expected_type == float:
195
- return value
257
+ return decoder.decode()
196
258
 
197
259
  if issubclass(expected_type, Enum):
198
- return expected_type(value)
260
+ return expected_type(decoder.decode())
261
+
262
+ initial_byte = decoder.read(1)[0]
263
+ major_type = initial_byte >> 5
264
+ subtype = initial_byte & 31
265
+
266
+ # Object ID for Marker, JavaType, etc.
267
+ if major_type == 0:
268
+ obj_id = major_decoders[major_type](decoder, subtype)
269
+ return self.remoting_context.get_object(obj_id)
270
+
271
+ # arrays
272
+ elif major_type == 4:
273
+ pass
274
+
275
+ # objects
276
+ elif major_type == 5:
277
+ field_name = decoder.decode()
278
+ assert field_name == "@c"
279
+ concrete_type = decoder.decode()
280
+
281
+ if concrete_type == "org.openrewrite.marker.SearchResult":
282
+ field_name = decoder.decode()
283
+ if field_name == "description":
284
+ desc = decoder.decode()
285
+ elif field_name == "id":
286
+ id_ = UUID(bytes=decoder.decode())
287
+ field_name = decoder.decode()
288
+ if field_name == "description":
289
+ desc = decoder.decode()
290
+ elif field_name == "id":
291
+ id_ = UUID(bytes=decoder.decode())
292
+ return SearchResult(id_, desc)
293
+
294
+ for type, value_deserializer in self.value_deserializers:
295
+ pass
296
+
297
+ initial_byte = decoder.read(1)[0]
298
+ major_type = initial_byte >> 5
299
+ subtype = initial_byte & 31
300
+ while initial_byte != 0xFF:
301
+ field_name = decoder.decode_string(subtype)
302
+ field_value = decoder.decode()
303
+ initial_byte = decoder.read(1)[0]
304
+ major_type = initial_byte >> 5
305
+ subtype = initial_byte & 31
306
+ pass
307
+
308
+ else:
309
+ return major_decoders[major_type](decoder, subtype)
199
310
 
200
311
  state = decoder.peek_state()
201
312
 
202
- if state in {cbor2.CborReaderState.BOOLEAN, cbor2.CborReaderState.UNSIGNED_INT,
203
- cbor2.CborReaderState.NEGATIVE_INT}:
313
+ if state in {
314
+ cbor2.CborReaderState.BOOLEAN,
315
+ cbor2.CborReaderState.UNSIGNED_INT,
316
+ cbor2.CborReaderState.NEGATIVE_INT,
317
+ }:
204
318
  result = decoder.read_int()
205
319
  obj = self.remoting_context.get_object(result)
206
320
  return obj if obj is not None else result
207
321
 
208
- if state in {cbor2.CborReaderState.HALF_FLOAT, cbor2.CborReaderState.FLOAT, cbor2.CborReaderState.DOUBLE}:
322
+ if state in {
323
+ cbor2.CborReaderState.HALF_FLOAT,
324
+ cbor2.CborReaderState.FLOAT,
325
+ cbor2.CborReaderState.DOUBLE,
326
+ }:
209
327
  return decoder.read_double()
210
328
 
211
329
  if state == cbor2.CborReaderState.TEXT_STRING:
@@ -213,7 +331,11 @@ class DeserializationContext:
213
331
  if decoder.peek_state() == cbor2.CborReaderState.END_ARRAY:
214
332
  return str_value
215
333
 
216
- concrete_type = str_value if decoder.peek_state() != cbor2.CborReaderState.END_ARRAY else expected_type.__name__
334
+ concrete_type = (
335
+ str_value
336
+ if decoder.peek_state() != cbor2.CborReaderState.END_ARRAY
337
+ else expected_type.__name__
338
+ )
217
339
 
218
340
  if concrete_type == "org.openrewrite.FileAttributes":
219
341
  map = decoder.read_cbor_map()
@@ -226,7 +348,7 @@ class DeserializationContext:
226
348
  if concrete_type == "java.lang.Integer":
227
349
  return decoder.decode()
228
350
  if concrete_type == "java.lang.Character":
229
- return decoder.decode()[0]
351
+ return decoder.decode()[0] # type: ignore
230
352
  if concrete_type == "java.lang.Long":
231
353
  return decoder.decode()
232
354
  if concrete_type == "java.lang.Double":
@@ -238,7 +360,9 @@ class DeserializationContext:
238
360
  if concrete_type == "java.math.BigDecimal":
239
361
  return decoder.decode()
240
362
 
241
- raise NotImplementedError(f"No deserialization implemented for: {concrete_type}")
363
+ raise NotImplementedError(
364
+ f"No deserialization implemented for: {concrete_type}"
365
+ )
242
366
 
243
367
  if state == cbor2.CborReaderState.ARRAY:
244
368
  decoder.read_array_start()
@@ -261,7 +385,10 @@ class DeserializationContext:
261
385
 
262
386
  concrete_type = marker_map["@c"]
263
387
 
264
- if concrete_type in {"org.openrewrite.marker.SearchResult", "Rewrite.Core.Marker.SearchResult"}:
388
+ if concrete_type in {
389
+ "org.openrewrite.marker.SearchResult",
390
+ "Rewrite.Core.Marker.SearchResult",
391
+ }:
265
392
  desc = marker_map.get("description", None)
266
393
  marker = SearchResult(UUID(bytes=marker_map["id"]), desc)
267
394
  else:
@@ -280,7 +407,9 @@ class DeserializationContext:
280
407
  for type_, deserializer in self.value_deserializers:
281
408
  if issubclass(actual_type, type_):
282
409
  return deserializer.deserialize(actual_type, decoder, self)
283
- raise NotImplementedError(f"No deserialization implemented for: {expected_type}")
410
+ raise NotImplementedError(
411
+ f"No deserialization implemented for: {expected_type}"
412
+ )
284
413
 
285
414
 
286
415
  class JsonReceiver(TreeReceiver):
@@ -301,10 +430,19 @@ class JsonReceiver(TreeReceiver):
301
430
  concrete_type = None
302
431
 
303
432
  if event_type in {EventType.Add, EventType.Update}:
304
- if event_type == EventType.Add and isinstance(array[1], str):
433
+ if (
434
+ event_type == EventType.Add
435
+ and len(array) > 1
436
+ and isinstance(array[1], str)
437
+ ):
305
438
  concrete_type = array[1]
306
439
 
307
- elif event_type not in {EventType.Delete, EventType.NoChange, EventType.StartList, EventType.EndList}:
440
+ elif event_type not in {
441
+ EventType.Delete,
442
+ EventType.NoChange,
443
+ EventType.StartList,
444
+ EventType.EndList,
445
+ }:
308
446
  raise NotImplementedError(event_type)
309
447
 
310
448
  if self.DEBUG:
@@ -322,14 +460,18 @@ class JsonReceiver(TreeReceiver):
322
460
  concrete_type = None
323
461
 
324
462
  if event_type in {EventType.Add, EventType.Update}:
325
- if (bool(get_args(expected_type)) and
326
- issubclass(expected_type, (list, Iterable))):
463
+ if bool(expected_type) and issubclass(expected_type, (list, Iterable)):
327
464
  # special case for list events
328
465
  msg = self._decoder.decode()
329
466
  else:
330
467
  msg = self._context.deserialize(expected_type, self._decoder)
331
468
 
332
- elif event_type not in {EventType.Delete, EventType.NoChange, EventType.StartList, EventType.EndList}:
469
+ elif event_type not in {
470
+ EventType.Delete,
471
+ EventType.NoChange,
472
+ EventType.StartList,
473
+ EventType.EndList,
474
+ }:
333
475
  raise NotImplementedError(event_type)
334
476
 
335
477
  if length is None:
@@ -359,19 +501,34 @@ class ParseErrorReceiver(Receiver):
359
501
 
360
502
  def visit_parse_error(self, parse_error, ctx):
361
503
  parse_error = parse_error.with_id(ctx.receive_value(parse_error.id))
362
- parse_error = parse_error.with_markers(ctx.receive_node(parse_error.markers, ctx.receive_markers))
363
- parse_error = parse_error.with_source_path(ctx.receive_value(parse_error.source_path))
364
- parse_error = parse_error.with_file_attributes(ctx.receive_value(parse_error.file_attributes))
365
- parse_error = parse_error.with_charset_name(ctx.receive_value(parse_error.charset_name))
366
- parse_error = parse_error.with_charset_bom_marked(ctx.receive_value(parse_error.charset_bom_marked))
367
- parse_error = parse_error.with_checksum(ctx.receive_value(parse_error.checksum))
504
+ parse_error = parse_error.with_markers(
505
+ ctx.receive_node(parse_error.markers, ctx.receive_markers)
506
+ )
507
+ parse_error = parse_error.with_source_path(
508
+ ctx.receive_value(parse_error.source_path)
509
+ )
510
+ parse_error = parse_error.with_file_attributes(
511
+ ctx.receive_value(parse_error.file_attributes)
512
+ )
513
+ parse_error = parse_error.with_charset_name(
514
+ ctx.receive_value(parse_error.charset_name)
515
+ )
516
+ parse_error = parse_error.with_charset_bom_marked(
517
+ ctx.receive_value(parse_error.charset_bom_marked)
518
+ )
519
+ parse_error = parse_error.with_checksum(
520
+ ctx.receive_value(parse_error.checksum)
521
+ )
368
522
  parse_error = parse_error.with_text(ctx.receive_value(parse_error.text))
369
523
  # parse_error = parse_error.with_erroneous(ctx.receive_tree(parse_error.erroneous))
370
524
  return parse_error
371
525
 
372
526
  class Factory(ReceiverFactory):
373
527
  def create(self, type_, ctx):
374
- if type_ in ["rewrite.parser.ParseError", "org.openrewrite.tree.ParseError"]:
528
+ if type_ in [
529
+ "rewrite.parser.ParseError",
530
+ "org.openrewrite.tree.ParseError",
531
+ ]:
375
532
  return ParseError(
376
533
  ctx.receive_value(None),
377
534
  ctx.receive_node(None, ctx.receive_markers),
@@ -381,6 +538,6 @@ class ParseErrorReceiver(Receiver):
381
538
  ctx.receive_value(None),
382
539
  ctx.receive_value(None),
383
540
  ctx.receive_value(None),
384
- None # ctx.receive_tree(None)
541
+ None, # ctx.receive_tree(None)
385
542
  )
386
543
  raise NotImplementedError("No factory method for type: " + type_)