openrewrite-remote 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openrewrite_remote-0.2.0.dist-info/METADATA +16 -0
- openrewrite_remote-0.2.0.dist-info/RECORD +18 -0
- openrewrite_remote-0.2.0.dist-info/WHEEL +4 -0
- rewrite/remote/__init__.py +7 -0
- rewrite/remote/client.py +40 -0
- rewrite/remote/event.py +20 -0
- rewrite/remote/java/extensions.py +176 -0
- rewrite/remote/java/receiver.py +1376 -0
- rewrite/remote/java/sender.py +659 -0
- rewrite/remote/python/extensions.py +179 -0
- rewrite/remote/python/receiver.py +1873 -0
- rewrite/remote/python/sender.py +894 -0
- rewrite/remote/receiver.py +386 -0
- rewrite/remote/remote_utils.py +280 -0
- rewrite/remote/remoting.py +342 -0
- rewrite/remote/sender.py +424 -0
- rewrite/remote/server.py +188 -0
- rewrite/remote/type_utils.py +60 -0
rewrite/remote/sender.py
ADDED
@@ -0,0 +1,424 @@
|
|
1
|
+
from __future__ import absolute_import
|
2
|
+
|
3
|
+
import decimal
|
4
|
+
from abc import ABC, abstractmethod
|
5
|
+
from dataclasses import fields
|
6
|
+
from pathlib import Path
|
7
|
+
from typing import Dict, Type, Any, Callable, ClassVar, TypeVar, Generic, Protocol, List, TYPE_CHECKING
|
8
|
+
from uuid import UUID
|
9
|
+
|
10
|
+
import cbor2
|
11
|
+
from cbor2 import CBOREncoder
|
12
|
+
from rewrite import Tree, Markers, Marker, ParseErrorVisitor, Cursor
|
13
|
+
from rewrite.remote.type_utils import to_java_type_name
|
14
|
+
from rewrite.visitor import TreeVisitor
|
15
|
+
|
16
|
+
from . import remote_utils
|
17
|
+
from .event import *
|
18
|
+
from .remote_utils import Operation
|
19
|
+
|
20
|
+
if TYPE_CHECKING:
|
21
|
+
from .remoting import RemotingContext
|
22
|
+
|
23
|
+
A = TypeVar('A')
|
24
|
+
T = TypeVar('T', bound=Tree)
|
25
|
+
V = TypeVar('V')
|
26
|
+
I = TypeVar('I')
|
27
|
+
|
28
|
+
|
29
|
+
class Sender(Protocol):
|
30
|
+
def send(self, after: T, before: Optional[T], ctx: 'SenderContext') -> None:
|
31
|
+
...
|
32
|
+
|
33
|
+
|
34
|
+
class OmniSender(Sender):
|
35
|
+
def send(self, after: Tree, before: Optional[Tree], ctx: 'SenderContext'):
|
36
|
+
sender = ctx.new_sender(type(after))
|
37
|
+
sender.send(after, before, ctx)
|
38
|
+
|
39
|
+
|
40
|
+
class TreeSender(Protocol):
|
41
|
+
def send_node(self, diff_event: DiffEvent, visitor: Callable[['TreeSender'], None]):
|
42
|
+
...
|
43
|
+
|
44
|
+
def send_value(self, diff_event: DiffEvent):
|
45
|
+
...
|
46
|
+
|
47
|
+
def flush(self):
|
48
|
+
...
|
49
|
+
|
50
|
+
|
51
|
+
class SenderContext(Generic[T]):
|
52
|
+
Registry: Dict[Type, Callable[[], Sender]] = {}
|
53
|
+
|
54
|
+
def __init__(self, sender: 'TreeSender', visitor: TreeVisitor = None, before: Optional[Any] = None):
|
55
|
+
self.sender = sender
|
56
|
+
self.visitor = visitor
|
57
|
+
self.before = before
|
58
|
+
|
59
|
+
def new_sender(self, type_: Type) -> Sender:
|
60
|
+
for entry_type, factory in self.Registry.items():
|
61
|
+
# FIXME find better solution
|
62
|
+
try:
|
63
|
+
if type_.__bases__.__contains__(entry_type) or issubclass(type_, entry_type):
|
64
|
+
return factory()
|
65
|
+
except:
|
66
|
+
pass
|
67
|
+
raise ValueError(f"Unsupported sender type: {type_}")
|
68
|
+
|
69
|
+
def fork(self, visitor: TreeVisitor, before: Optional[Any] = None) -> 'SenderContext':
|
70
|
+
return SenderContext(self.sender, visitor, before)
|
71
|
+
|
72
|
+
def visit(self, consumer: Callable[[V, 'SenderContext'], None], after: V, before: Optional[V] = None):
|
73
|
+
save_before = self.before
|
74
|
+
self.before = before
|
75
|
+
consumer(after, self)
|
76
|
+
self.before = save_before
|
77
|
+
|
78
|
+
def send_tree(self, after: T, ctx: 'SenderContext'):
|
79
|
+
after.accept(self.visitor, ctx)
|
80
|
+
|
81
|
+
def send_any_tree(self, after: T, before: Optional[T]):
|
82
|
+
OmniSender().send(after, before, self)
|
83
|
+
|
84
|
+
def send_node(self, owner: A, extractor: Callable[[A], Optional[V]], details: Callable[[V, 'SenderContext'], None]):
|
85
|
+
self.send_node_internal(extractor(owner), extractor(self.before) if self.before is not None else None, details)
|
86
|
+
|
87
|
+
def send_value(self, owner, value_extractor: Callable[[T], Optional[V]]):
|
88
|
+
after_value = value_extractor(owner)
|
89
|
+
before_value = value_extractor(self.before) if self.before is not None else None
|
90
|
+
self.send_value_internal(after_value, before_value)
|
91
|
+
|
92
|
+
def send_typed_value(self, owner: A, value_extractor: Callable[[A], V]):
|
93
|
+
after_value = value_extractor(owner)
|
94
|
+
before_value = value_extractor(self.before) if self.before is not None else None
|
95
|
+
self.send_typed_value_internal(after_value, before_value)
|
96
|
+
|
97
|
+
def send_list_event(self, after: Optional[List[V]], before: Optional[List[V]]) -> bool:
|
98
|
+
if after == before:
|
99
|
+
evt = DiffEvent(EventType.NoChange, None, None)
|
100
|
+
elif before is None:
|
101
|
+
evt = DiffEvent(EventType.Add, None, len(after) if after is not None else 0)
|
102
|
+
elif after is None:
|
103
|
+
evt = DiffEvent(EventType.Delete, None, None)
|
104
|
+
else:
|
105
|
+
evt = DiffEvent(EventType.Update, None, len(after))
|
106
|
+
|
107
|
+
self.sender.send_value(evt)
|
108
|
+
return evt.event_type != EventType.NoChange and evt.event_type != EventType.Delete
|
109
|
+
|
110
|
+
def send_nodes(self, owner: A, element_extractor: Callable[[A], List[V]],
|
111
|
+
details: Callable[[V, 'SenderContext'], None], id_function: Callable[[V], I]):
|
112
|
+
after_list = element_extractor(owner)
|
113
|
+
before_list = element_extractor(self.before) if self.before is not None else None
|
114
|
+
|
115
|
+
if self.send_list_event(after_list, before_list):
|
116
|
+
if before_list is not None:
|
117
|
+
self.sender.send_value(DiffEvent(EventType.StartList, None, None))
|
118
|
+
|
119
|
+
remote_utils.calculate_list_diff(
|
120
|
+
before_list or [],
|
121
|
+
after_list,
|
122
|
+
id_function,
|
123
|
+
lambda op, _1, _2, bv, av: {
|
124
|
+
Operation.Delete: lambda: self.send_node_internal(av, bv, details),
|
125
|
+
Operation.NoChange: lambda: self.send_node_internal(av, bv, details),
|
126
|
+
Operation.Add: lambda: self.send_node_internal(av, bv, details),
|
127
|
+
Operation.Update: lambda: self.send_node_internal(av, bv, details),
|
128
|
+
Operation.Move: lambda: NotImplementedError("Unexpected operation: " + str(op))
|
129
|
+
}[op]()
|
130
|
+
)
|
131
|
+
|
132
|
+
if before_list is not None:
|
133
|
+
self.sender.send_value(DiffEvent(EventType.EndList, None, None))
|
134
|
+
|
135
|
+
def send_values(self, owner, value_extractor: Callable[[T], List[V]], id_function: Callable[[V], I]):
|
136
|
+
after_list = value_extractor(owner)
|
137
|
+
before_list = value_extractor(self.before) if self.before is not None else None
|
138
|
+
|
139
|
+
if self.send_list_event(after_list, before_list):
|
140
|
+
if before_list is not None:
|
141
|
+
self.sender.send_value(DiffEvent(EventType.StartList, None, None))
|
142
|
+
|
143
|
+
remote_utils.calculate_list_diff(
|
144
|
+
before_list or [],
|
145
|
+
after_list,
|
146
|
+
id_function,
|
147
|
+
lambda op, _1, _2, bv, av: {
|
148
|
+
Operation.Delete: lambda: self.send_value_internal(av, bv),
|
149
|
+
Operation.NoChange: lambda: self.send_value_internal(av, bv),
|
150
|
+
Operation.Add: lambda: self.send_value_internal(av, bv),
|
151
|
+
Operation.Update: lambda: self.send_value_internal(av, bv),
|
152
|
+
Operation.Move: lambda: NotImplementedError("Unexpected operation: " + str(op))
|
153
|
+
}[op]()
|
154
|
+
)
|
155
|
+
|
156
|
+
if before_list is not None:
|
157
|
+
self.sender.send_value(DiffEvent(EventType.EndList, None, None))
|
158
|
+
|
159
|
+
def send_typed_values(self, owner: T, value_extractor: Callable[[T], List[V]], id_function: Callable[[V], I]):
|
160
|
+
after_list = value_extractor(owner)
|
161
|
+
before_list = value_extractor(self.before) if self.before is not None else None
|
162
|
+
|
163
|
+
if self.send_list_event(after_list, before_list):
|
164
|
+
if before_list is not None:
|
165
|
+
self.sender.send_value(DiffEvent(EventType.StartList, None, None))
|
166
|
+
|
167
|
+
remote_utils.calculate_list_diff(
|
168
|
+
before_list or [],
|
169
|
+
after_list,
|
170
|
+
id_function,
|
171
|
+
lambda op, _1, _2, bv, av: {
|
172
|
+
Operation.Delete: lambda: self.send_typed_value_internal(av, bv),
|
173
|
+
Operation.NoChange: lambda: self.send_typed_value_internal(av, bv),
|
174
|
+
Operation.Add: lambda: self.send_typed_value_internal(av, bv),
|
175
|
+
Operation.Update: lambda: self.send_typed_value_internal(av, bv),
|
176
|
+
Operation.Move: lambda: NotImplementedError("Unexpected operation: " + str(op))
|
177
|
+
}[op]()
|
178
|
+
)
|
179
|
+
|
180
|
+
if before_list is not None:
|
181
|
+
self.sender.send_value(DiffEvent(EventType.EndList, None, None))
|
182
|
+
|
183
|
+
def send_markers(self, markers: Markers, ignore):
|
184
|
+
self.send_value(markers, lambda ms: ms.id)
|
185
|
+
self.send_values(markers, lambda ms: ms.markers, lambda ms: ms.id)
|
186
|
+
|
187
|
+
def send_tree_visitor(self, after: T, ctx: 'SenderContext'):
|
188
|
+
after.accept(self.visitor, ctx)
|
189
|
+
|
190
|
+
@staticmethod
|
191
|
+
def register(type_: Type, sender_factory: Callable[[], Sender]):
|
192
|
+
SenderContext.Registry[type_] = sender_factory
|
193
|
+
|
194
|
+
@staticmethod
|
195
|
+
def are_equal(after: Optional[V], before: Optional[V]) -> bool:
|
196
|
+
if after is None or before is None:
|
197
|
+
return after == before
|
198
|
+
|
199
|
+
return isinstance(after, (Tree, Markers)) or isinstance(before, (Tree, Markers))
|
200
|
+
|
201
|
+
def send_node_internal(self, after: Optional[V], before: Optional[V],
|
202
|
+
details: Callable[[V, 'SenderContext'], None]):
|
203
|
+
if self.are_equal(after, before):
|
204
|
+
evt = DiffEvent(EventType.NoChange, None, None)
|
205
|
+
elif before is None:
|
206
|
+
concrete_type = to_java_type_name(type(after)) if after is not None else None
|
207
|
+
evt = DiffEvent(EventType.Add, concrete_type, None)
|
208
|
+
elif after is None:
|
209
|
+
evt = DiffEvent(EventType.Delete, None, None)
|
210
|
+
else:
|
211
|
+
evt = DiffEvent(EventType.Update, None, None)
|
212
|
+
|
213
|
+
self.sender.send_node(evt, lambda _: self.visit(details, after, before))
|
214
|
+
|
215
|
+
def send_value_internal(self, after: V, before: Optional[V]):
|
216
|
+
if self.before is not None and self.are_equal(after, before):
|
217
|
+
evt = DiffEvent(EventType.NoChange, None, None)
|
218
|
+
elif self.before is None or before is None:
|
219
|
+
concrete_type = to_java_type_name(type(after)) if isinstance(after, Marker) else None
|
220
|
+
evt = DiffEvent(EventType.Add, concrete_type, after)
|
221
|
+
elif after is None:
|
222
|
+
evt = DiffEvent(EventType.Delete, None, None)
|
223
|
+
else:
|
224
|
+
evt = DiffEvent(EventType.Update, None, after)
|
225
|
+
|
226
|
+
self.sender.send_value(evt)
|
227
|
+
|
228
|
+
def send_typed_value_internal(self, after: V, before: Optional[V]):
|
229
|
+
if self.before is not None and self.are_equal(after, before):
|
230
|
+
evt = DiffEvent(EventType.NoChange, None, None)
|
231
|
+
elif self.before is None or before is None:
|
232
|
+
concrete_type = to_java_type_name(type(after)) if after is not None else None
|
233
|
+
evt = DiffEvent(EventType.Add, concrete_type, after)
|
234
|
+
elif after is None:
|
235
|
+
evt = DiffEvent(EventType.Delete, None, None)
|
236
|
+
else:
|
237
|
+
evt = DiffEvent(EventType.Update, None, after)
|
238
|
+
|
239
|
+
self.sender.send_value(evt)
|
240
|
+
|
241
|
+
|
242
|
+
class SerializationContext:
|
243
|
+
def __init__(self, remoting_context: 'RemotingContext',
|
244
|
+
value_serializers: Optional[Dict[Type, 'ValueSerializer']] = None):
|
245
|
+
self.remoting_context = remoting_context
|
246
|
+
self.value_serializers = value_serializers or {}
|
247
|
+
|
248
|
+
def serialize(self, value: Any, type_name: Optional[str], encoder: CBOREncoder):
|
249
|
+
if value is None:
|
250
|
+
encoder.encode_none(None)
|
251
|
+
return
|
252
|
+
|
253
|
+
for type_cls, serializer in self.value_serializers.items():
|
254
|
+
if isinstance(value, type_cls):
|
255
|
+
serializer.serialize(value, type_name, encoder, self)
|
256
|
+
return
|
257
|
+
|
258
|
+
DefaultValueSerializer().serialize(value, type_name, encoder, self)
|
259
|
+
|
260
|
+
|
261
|
+
class JsonSender(TreeSender):
|
262
|
+
Debug = False
|
263
|
+
|
264
|
+
def __init__(self, stream, context: SerializationContext):
|
265
|
+
self._stream = stream
|
266
|
+
self._context = context
|
267
|
+
self._encoder = cbor2.CBOREncoder(self._stream)
|
268
|
+
|
269
|
+
def send_node(self, diff_event, visitor):
|
270
|
+
|
271
|
+
if diff_event.event_type in (EventType.Add, EventType.Update):
|
272
|
+
self._encoder.encode(
|
273
|
+
[diff_event.event_type.value] if diff_event.concrete_type is None else [diff_event.event_type.value,
|
274
|
+
diff_event.concrete_type])
|
275
|
+
|
276
|
+
if self.Debug:
|
277
|
+
print(f"SEND: {diff_event}")
|
278
|
+
|
279
|
+
visitor(self)
|
280
|
+
elif diff_event.event_type in (EventType.Delete, EventType.NoChange):
|
281
|
+
self._encoder.encode([diff_event.event_type.value])
|
282
|
+
|
283
|
+
if self.Debug:
|
284
|
+
print(f"SEND: {diff_event}")
|
285
|
+
else:
|
286
|
+
raise NotImplementedError()
|
287
|
+
|
288
|
+
def send_value(self, diff_event):
|
289
|
+
if diff_event.event_type in (EventType.Add, EventType.Update):
|
290
|
+
self._encoder.encode_length(4, 3 if diff_event.concrete_type is not None else 2)
|
291
|
+
self._encoder.encode_int(diff_event.event_type.value)
|
292
|
+
if diff_event.concrete_type is not None:
|
293
|
+
self._encoder.encode(diff_event.concrete_type)
|
294
|
+
self._context.serialize(diff_event.msg, diff_event.concrete_type, self._encoder)
|
295
|
+
elif diff_event.event_type in (EventType.Delete, EventType.NoChange, EventType.StartList, EventType.EndList):
|
296
|
+
self._encoder.encode([diff_event.event_type.value])
|
297
|
+
else:
|
298
|
+
raise NotImplementedError()
|
299
|
+
|
300
|
+
if self.Debug:
|
301
|
+
print(f"SEND: {diff_event}")
|
302
|
+
|
303
|
+
def flush(self):
|
304
|
+
self._stream.flush()
|
305
|
+
|
306
|
+
|
307
|
+
INDEFINITE_ARRAY_START = b'\x9f'
|
308
|
+
INDEFINITE_MAP_START = b'\xbf'
|
309
|
+
BREAK_MARKER = b'\xff'
|
310
|
+
|
311
|
+
|
312
|
+
class ValueSerializer(ABC):
|
313
|
+
|
314
|
+
@abstractmethod
|
315
|
+
def serialize(self, value: Any, type_name: Optional[str], writer: CBOREncoder, context: SerializationContext):
|
316
|
+
pass
|
317
|
+
|
318
|
+
@staticmethod
|
319
|
+
def write_object_using_reflection(value: Any, type_name: Optional[str], with_id: bool,
|
320
|
+
encoder: CBOREncoder, context: SerializationContext):
|
321
|
+
if type(value).__qualname__ == 'JavaType.Primitive':
|
322
|
+
# FIXME implement type attribution support
|
323
|
+
encoder.encode(['org.openrewrite.java.tree.JavaType$Primitive', 0])
|
324
|
+
return
|
325
|
+
if with_id and (id := context.remoting_context.try_get_id(value)):
|
326
|
+
encoder.encode_int(id)
|
327
|
+
return
|
328
|
+
|
329
|
+
encoder.write(INDEFINITE_MAP_START)
|
330
|
+
encoder.encode_string("@c")
|
331
|
+
encoder.encode_string(type_name or to_java_type_name(type(value)))
|
332
|
+
if with_id:
|
333
|
+
encoder.encode_string("@ref")
|
334
|
+
id = context.remoting_context.add(value)
|
335
|
+
encoder.encode_int(id)
|
336
|
+
|
337
|
+
for field in fields(value):
|
338
|
+
if field.name[0] == '_' and not hasattr(field.type, '__origin__') or field.type.__origin__ is not ClassVar:
|
339
|
+
encoder.encode_string(field.name[1:])
|
340
|
+
context.serialize(getattr(value, field.name), None, encoder)
|
341
|
+
encoder.write(BREAK_MARKER)
|
342
|
+
|
343
|
+
|
344
|
+
class DefaultValueSerializer(ValueSerializer):
|
345
|
+
def serialize(self, value: Any, type_name: Optional[str], encoder: CBOREncoder, context: SerializationContext):
|
346
|
+
if isinstance(value, (int, float, str, bool, decimal.Decimal)):
|
347
|
+
encoder.encode(value)
|
348
|
+
elif value is None:
|
349
|
+
encoder.encode_none(None)
|
350
|
+
elif isinstance(value, UUID):
|
351
|
+
encoder.encode(value.bytes)
|
352
|
+
elif isinstance(value, Enum):
|
353
|
+
encoder.encode(value.value)
|
354
|
+
elif isinstance(value, Path):
|
355
|
+
encoder.encode(str(value))
|
356
|
+
elif isinstance(value, list):
|
357
|
+
encoder.encode_length(4, len(value))
|
358
|
+
for item in value:
|
359
|
+
context.serialize(item, None, encoder)
|
360
|
+
elif isinstance(value, Markers):
|
361
|
+
if (id := context.remoting_context.try_get_id(value)):
|
362
|
+
encoder.encode_int(id)
|
363
|
+
else:
|
364
|
+
encoder.encode_length(5, 3)
|
365
|
+
encoder.encode_string('@ref')
|
366
|
+
encoder.encode_int(id)
|
367
|
+
encoder.encode_string('id')
|
368
|
+
encoder.encode_uuid(value.id)
|
369
|
+
encoder.encode_string('markers')
|
370
|
+
encoder.encode_length(4, len(value.markers))
|
371
|
+
for marker in value.markers:
|
372
|
+
context.serialize(marker, None, encoder)
|
373
|
+
elif isinstance(value, Marker):
|
374
|
+
if (id := context.remoting_context.try_get_id(value)):
|
375
|
+
encoder.encode_int(id)
|
376
|
+
else:
|
377
|
+
id = context.remoting_context.add(value)
|
378
|
+
encoder.write(INDEFINITE_MAP_START)
|
379
|
+
encoder.encode_string("@c")
|
380
|
+
encoder.encode_string(to_java_type_name(type(value)))
|
381
|
+
encoder.encode_string('@ref')
|
382
|
+
encoder.encode_int(id)
|
383
|
+
for field in fields(value):
|
384
|
+
if field.name[0] == '_' and not hasattr(field.type,
|
385
|
+
'__origin__') or field.type.__origin__ is not ClassVar:
|
386
|
+
encoder.encode_string(field.name[1:])
|
387
|
+
context.serialize(getattr(value, field.name), None, encoder)
|
388
|
+
encoder.write(BREAK_MARKER)
|
389
|
+
else:
|
390
|
+
ValueSerializer.write_object_using_reflection(value, type_name, False, encoder, context)
|
391
|
+
|
392
|
+
|
393
|
+
def delegate_based_serializer(delegate: Callable[[Any, Optional[str], CBOREncoder, SerializationContext], None]):
|
394
|
+
class DelegateBasedSerializer(ValueSerializer):
|
395
|
+
def serialize(self, value: Any, type_name: Optional[str], encoder: CBOREncoder, context: SerializationContext):
|
396
|
+
delegate(value, type_name, encoder, context)
|
397
|
+
|
398
|
+
return DelegateBasedSerializer
|
399
|
+
|
400
|
+
|
401
|
+
class ParseErrorSender(Sender):
|
402
|
+
def send(self, after, before, ctx):
|
403
|
+
visitor = self.Visitor()
|
404
|
+
visitor.visit(after, ctx.fork(visitor, before))
|
405
|
+
|
406
|
+
class Visitor(ParseErrorVisitor):
|
407
|
+
def visit(self, tree, ctx, parent: Optional[Cursor] = None):
|
408
|
+
self.cursor = Cursor(self.cursor, tree)
|
409
|
+
ctx.send_node(tree, lambda x: x, ctx.send_tree)
|
410
|
+
self.cursor = self.cursor.parent
|
411
|
+
|
412
|
+
return tree
|
413
|
+
|
414
|
+
def visit_parse_error(self, parse_error, ctx):
|
415
|
+
ctx.send_value(parse_error, lambda v: v.id)
|
416
|
+
ctx.send_node(parse_error, lambda v: v.markers, ctx.send_markers)
|
417
|
+
ctx.send_value(parse_error, lambda v: v.source_path)
|
418
|
+
ctx.send_typed_value(parse_error, lambda v: v.file_attributes)
|
419
|
+
ctx.send_value(parse_error, lambda v: v.charset_name)
|
420
|
+
ctx.send_value(parse_error, lambda v: v.charset_bom_marked)
|
421
|
+
ctx.send_typed_value(parse_error, lambda v: v.checksum)
|
422
|
+
ctx.send_value(parse_error, lambda v: v.text)
|
423
|
+
# ctx.send_node(parse_error, lambda v: v.erroneous, ctx.send_tree)
|
424
|
+
return parse_error
|
rewrite/remote/server.py
ADDED
@@ -0,0 +1,188 @@
|
|
1
|
+
import importlib
|
2
|
+
import importlib.resources
|
3
|
+
import os
|
4
|
+
import socket
|
5
|
+
import sys
|
6
|
+
import time
|
7
|
+
import traceback
|
8
|
+
import zipfile
|
9
|
+
from io import BytesIO, StringIO
|
10
|
+
from pathlib import Path
|
11
|
+
from typing import Optional
|
12
|
+
|
13
|
+
import cbor2
|
14
|
+
import select
|
15
|
+
from cbor2 import dumps
|
16
|
+
from rewrite import ParserInput, InMemoryExecutionContext, ExecutionContext, ParseError
|
17
|
+
from rewrite.python import Py
|
18
|
+
from rewrite.python.parser import PythonParserBuilder
|
19
|
+
from rewrite.remote import SenderContext, ReceiverContext, RemotingContext, RemotingMessenger, RemotingMessageType, \
|
20
|
+
RemotePrinterFactory, OK, ParseErrorSender
|
21
|
+
from rewrite.remote.python.receiver import PythonReceiver
|
22
|
+
from rewrite.remote.python.sender import PythonSender
|
23
|
+
|
24
|
+
INACTIVITY_TIMEOUT = 300 # 5 minutes
|
25
|
+
_OK: int = 0
|
26
|
+
_ERROR: int = 1
|
27
|
+
|
28
|
+
|
29
|
+
def register_remoting_factories():
|
30
|
+
SenderContext.register(ParseError, lambda: ParseErrorSender())
|
31
|
+
SenderContext.register(Py, lambda: PythonSender())
|
32
|
+
ReceiverContext.register(Py, lambda: PythonReceiver())
|
33
|
+
|
34
|
+
|
35
|
+
def find_free_port():
|
36
|
+
"""Find a free port by using the system to allocate a port for us."""
|
37
|
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
38
|
+
s.bind(('', 0))
|
39
|
+
s.listen(1)
|
40
|
+
return s.getsockname()[1]
|
41
|
+
|
42
|
+
|
43
|
+
class Server:
|
44
|
+
_port: Optional[int]
|
45
|
+
_path: Optional[str]
|
46
|
+
_remoting_context: RemotingContext
|
47
|
+
_messenger: RemotingMessenger
|
48
|
+
|
49
|
+
def __init__(self, port: int = None, path: str = None, timeout: int = INACTIVITY_TIMEOUT):
|
50
|
+
self._port = port
|
51
|
+
self._path = path
|
52
|
+
self.timeout = timeout
|
53
|
+
self._remoting_context = RemotingContext()
|
54
|
+
self._messenger = RemotingMessenger(
|
55
|
+
self._remoting_context,
|
56
|
+
{
|
57
|
+
'parse-python-source': self.parse_python_source,
|
58
|
+
'parse-python-file': self.parse_python_file,
|
59
|
+
}
|
60
|
+
)
|
61
|
+
|
62
|
+
def start(self):
|
63
|
+
"""Start the server and listen for connections on the given port."""
|
64
|
+
if self._path:
|
65
|
+
if os.path.exists(self._path):
|
66
|
+
os.remove(self._path)
|
67
|
+
with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as s:
|
68
|
+
s.bind(self._path)
|
69
|
+
s.listen()
|
70
|
+
print(f"Server listening on Unix domain socket: {self._path}")
|
71
|
+
while True:
|
72
|
+
conn, _ = s.accept()
|
73
|
+
with conn:
|
74
|
+
self.handle_client(conn)
|
75
|
+
else:
|
76
|
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
77
|
+
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
78
|
+
s.bind(('localhost', self._port))
|
79
|
+
s.listen(5)
|
80
|
+
print(f"Server listening on port {self._port}")
|
81
|
+
last_activity_time = time.time()
|
82
|
+
while True:
|
83
|
+
s.settimeout(5)
|
84
|
+
try:
|
85
|
+
conn, addr = s.accept()
|
86
|
+
last_activity_time = time.time() # Reset inactivity timer
|
87
|
+
with conn:
|
88
|
+
conn.settimeout(None)
|
89
|
+
self.handle_client(conn)
|
90
|
+
except socket.timeout:
|
91
|
+
current_time = time.time()
|
92
|
+
if current_time - last_activity_time >= self.timeout:
|
93
|
+
print("No new connections for 5 minutes, shutting down server.")
|
94
|
+
break
|
95
|
+
|
96
|
+
def handle_client(self, sock: socket.socket) -> None:
|
97
|
+
try:
|
98
|
+
self._remoting_context.connect(sock)
|
99
|
+
RemotePrinterFactory(self._remoting_context.client).set_current()
|
100
|
+
|
101
|
+
sock.setblocking(True)
|
102
|
+
while True:
|
103
|
+
message_type = sock.recv(1)
|
104
|
+
if not message_type:
|
105
|
+
return
|
106
|
+
assert cbor2.load(BytesIO(message_type)) == RemotingMessageType.Request
|
107
|
+
self._messenger.process_request(sock)
|
108
|
+
|
109
|
+
readable, _, _ = select.select([sock], [], [], 0.01)
|
110
|
+
if sock not in readable:
|
111
|
+
return
|
112
|
+
|
113
|
+
except (OSError, IOError):
|
114
|
+
# the socket was closed
|
115
|
+
return
|
116
|
+
except Exception as e:
|
117
|
+
print(f"An error occurred while handling client: {e}")
|
118
|
+
traceback.print_exc()
|
119
|
+
if sock.fileno() != -1:
|
120
|
+
try:
|
121
|
+
# Equivalent to C#'s stream.WriteTimeout = 1000;
|
122
|
+
sock.send(dumps(RemotingMessageType.Response))
|
123
|
+
sock.send(dumps(_ERROR))
|
124
|
+
sock.send(dumps(traceback.format_exc()))
|
125
|
+
except (OSError, IOError):
|
126
|
+
# the socket was closed
|
127
|
+
return
|
128
|
+
except Exception as inner_exception:
|
129
|
+
print(inner_exception)
|
130
|
+
finally:
|
131
|
+
sock.close()
|
132
|
+
|
133
|
+
def parse_python_source(self, stream: BytesIO, sock: socket.socket, remoting_ctx: RemotingContext):
|
134
|
+
source = cbor2.load(stream)
|
135
|
+
ctx = InMemoryExecutionContext()
|
136
|
+
ctx.put_message(ExecutionContext.REQUIRE_PRINT_EQUALS_INPUT, False)
|
137
|
+
for cu in PythonParserBuilder().build().parse_inputs(
|
138
|
+
[ParserInput(Path('source.py'), None, True, lambda: StringIO(source))],
|
139
|
+
None, ctx):
|
140
|
+
response_stream = BytesIO()
|
141
|
+
cbor2.dump(RemotingMessageType.Response, response_stream)
|
142
|
+
cbor2.dump(OK, response_stream)
|
143
|
+
remoting_ctx.new_sender_context(response_stream).send_any_tree(cu, None)
|
144
|
+
sock.sendall(response_stream.getvalue())
|
145
|
+
|
146
|
+
def parse_python_file(self, stream: BytesIO, sock: socket.socket, remoting_ctx: RemotingContext):
|
147
|
+
path = cbor2.load(stream)
|
148
|
+
ctx = InMemoryExecutionContext()
|
149
|
+
ctx.put_message(ExecutionContext.REQUIRE_PRINT_EQUALS_INPUT, False)
|
150
|
+
for cu in PythonParserBuilder().build().parse_inputs(
|
151
|
+
[ParserInput(Path(path), None, True, lambda: read_file_contents(path))],
|
152
|
+
None, ctx):
|
153
|
+
response_stream = BytesIO()
|
154
|
+
cbor2.dump(RemotingMessageType.Response, response_stream)
|
155
|
+
cbor2.dump(OK, response_stream)
|
156
|
+
remoting_ctx.new_sender_context(response_stream).send_any_tree(cu, None)
|
157
|
+
sock.sendall(response_stream.getvalue())
|
158
|
+
|
159
|
+
|
160
|
+
def read_file_contents(path):
|
161
|
+
with open(path, 'r') as file:
|
162
|
+
return StringIO(file.read())
|
163
|
+
|
164
|
+
|
165
|
+
def read_data_from_zip():
|
166
|
+
# Access the resource within the 'your_package.resources' package
|
167
|
+
# 'data.zip' is the name of the file included
|
168
|
+
with importlib.resources.open_binary('resources', 'rewrite-remote-java.zip') as f:
|
169
|
+
# Open the zip file from the resource file stream
|
170
|
+
with zipfile.ZipFile(f) as zip_file:
|
171
|
+
# List the contents of the zip file
|
172
|
+
print(zip_file.namelist())
|
173
|
+
# Read a specific file inside the zip file (if you know the file name within it)
|
174
|
+
with zip_file.open('rewrite-remote-java-0.2.0-SNAPSHOT/bin/rewrite-remote-java') as inner_file:
|
175
|
+
data = inner_file.read()
|
176
|
+
print(data.decode('utf-8'))
|
177
|
+
|
178
|
+
|
179
|
+
# Example usage
|
180
|
+
# read_data_from_zip()
|
181
|
+
|
182
|
+
if __name__ == "__main__":
|
183
|
+
port = int(sys.argv[1]) if len(sys.argv) > 1 else 54322
|
184
|
+
timeout = int(sys.argv[2]) if len(sys.argv) > 2 else INACTIVITY_TIMEOUT
|
185
|
+
register_remoting_factories()
|
186
|
+
Server(port=port, timeout=timeout).start()
|
187
|
+
# Server(port=find_free_port()).start()
|
188
|
+
# Server(path=tempfile.gettempdir() + '/rewrite-csharp.sock').start()
|
@@ -0,0 +1,60 @@
|
|
1
|
+
import typing
|
2
|
+
from functools import lru_cache
|
3
|
+
from typing import Type
|
4
|
+
|
5
|
+
|
6
|
+
@lru_cache(maxsize=None)
|
7
|
+
def to_java_type_name(t: typing.Type) -> str:
|
8
|
+
if t == bool:
|
9
|
+
return 'java.lang.Boolean'
|
10
|
+
if t == int:
|
11
|
+
return 'java.lang.Integer'
|
12
|
+
if t == str:
|
13
|
+
return 'java.lang.String'
|
14
|
+
if t == float:
|
15
|
+
return 'java.lang.Double'
|
16
|
+
if t == type(None):
|
17
|
+
return 'null'
|
18
|
+
if t.__module__.startswith('rewrite.java.support_types'):
|
19
|
+
if t.__name__ == 'Space':
|
20
|
+
return 'org.openrewrite.java.tree.Space'
|
21
|
+
if t.__name__ == 'Comment':
|
22
|
+
return 'org.openrewrite.java.tree.Comment'
|
23
|
+
if t.__name__ == 'TextComment':
|
24
|
+
return 'org.openrewrite.java.tree.TextComment'
|
25
|
+
if t.__name__ == 'JLeftPadded':
|
26
|
+
return 'org.openrewrite.java.tree.JLeftPadded'
|
27
|
+
if t.__name__ == 'JRightPadded':
|
28
|
+
return 'org.openrewrite.java.tree.JRightPadded'
|
29
|
+
if t.__name__ == 'JContainer':
|
30
|
+
return 'org.openrewrite.java.tree.JContainer'
|
31
|
+
if t.__module__.startswith('rewrite.java.markers'):
|
32
|
+
return 'org.openrewrite.java.marker.' + t.__qualname__
|
33
|
+
if t.__module__.startswith('rewrite.java.tree'):
|
34
|
+
return 'org.openrewrite.java.tree.J$' + t.__qualname__.replace('.', '$')
|
35
|
+
if t.__module__.startswith('rewrite.python.support_types'):
|
36
|
+
if t.__name__ == 'PyComment':
|
37
|
+
return 'org.openrewrite.python.tree.PyComment'
|
38
|
+
if t.__name__ == 'PyLeftPadded':
|
39
|
+
return 'org.openrewrite.python.tree.PyLeftPadded'
|
40
|
+
if t.__name__ == 'PyRightPadded':
|
41
|
+
return 'org.openrewrite.python.tree.PyRightPadded'
|
42
|
+
if t.__name__ == 'PyContainer':
|
43
|
+
return 'org.openrewrite.python.tree.PyContainer'
|
44
|
+
if t.__module__.startswith('rewrite.python.tree'):
|
45
|
+
return 'org.openrewrite.python.tree.Py$' + t.__qualname__.replace('.', '$')
|
46
|
+
if t.__module__.startswith('rewrite.marker'):
|
47
|
+
if t.__name__ == 'ParseExceptionResult':
|
48
|
+
return 'org.openrewrite.ParseExceptionResult'
|
49
|
+
return 'org.openrewrite.marker.' + t.__qualname__.replace('.', '$')
|
50
|
+
if t.__module__ == 'rewrite.parser' and t.__name__ == 'ParseError':
|
51
|
+
return 'org.openrewrite.tree.ParseError'
|
52
|
+
if t.__module__.startswith('rewrite.') and t.__module__.endswith('.tree'):
|
53
|
+
model = t.__module__.split('.')[1]
|
54
|
+
return 'org.openrewrite.' + model + '.tree.' + model.capitalize() + '$' + t.__qualname__.replace('.', '$')
|
55
|
+
return t.__module__ + '.' + t.__qualname__
|
56
|
+
# raise NotImplementedError("to_java_type_name: " + str(o))
|
57
|
+
|
58
|
+
|
59
|
+
def get_type(type_name: str) -> Type:
|
60
|
+
raise NotImplementedError("get_type for: " + type_name)
|