openrewrite-remote 0.13.2__py3-none-any.whl → 0.13.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,152 @@
1
+ import logging
2
+ import socket
3
+ from io import BytesIO
4
+ from typing import List, TypedDict
5
+ from cbor2 import dumps, CBORDecoder
6
+
7
+ from rewrite_remote.remote_utils import COMMAND_END
8
+ from rewrite_remote.remoting import OK, RemotingContext, RemotingMessageType
9
+ from rewrite_remote.handlers.handler_helpers import respond_with_error
10
+ from rewrite_remote.handlers.pypi_manager import PyPiManager, Source
11
+ from rewrite_remote.handlers.types import PackageSource
12
+
13
+
14
+ class RecipeInstallArgs(TypedDict):
15
+ package_id: str
16
+ package_version: str
17
+ include_default_repository: bool
18
+ package_sources: List[PackageSource]
19
+
20
+
21
+ def decode_recipe_install_args(decoder: CBORDecoder) -> RecipeInstallArgs:
22
+ """
23
+ Decodes the arguments (order matters and must match the order encoded)
24
+ """
25
+ package_id = str(decoder.decode())
26
+ package_version = str(decoder.decode())
27
+ include_default_repository = bool(decoder.decode())
28
+ package_sources_data = decoder.decode()
29
+
30
+ if not isinstance(package_sources_data, list):
31
+ raise ValueError("package_sources_data is not a list")
32
+
33
+ package_sources = []
34
+ if package_sources_data and len(package_sources_data) > 0:
35
+ package_sources = [
36
+ PackageSource(source=ps.get("source"), credential=ps.get("credential"))
37
+ for ps in package_sources_data
38
+ ]
39
+
40
+ return {
41
+ "package_id": package_id,
42
+ "package_version": package_version,
43
+ "include_default_repository": include_default_repository,
44
+ "package_sources": package_sources,
45
+ }
46
+
47
+
48
+ # Main command handler with the specified signature
49
+ def recipe_install_handler(
50
+ stream: BytesIO, sock: socket.socket, remoting_ctx: RemotingContext
51
+ ) -> None:
52
+ remoting_ctx.reset()
53
+
54
+ # 1. Read input from stream
55
+ try:
56
+ data = stream.read()
57
+ decoder = CBORDecoder(BytesIO(data))
58
+ args = decode_recipe_install_args(decoder)
59
+ package_id = args.get("package_id")
60
+ package_version = args.get("package_version")
61
+ include_default_repository = args.get("include_default_repository")
62
+ package_sources = args.get("package_sources")
63
+ except Exception as e: # pylint: disable=broad-except
64
+ respond_with_error(f"Failed to decode arguments: {e}", sock)
65
+ return
66
+
67
+ if package_id is None:
68
+ respond_with_error("package_id is required", sock)
69
+ return
70
+
71
+ if package_version is None:
72
+ respond_with_error("package_version is required", sock)
73
+ return
74
+
75
+ if package_sources is None:
76
+ respond_with_error("package_sources is required", sock)
77
+ return
78
+
79
+ if include_default_repository is None:
80
+ respond_with_error("include_default_repository is required", sock)
81
+ return
82
+
83
+ # 2. Log the request
84
+ logging.info(
85
+ f"""[Server] Handling install-recipe request: {{
86
+ packageId: {package_id},
87
+ packageVersion: {package_version},
88
+ packageSources: {package_sources},
89
+ includeDefaultRepository: {include_default_repository},
90
+ }}"""
91
+ )
92
+
93
+ # 3. Validate sources
94
+ sources: List[Source] = [
95
+ Source(
96
+ source=ps.source,
97
+ username=ps.credential.get("username") if ps.credential else None,
98
+ password=ps.credential.get("password") if ps.credential else None,
99
+ token=ps.credential.get("token") if ps.credential else None,
100
+ )
101
+ for ps in package_sources or []
102
+ ]
103
+
104
+ valid_source = PyPiManager.find_valid_source(
105
+ package_id, package_version, sources, include_default_repository
106
+ )
107
+
108
+ if not valid_source:
109
+ respond_with_error("No valid sources found", sock)
110
+ return
111
+
112
+ # 4. Install the recipe
113
+ try:
114
+ installable_recipes = PyPiManager.install_package(
115
+ package_id, package_version, valid_source
116
+ )
117
+ except Exception as e: # pylint: disable=broad-except
118
+ respond_with_error(f"Failed to install package: {e}", sock)
119
+ return
120
+
121
+ # 5. Log the result
122
+ logging.info(
123
+ "[Server] Found %d recipe(s) for package %s",
124
+ len(installable_recipes.recipes),
125
+ package_id,
126
+ )
127
+ for recipe in installable_recipes.recipes:
128
+ logging.info(" Resolved recipe %s from %s", recipe, valid_source.source)
129
+
130
+ # 6. Write response to stream
131
+ response = {
132
+ "recipes": [
133
+ {
134
+ "name": recipe.name,
135
+ "source": recipe.source,
136
+ "options": [],
137
+ }
138
+ for recipe in installable_recipes.recipes
139
+ ],
140
+ "repository": installable_recipes.source,
141
+ "version": installable_recipes.version,
142
+ }
143
+
144
+ # Encode the response using CBOR
145
+ encoded_response = b""
146
+ encoded_response += dumps(RemotingMessageType.Response)
147
+ encoded_response += dumps(OK)
148
+ encoded_response += dumps(response)
149
+ encoded_response += COMMAND_END
150
+ sock.sendall(encoded_response)
151
+
152
+ logging.info("[Server] Request completed.")
@@ -0,0 +1,135 @@
1
+ import logging
2
+ import socket
3
+ from io import BytesIO
4
+ from typing import List, TypedDict, Any
5
+ from cbor2 import dumps, CBORDecoder
6
+
7
+ from rewrite_remote.handlers.pypi_manager import Option, PyPiManager
8
+ from rewrite_remote.remoting import (
9
+ OK,
10
+ RemotingExecutionContextView,
11
+ RemotingMessageType,
12
+ RemotingMessenger,
13
+ )
14
+
15
+ from rewrite_remote.remote_utils import COMMAND_END
16
+ from rewrite_remote.remoting import OK, RemotingContext, RemotingMessageType
17
+ from rewrite_remote.handlers.handler_helpers import respond_with_error
18
+
19
+ from rewrite import InMemoryExecutionContext
20
+
21
+
22
+ class RunRecipeLoadAndVisitorArgs(TypedDict):
23
+ recipe_name: str
24
+ recipe_source: str
25
+ recipe_options: Any # List[recipeOption]
26
+
27
+
28
+ def decode_run_recipe_load_and_visitor_args(
29
+ decoder: CBORDecoder,
30
+ ) -> RunRecipeLoadAndVisitorArgs:
31
+ """
32
+ Decodes the arguments (order matters and must match the order encoded)
33
+ """
34
+ recipe_name = str(decoder.decode())
35
+ recipe_source = str(decoder.decode())
36
+ recipe_options = decoder.decode()
37
+
38
+ return {
39
+ "recipe_name": recipe_name,
40
+ "recipe_source": recipe_source,
41
+ "recipe_options": recipe_options,
42
+ }
43
+
44
+
45
+ def run_recipe_load_and_visitor_handler(
46
+ stream: BytesIO, sock: socket.socket, remoting_ctx: RemotingContext
47
+ ) -> None:
48
+ remoting_ctx.reset()
49
+
50
+ # Read input from stream
51
+ try:
52
+ data = stream.read()
53
+ decoder = CBORDecoder(BytesIO(data))
54
+ args = decode_run_recipe_load_and_visitor_args(decoder)
55
+ recipe_name = args.get("recipe_name")
56
+ recipe_source = args.get("recipe_source")
57
+ recipe_options: List[Option] = args.get("recipe_options") or []
58
+ except Exception as e: # pylint: disable=broad-except
59
+ respond_with_error(f"Failed to decode arguments: {e}", sock)
60
+ return
61
+
62
+ if recipe_name is None:
63
+ respond_with_error("recipe_name is required", sock)
64
+ return
65
+
66
+ if recipe_source is None:
67
+ respond_with_error("recipe_source is required", sock)
68
+ return
69
+
70
+ # Log the request
71
+ logging.info(
72
+ f"""[Server] Handling run-recipe-load-and-visitor request: {{
73
+ recipe_name: {recipe_name},
74
+ recipe_source: {recipe_source},
75
+ recipe_options: {recipe_options},
76
+ }}"""
77
+ )
78
+
79
+ # Receive the tree
80
+ if not hasattr(RemotingMessenger, "_state"):
81
+ RemotingMessenger._state = None
82
+
83
+ if not hasattr(RemotingMessenger, "_context"):
84
+ RemotingMessenger._context = remoting_ctx
85
+
86
+ received = None
87
+ try:
88
+ received = RemotingMessenger.receive_tree(
89
+ RemotingMessenger, stream, RemotingMessenger._state
90
+ )
91
+ except Exception as e:
92
+ # logging.error(f"Failed to receive tree: {e}")
93
+ respond_with_error(f"Failed to receive tree: {e}", sock)
94
+ return
95
+
96
+ # Set the execution context
97
+ ctx = InMemoryExecutionContext()
98
+ RemotingExecutionContextView.view(ctx).remoting_context = remoting_ctx
99
+
100
+ try:
101
+ recipe_instance = PyPiManager.load_recipe(
102
+ recipe_name, recipe_source, recipe_options
103
+ )
104
+ except Exception as e:
105
+ respond_with_error(f"Failed to load recipe: {e}", sock)
106
+ return
107
+
108
+ # 4. Run the recipe
109
+ try:
110
+ tree_visitor = recipe_instance.get_visitor()
111
+
112
+ if not hasattr(tree_visitor, "visit") or not callable(tree_visitor.visit):
113
+ raise ValueError("Visitor does not have a visit method")
114
+
115
+ RemotingMessenger._state = tree_visitor.visit(received, ctx)
116
+
117
+ if RemotingMessenger._state is None:
118
+ raise ValueError("RemotingMessenger.state cannot be None")
119
+ except Exception as e:
120
+ respond_with_error(f"Failed to process input data with recipe: {e}", sock)
121
+ return
122
+
123
+ # 5. Write the response
124
+ response_encoder = BytesIO()
125
+ RemotingMessenger.send_tree(
126
+ remoting_ctx, response_encoder, RemotingMessenger._state, received
127
+ )
128
+
129
+ encoded_response = b""
130
+ encoded_response += dumps(RemotingMessageType.Response)
131
+ encoded_response += dumps(OK)
132
+ encoded_response += COMMAND_END
133
+ sock.sendall(encoded_response)
134
+
135
+ logging.info("[Server] Request completed.")
@@ -0,0 +1,30 @@
1
+ from typing import List, Optional, Dict, Any
2
+ from dataclasses import dataclass, field
3
+
4
+
5
+ @dataclass
6
+ class PackageSource:
7
+ source: str = ""
8
+ credential: Optional[Dict[str, str]] = None
9
+
10
+
11
+ @dataclass
12
+ class RecipeOption:
13
+ name: str
14
+ type: str
15
+ required: bool
16
+ value: Optional[Any]
17
+
18
+
19
+ @dataclass
20
+ class Recipe:
21
+ name: str
22
+ source: str
23
+ options: List[RecipeOption] = field(default_factory=list)
24
+
25
+
26
+ @dataclass
27
+ class RecipeInstallResponse:
28
+ recipes: List[Recipe]
29
+ repository: str
30
+ version: str
@@ -15,13 +15,13 @@ from typing import (
15
15
  cast,
16
16
  Iterable,
17
17
  Any,
18
- Tuple,
19
- get_args,
20
18
  TYPE_CHECKING,
19
+ get_args,
20
+ get_origin,
21
21
  )
