openrewrite-remote 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,342 @@
1
+ from __future__ import absolute_import
2
+
3
+ import socket
4
+ import threading
5
+ from dataclasses import dataclass
6
+ from io import BytesIO
7
+ from threading import Lock
8
+ from typing import Any, Dict, Optional, Type, List, Tuple, Callable, cast
9
+
10
+ import cbor2
11
+ from cbor2 import dumps, loads, load
12
+ from rewrite.execution import DelegatingExecutionContext
13
+ from rewrite.tree import PrintOutputCapture, P
14
+ from rewrite.visitor import T
15
+
16
+ from rewrite import Recipe, InMemoryExecutionContext, Cursor, PrinterFactory, TreeVisitor, Tree
17
+ from rewrite.remote import ValueSerializer, ValueDeserializer, SenderContext, ReceiverContext, JsonSender, JsonReceiver, \
18
+ SerializationContext, DeserializationContext, remote_utils
19
+
20
+
21
+ class RemotingContext:
22
+ _remoting_thread_local = threading.local()
23
+ _recipe_factories: Dict[str, Callable[[str, Dict[str, Any]], Recipe]] = {}
24
+ _value_serializers: Dict[Type, ValueSerializer] = {}
25
+ _value_deserializers: List[Tuple[Type, ValueDeserializer]] = []
26
+ _object_to_id_map: Dict[Any, int] = {}
27
+ _id_to_object_map: Dict[int, Any] = {}
28
+
29
+ def __init__(self):
30
+ self._client = None
31
+
32
+ @classmethod
33
+ def current(cls) -> 'RemotingContext':
34
+ result = getattr(cls._remoting_thread_local, 'context', None)
35
+ if result is None:
36
+ raise ValueError("No RemotingContext has been set")
37
+ return result
38
+
39
+ def set_current(self) -> None:
40
+ cls = self.__class__
41
+ cls._remoting_thread_local.context = self
42
+
43
+ @property
44
+ def client(self) -> Optional['RemotingClient']:
45
+ return self._client
46
+
47
+ def connect(self, sock: Any) -> 'RemotingContext':
48
+ self._client = RemotingClient(self, sock)
49
+ return self
50
+
51
+ def close(self) -> None:
52
+ self._client.close()
53
+
54
+ def try_get_id(self, key: Any) -> Optional[int]:
55
+ return self._object_to_id_map.get(key)
56
+
57
+ def add(self, value: Any) -> int:
58
+ object_id = len(self._object_to_id_map)
59
+ self._object_to_id_map[value] = object_id
60
+ self._id_to_object_map[object_id] = value
61
+ return object_id
62
+
63
+ def add_by_id(self, key: int, value: Any) -> None:
64
+ self._id_to_object_map[key] = value
65
+ self._object_to_id_map[value] = key
66
+
67
+ def get_object(self, key: int) -> Optional[Any]:
68
+ return self._id_to_object_map.get(key)
69
+
70
+ def reset(self) -> None:
71
+ self._object_to_id_map.clear()
72
+ self._id_to_object_map.clear()
73
+
74
+ def new_sender_context(self, output_stream: Any) -> 'SenderContext':
75
+ return SenderContext(JsonSender(output_stream, SerializationContext(self, self._value_serializers)))
76
+
77
+ def new_receiver_context(self, input_stream: Any) -> 'ReceiverContext':
78
+ return ReceiverContext(JsonReceiver(input_stream, DeserializationContext(self, self._value_deserializers)))
79
+
80
+ def copy(self) -> 'RemotingContext':
81
+ return RemotingContext()
82
+
83
+ def new_recipe(self, recipe_id: str, recipe_options: Any) -> 'Recipe':
84
+ return self._recipe_factories[recipe_id](recipe_options)
85
+
86
+ @classmethod
87
+ def register_value_serializer(cls, serializer: ValueSerializer) -> None:
88
+ cls._value_serializers[serializer.__annotations__['value']] = serializer
89
+
90
+ @classmethod
91
+ def register_value_deserializer(cls, deserializer: ValueDeserializer) -> None:
92
+ for i in range(len(cls._value_deserializers)):
93
+ type_ = cls._value_deserializers[i][0]
94
+ if type_ == deserializer.__annotations__['data']:
95
+ cls._value_deserializers[i] = (type_, deserializer)
96
+ return
97
+ cls._value_deserializers.append((deserializer.__annotations__['data'], deserializer))
98
+
99
+
100
+ class RemotingExecutionContextView(DelegatingExecutionContext):
101
+ def __init__(self, delegate):
102
+ super().__init__(delegate)
103
+ self._delegate = delegate
104
+
105
+ @staticmethod
106
+ def view(ctx):
107
+ if isinstance(ctx, RemotingExecutionContextView):
108
+ return ctx
109
+ return RemotingExecutionContextView(ctx)
110
+
111
+ @property
112
+ def remoting_context(self) -> RemotingContext:
113
+ return self._delegate.get_message('remoting', RemotingContext.current())
114
+
115
+ @remoting_context.setter
116
+ def remoting_context(self, value: RemotingContext):
117
+ value.set_current()
118
+ self._delegate.put_message('remoting', value)
119
+
120
+
121
+ OK = 0
122
+ ERROR = 1
123
+
124
+
125
+ class RemotingMessenger:
126
+
127
+ def __init__(self, context: RemotingContext,
128
+ additional_handlers: Dict[str, Callable[[BytesIO, socket.socket, RemotingContext], Any]] = None):
129
+ self._context = context
130
+ self._additional_handlers = additional_handlers or {}
131
+ self._recipes = []
132
+ self._state = None
133
+
134
+ def process_request(self, sock: socket.socket) -> bool:
135
+ stream = remote_utils.read_to_command_end(sock)
136
+ command = cbor2.load(stream)
137
+
138
+ if command == "hello":
139
+ self.handle_hello_command(stream, sock)
140
+ elif command == "reset":
141
+ self.handle_reset_command(stream, sock)
142
+ elif command == "load-recipe":
143
+ self.handle_load_recipe_command(stream, sock)
144
+ elif command == "run-recipe-visitor":
145
+ self.handle_run_recipe_visitor_command(stream, sock)
146
+ elif command == "print":
147
+ self.handle_print_command(stream, sock)
148
+ else:
149
+ if command in self._additional_handlers:
150
+ self._additional_handlers[command](stream, sock, self._context)
151
+ else:
152
+ raise NotImplementedError(f"Unsupported command: {command}")
153
+
154
+ return True
155
+
156
+ def handle_hello_command(self, stream: BytesIO, sock: socket.socket):
157
+ cbor2.load(stream)
158
+ response_stream = BytesIO()
159
+ cbor2.dump(RemotingMessageType.Response, response_stream)
160
+ cbor2.dump(OK, response_stream)
161
+ sock.sendall(response_stream.getvalue())
162
+
163
+ def handle_reset_command(self, stream: BytesIO, sock: socket.socket):
164
+ self._state = None
165
+ self._context = self._context.copy()
166
+ self._context.connect(socket.socket())
167
+ self._recipes.clear()
168
+ response_stream = BytesIO()
169
+ cbor2.dump(RemotingMessageType.Response, response_stream)
170
+ cbor2.dump(OK, response_stream)
171
+ sock.sendall(response_stream.getvalue())
172
+
173
+ def handle_load_recipe_command(self, stream: BytesIO, sock: socket.socket):
174
+ recipe_id = cbor2.load(stream)
175
+ recipe_options = cbor2.load(stream)
176
+ recipe = self._context.new_recipe(recipe_id, recipe_options)
177
+ self._recipes.append(recipe)
178
+
179
+ response_stream = BytesIO()
180
+ cbor2.dump(RemotingMessageType.Response, response_stream)
181
+ cbor2.dump(OK, response_stream)
182
+ cbor2.dump(len(self._recipes) - 1, response_stream)
183
+ sock.sendall(response_stream.getvalue())
184
+
185
+ def handle_run_recipe_visitor_command(self, stream: BytesIO, sock: socket.socket):
186
+ recipe_index = cbor2.load(stream)
187
+ recipe = self._recipes[recipe_index]
188
+
189
+ output_stream = BytesIO()
190
+
191
+ receiver_context = self._context.new_receiver_context(stream)
192
+ received = receiver_context.receive_tree(self._state)
193
+ ctx = InMemoryExecutionContext()
194
+ RemotingExecutionContextView.view(ctx).remoting_context = self._context
195
+ self._state = recipe.get_visitor().visit(received, ctx)
196
+ sender_context = self._context.new_sender_context(output_stream)
197
+ sender_context.send_any_tree(self._state, received)
198
+
199
+ response_stream = BytesIO()
200
+ cbor2.dump(RemotingMessageType.Response, response_stream)
201
+ cbor2.dump(OK, response_stream)
202
+ response_stream.write(output_stream.getvalue())
203
+ sock.sendall(response_stream.getvalue())
204
+
205
+ def handle_print_command(self, stream: BytesIO, sock: socket.socket):
206
+ receiver_context = self._context.new_receiver_context(stream)
207
+ received = receiver_context.receive_any_tree(None)
208
+ root_cursor = Cursor(None, Cursor.ROOT_VALUE)
209
+ ctx = InMemoryExecutionContext()
210
+ RemotingExecutionContextView.view(ctx).remoting_context = self._context
211
+ print_output = received.print(Cursor(root_cursor, received), PrintOutputCapture(0))
212
+
213
+ response_stream = BytesIO()
214
+ cbor2.dump(RemotingMessageType.Response, response_stream)
215
+ cbor2.dump(OK, response_stream)
216
+ cbor2.dump(print_output, response_stream)
217
+ sock.sendall(response_stream.getvalue())
218
+
219
+ def send_request(self, sock: socket.socket, command: str, *args):
220
+ sock.sendall(dumps(RemotingMessageType.Request))
221
+ sock.sendall(dumps(command))
222
+ for arg in args:
223
+ # FIXME serialize properly
224
+ sock.sendall(dumps(arg))
225
+ self.send_end_message(sock)
226
+
227
+ def send_request_stream(self, sock: socket.socket, command: str, *args):
228
+ sock.sendall(dumps(RemotingMessageType.Request))
229
+ sock.sendall(dumps(command))
230
+ for arg in args:
231
+ arg(sock)
232
+ self.send_end_message(sock)
233
+
234
+ @staticmethod
235
+ def send_end_message(sock):
236
+ sock.sendall(b'\x81\x17')
237
+
238
+ def send_print_request(self, sock: socket.socket, cursor: Cursor):
239
+ self.send_request_stream(sock, "print",
240
+ lambda s: self.send_tree(s, cast(Tree, cursor.value)))
241
+ if self.recv_byte(sock) != RemotingMessageType.Response:
242
+ raise ValueError("Unexpected message type.")
243
+ if self.recv_byte(sock) != 0:
244
+ raise ValueError(f"Remote print failed: {loads(self.recv_all(sock))}")
245
+ data = remote_utils.read_to_command_end(sock)
246
+ print_msg = load(data)
247
+ # end = load(data) # end
248
+ return print_msg
249
+
250
+ def send_tree(self, sock: socket.socket, tree: Tree):
251
+ b = BytesIO()
252
+ self._context.new_sender_context(b).send_any_tree(tree, None)
253
+ sock.sendall(b.getvalue())
254
+
255
+ def send_run_recipe_request(self, sock: socket.socket, recipe, options: dict, source_files: list):
256
+ self.send_request_stream(sock, "run-recipe", lambda s: (
257
+ sock.sendall(dumps(recipe)),
258
+ sock.sendall(dumps(options)),
259
+ sock.sendall(dumps(len(source_files))),
260
+ *[self._context.new_sender_context(s).send_any_tree(sf, None) for sf in source_files]
261
+ ))
262
+ while self.recv_byte(sock) == RemotingMessageType.Request:
263
+ self.process_request(sock)
264
+ if self.recv_byte(sock) != RemotingMessageType.Response:
265
+ raise ValueError("Unexpected message type.")
266
+ if self.recv_byte(sock) != 0:
267
+ raise ValueError(f"Remote recipe run failed: {loads(self.recv_all(sock))}")
268
+ input_stream = remote_utils.read_to_command_end(sock)
269
+ receiver_context = self._context.new_receiver_context(input_stream)
270
+ updated = [receiver_context.receive_any_tree(sf) for sf in source_files]
271
+ loads(input_stream) # end
272
+ return updated
273
+
274
+ def send_reset_request(self, sock: socket.socket):
275
+ self.send_request(sock, "reset")
276
+ if self.recv_byte(sock) != RemotingMessageType.Response:
277
+ raise ValueError("Unexpected message type.")
278
+ if self.recv_byte(sock) != 0:
279
+ raise ValueError(f"Remote reset failed: {loads(self.recv_all(sock))}")
280
+ loads(self.recv_all(sock)) # command end
281
+
282
+ def recv_byte(self, sock):
283
+ return sock.recv(1)[0]
284
+
285
+ def recv_all(self, sock, buffer_size=4096):
286
+ data = b''
287
+ while True:
288
+ part = sock.recv(buffer_size)
289
+ data += part
290
+ if len(part) < buffer_size:
291
+ break
292
+ return data
293
+
294
+
295
+ class RemotingMessageType:
296
+ Request = 0
297
+ Response = 1
298
+
299
+
300
+ class RemotingClient:
301
+ def __init__(self, context, sock: socket.socket):
302
+ self._messenger = RemotingMessenger(context)
303
+ self._socket = sock
304
+ self._lock = Lock()
305
+
306
+ def close(self) -> None:
307
+ self._socket.close()
308
+
309
+ def hello(self):
310
+ with self._lock:
311
+ self._messenger.send_request(self._socket, "hello")
312
+
313
+ def print(self, cursor):
314
+ with self._lock:
315
+ return self._messenger.send_print_request(self._socket, cursor)
316
+
317
+ def reset(self):
318
+ with self._lock:
319
+ self._messenger.send_reset_request(self._socket)
320
+
321
+ def run_recipe(self, recipe, options: dict, source_files: list):
322
+ with self._lock:
323
+ return self._messenger.send_run_recipe_request(self._socket, recipe, options, source_files)
324
+
325
+
326
+ @dataclass
327
+ class RemotePrinter(TreeVisitor[Any, PrintOutputCapture[P]]):
328
+ _client: RemotingClient
329
+
330
+ def visit(self, tree: Optional[Tree], p: PrintOutputCapture[P], parent: Optional[Cursor] = None) -> Optional[T]:
331
+ self.cursor = Cursor(parent, tree)
332
+ p.append(self._client.print(self.cursor))
333
+ self.cursor = self.cursor.parent
334
+ return tree
335
+
336
+
337
+ @dataclass
338
+ class RemotePrinterFactory(PrinterFactory):
339
+ _client: RemotingClient
340
+
341
+ def create_printer(self, cursor: Cursor) -> TreeVisitor[Any, PrintOutputCapture[P]]:
342
+ return RemotePrinter(self._client)