22
22
  from uuid import UUID
23
23
 
24
- import cbor2
24
+ from _cbor2 import break_marker
25
25
  from cbor2 import CBORDecoder
26
26
  from cbor2._decoder import major_decoders
27
27
  from rewrite import (
@@ -34,8 +34,8 @@ from rewrite import (
34
34
  )
35
35
  from rewrite import Tree, TreeVisitor, Cursor, FileAttributes
36
36
 
37
- from .event import DiffEvent, EventType
38
37
  from . import remote_utils, type_utils
38
+ from .event import DiffEvent, EventType
39
39
 
40
40
  if TYPE_CHECKING:
41
41
  from .remoting import RemotingContext
@@ -198,17 +198,11 @@ class ReceiverContext:
198
198
  ReceiverContext.Registry[type_] = receiver_factory
199
199
 
200
200
 
201
- class ValueDeserializer(Protocol):
202
- def deserialize(
203
- self,
204
- type_: Optional[Type[Any]],
205
- reader: CBORDecoder,
206
- context: "DeserializationContext",
207
- ) -> Optional[Any]: ...
201
+ ValueDeserializer = Callable[[Type[Any], CBORDecoder, "DeserializationContext"], Optional[Any]]
208
202
 
209
203
 
210
204
  class DefaultValueDeserializer(ValueDeserializer):
211
- def deserialize(
205
+ def __call__(
212
206
  self,
213
207
  expected_type: Optional[Type[Any]],
214
208
  reader: CBORDecoder,
@@ -224,15 +218,15 @@ class DefaultValueDeserializer(ValueDeserializer):
224
218
  class DeserializationContext:
225
219
  DefaultDeserializer = DefaultValueDeserializer()
226
220
 
227
- value_deserializers: List[Tuple[Type, ValueDeserializer]]
221
+ value_deserializers: Dict[str, ValueDeserializer]
228
222
 
229
223
  def __init__(
230
224
  self,
231
225
  remoting_context: "RemotingContext",
232
- value_deserializers: Optional[List[Tuple[Type, ValueDeserializer]]] = None,
226
+ value_deserializers: Optional[Dict[Type, ValueDeserializer]] = None,
233
227
  ):
234
228
  self.remoting_context = remoting_context
235
- self.value_deserializers = value_deserializers or []
229
+ self.value_deserializers = value_deserializers or {}
236
230
 
237
231
  def deserialize(self, expected_type: Type[Any], decoder: CBORDecoder) -> Any:
238
232
  if expected_type == UUID:
@@ -256,21 +250,36 @@ class DeserializationContext:
256
250
  if expected_type == float:
257
251
  return decoder.decode()
258
252
 
259
- if issubclass(expected_type, Enum):
253
+ if isinstance(expected_type, type) and issubclass(expected_type, Enum):
260
254
  return expected_type(decoder.decode())
261
255
 
262
256
  initial_byte = decoder.read(1)[0]
263
257
  major_type = initial_byte >> 5
264
258
  subtype = initial_byte & 31
259
+ concrete_type = None
265
260
 
266
261
  # Object ID for Marker, JavaType, etc.
267
262
  if major_type == 0:
268
- obj_id = major_decoders[major_type](decoder, subtype)
263
+ obj_id = decoder.decode_uint(subtype)
269
264
  return self.remoting_context.get_object(obj_id)
270
265
 
266
+ elif major_type == 1:
267
+ return decoder.decode_negint(subtype)
268
+
271
269
  # arrays
272
270
  elif major_type == 4:
273
- pass
271
+ if get_origin(expected_type) in (List, list):
272
+ expected_elem_type = get_args(expected_type)[0]
273
+ array = []
274
+ if subtype != 31:
275
+ for _ in range(subtype):
276
+ array.append(self.deserialize(expected_elem_type, decoder))
277
+ else:
278
+ while not (value := self.deserialize(expected_elem_type, decoder)) == break_marker:
279
+ array.append(value)
280
+ return array
281
+ else:
282
+ concrete_type = decoder.decode()
274
283
 
275
284
  # objects
276
285
  elif major_type == 5:
@@ -278,6 +287,7 @@ class DeserializationContext:
278
287
  assert field_name == "@c"
279
288
  concrete_type = decoder.decode()
280
289
 
290
+ if concrete_type:
281
291
  if concrete_type == "org.openrewrite.marker.SearchResult":
282
292
  field_name = decoder.decode()
283
293
  if field_name == "description":
@@ -291,8 +301,12 @@ class DeserializationContext:
291
301
  id_ = UUID(bytes=decoder.decode())
292
302
  return SearchResult(id_, desc)
293
303
 
294
- for type, value_deserializer in self.value_deserializers:
295
- pass
304
+ if (deser := self.value_deserializers.get(concrete_type)):
305
+ return deser(concrete_type, decoder, self)
306
+
307
+ for type_, value_deserializer in self.value_deserializers.items():
308
+ if type_ == concrete_type:
309
+ return value_deserializer(concrete_type, decoder, self)
296
310
 
297
311
  initial_byte = decoder.read(1)[0]
298
312
  major_type = initial_byte >> 5
@@ -305,6 +319,10 @@ class DeserializationContext:
305
319
  subtype = initial_byte & 31
306
320
  pass
307
321
 
322
+ elif major_type == 3:
323
+ return decoder.decode_string(subtype)
324
+ elif major_type == 2:
325
+ return decoder.decode_bytestring(subtype)
308
326
  else:
309
327
  return major_decoders[major_type](decoder, subtype)
310
328
 
@@ -368,7 +386,7 @@ class DeserializationContext:
368
386
  decoder.read_array_start()
369
387
  concrete_type = decoder.read_text_string()
370
388
  actual_type = type_utils.get_type(concrete_type)
371
- for type_, deserializer in self.value_deserializers:
389
+ for type_, deserializer in self.value_deserializers.items():
372
390
  if issubclass(actual_type, type_):
373
391
  return deserializer.deserialize(actual_type, decoder, self)
374
392
 
@@ -404,7 +422,7 @@ class DeserializationContext:
404
422
  raise NotImplementedError("Expected @c key")
405
423
  concrete_type = decoder.read_text_string()
406
424
  actual_type = type_utils.get_type(concrete_type)
407
- for type_, deserializer in self.value_deserializers:
425
+ for type_, deserializer in self.value_deserializers.items():
408
426
  if issubclass(actual_type, type_):
409
427
  return deserializer.deserialize(actual_type, decoder, self)
410
428
  raise NotImplementedError(
@@ -418,7 +436,7 @@ class JsonReceiver(TreeReceiver):
418
436
  def __init__(self, stream, context: DeserializationContext):
419
437
  super().__init__()
420
438
  self._stream = stream
421
- self._decoder = cbor2.CBORDecoder(self._stream)
439
+ self._decoder = CBORDecoder(self._stream)
422
440
  self._context = context
423
441
  self._count = 0
424
442
 
@@ -56,7 +56,7 @@ class RemotingContext:
56
56
  _remoting_thread_local = threading.local()
57
57
  _recipe_factories: Dict[str, Callable[[str, Dict[str, Any]], Recipe]] = {}
58
58
  _value_serializers: Dict[Type, ValueSerializer] = {}
59
- _value_deserializers: List[Tuple[Type, ValueDeserializer]] = []
59
+ _value_deserializers: Dict[str, ValueDeserializer] = {}
60
60
  _object_to_id_map: Dict[int, int] = {}
61
61
  _id_to_object_map: Dict[int, Any] = {}
62
62
 
@@ -126,19 +126,12 @@ class RemotingContext:
126
126
  return self._recipe_factories[recipe_id](recipe_options)
127
127
 
128
128
  @classmethod
129
- def register_value_serializer(cls, serializer: ValueSerializer) -> None:
130
- cls._value_serializers[serializer.__annotations__["value"]] = serializer
129
+ def register_value_serializer(cls, type_: Type, serializer: ValueSerializer) -> None:
130
+ cls._value_serializers[type_] = serializer
131
131
 
132
132
  @classmethod
133
- def register_value_deserializer(cls, deserializer: ValueDeserializer) -> None:
134
- for i in range(len(cls._value_deserializers)):
135
- type_ = cls._value_deserializers[i][0]
136
- if type_ == deserializer.__annotations__["data"]:
137
- cls._value_deserializers[i] = (type_, deserializer)
138
- return
139
- cls._value_deserializers.append(
140
- (deserializer.__annotations__["data"], deserializer)
141
- )
133
+ def register_value_deserializer(cls, type_name: str, deserializer: ValueDeserializer) -> None:
134
+ cls._value_deserializers[type_name] = deserializer
142
135
 
143
136
 
144
137
  class RemotingExecutionContextView(DelegatingExecutionContext):