zrb 1.0.0b8__py3-none-any.whl → 1.0.0b10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. zrb/__main__.py +3 -0
  2. zrb/builtin/project/add/fastapp/fastapp_task.py +1 -0
  3. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/.coveragerc +11 -0
  4. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/.gitignore +4 -0
  5. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/column/add_column_task.py +4 -4
  6. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/config.py +5 -0
  7. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/entity/add_entity_task.py +108 -1
  8. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/entity/add_entity_util.py +67 -4
  9. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/entity/template/app_template/module/my_module/service/my_entity/my_entity_service.py +5 -5
  10. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/entity/template/app_template/schema/my_entity.py +1 -0
  11. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/entity/template/app_template/test/my_module/my_entity/test_create_my_entity.py +53 -0
  12. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/entity/template/app_template/test/my_module/my_entity/test_delete_my_entity.py +62 -0
  13. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/entity/template/app_template/test/my_module/my_entity/test_read_my_entity.py +65 -0
  14. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/entity/template/app_template/test/my_module/my_entity/test_update_my_entity.py +61 -0
  15. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/entity/template/gateway_subroute.py +57 -13
  16. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/input.py +8 -0
  17. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/module/add_module_util.py +2 -2
  18. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/module/template/app_template/module/gateway/subroute/my_module.py +6 -1
  19. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/module/template/module_task_definition.py +10 -6
  20. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/task.py +65 -14
  21. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/task_util.py +106 -0
  22. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/util.py +6 -86
  23. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/common/base_db_repository.py +27 -11
  24. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/common/base_service.py +140 -51
  25. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/common/error.py +15 -0
  26. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/common/util/parser.py +1 -1
  27. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/config.py +22 -4
  28. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/client/auth_client.py +21 -0
  29. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/migration/versions/3093c7336477_add_auth_tables.py +106 -61
  30. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/migration/versions/8ed025bcc845_create_permissions.py +69 -0
  31. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/migration_metadata.py +3 -4
  32. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/route.py +15 -14
  33. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/permission/permission_service.py +4 -4
  34. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/role/repository/role_db_repository.py +24 -5
  35. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/role/role_service.py +14 -12
  36. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/user/repository/user_db_repository.py +134 -97
  37. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/user/repository/user_repository.py +28 -11
  38. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/user/user_service.py +215 -13
  39. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/user/user_service_factory.py +30 -2
  40. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/gateway/subroute/auth.py +216 -41
  41. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/gateway/util/auth.py +57 -0
  42. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/requirements.txt +7 -1
  43. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/schema/permission.py +2 -0
  44. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/schema/role.py +13 -12
  45. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/schema/user.py +64 -12
  46. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/test/_util/access_token.py +19 -0
  47. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/test/auth/permission/test_create_permission.py +59 -0
  48. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/test/auth/permission/test_delete_permission.py +68 -0
  49. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/test/auth/permission/test_read_permission.py +71 -0
  50. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/test/auth/permission/test_update_permission.py +66 -0
  51. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/test/auth/test_user_session.py +195 -0
  52. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/test/test_health_and_readiness.py +28 -0
  53. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/test/test_homepage.py +17 -0
  54. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/test/test_not_found_error.py +16 -0
  55. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/test.sh +7 -0
  56. zrb/task/base_task.py +10 -10
  57. zrb/task/cmd_task.py +2 -5
  58. zrb/util/cmd/command.py +39 -48
  59. zrb/util/codemod/modification_mode.py +3 -0
  60. zrb/util/codemod/modify_class.py +58 -0
  61. zrb/util/codemod/modify_class_parent.py +68 -0
  62. zrb/util/codemod/modify_class_property.py +128 -0
  63. zrb/util/codemod/modify_dict.py +75 -0
  64. zrb/util/codemod/modify_function.py +65 -0
  65. zrb/util/codemod/modify_function_call.py +68 -0
  66. zrb/util/codemod/modify_method.py +88 -0
  67. zrb/util/codemod/{prepend_code_to_module.py → modify_module.py} +2 -3
  68. zrb/util/file.py +3 -2
  69. {zrb-1.0.0b8.dist-info → zrb-1.0.0b10.dist-info}/METADATA +2 -1
  70. {zrb-1.0.0b8.dist-info → zrb-1.0.0b10.dist-info}/RECORD +72 -55
  71. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/migrate.py +0 -3
  72. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/schema/session.py +0 -48
  73. zrb/util/codemod/append_code_to_class.py +0 -35
  74. zrb/util/codemod/append_code_to_function.py +0 -38
  75. zrb/util/codemod/append_code_to_method.py +0 -55
  76. zrb/util/codemod/append_key_to_dict.py +0 -51
  77. zrb/util/codemod/append_param_to_function_call.py +0 -39
  78. zrb/util/codemod/prepend_parent_to_class.py +0 -38
  79. zrb/util/codemod/prepend_property_to_class.py +0 -55
  80. {zrb-1.0.0b8.dist-info → zrb-1.0.0b10.dist-info}/WHEEL +0 -0
  81. {zrb-1.0.0b8.dist-info → zrb-1.0.0b10.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,16 @@
1
+ import os
2
+
3
+ from fastapi.testclient import TestClient
4
+ from my_app_name.main import app
5
+ from my_app_name.module.gateway.util.view import render_error
6
+
7
+
8
+ def test_not_found_error():
9
+ client = TestClient(app, base_url="http://localhost")
10
+ response = client.get(
11
+ "/holly-grail/not-found/philosopher-stone/not-found/hope/inexist"
12
+ )
13
+ assert response.status_code == 404
14
+ assert response.text == render_error(
15
+ error_message="Not found", status_code=404
16
+ ).body.decode("utf-8")
@@ -0,0 +1,7 @@
1
+ pytest -vv \
2
+ --cov=my_app_name \
3
+ --cov-config=.coveragerc \
4
+ --cov-report=html \
5
+ --cov-report=term \
6
+ --cov-report=term-missing \
7
+ --ignore=_zrb
zrb/task/base_task.py CHANGED
@@ -295,16 +295,16 @@ class BaseTask(AnyTask):
295
295
  async def exec_root_tasks(self, session: AnySession):
296
296
  session.set_main_task(self)
297
297
  session.state_logger.write(session.as_state_log())
298
- log_state = asyncio.create_task(self._log_session_state(session))
299
- root_tasks = [
300
- task
301
- for task in session.get_root_tasks(self)
302
- if session.is_allowed_to_run(task)
303
- ]
304
- root_task_coros = [
305
- run_async(root_task.exec_chain(session)) for root_task in root_tasks
306
- ]
307
298
  try:
299
+ log_state = asyncio.create_task(self._log_session_state(session))
300
+ root_tasks = [
301
+ task
302
+ for task in session.get_root_tasks(self)
303
+ if session.is_allowed_to_run(task)
304
+ ]
305
+ root_task_coros = [
306
+ run_async(root_task.exec_chain(session)) for root_task in root_tasks
307
+ ]
308
308
  await asyncio.gather(*root_task_coros)
309
309
  await session.wait_deferred()
310
310
  session.terminate()
@@ -312,7 +312,7 @@ class BaseTask(AnyTask):
312
312
  return session.final_result
313
313
  except IndexError:
314
314
  return None
315
- except asyncio.CancelledError:
315
+ except (asyncio.CancelledError, KeyboardInterrupt):
316
316
  ctx = self.get_ctx(session)
317
317
  ctx.log_info("Session terminated")
318
318
  finally:
zrb/task/cmd_task.py CHANGED
@@ -141,11 +141,8 @@ class CmdTask(BaseTask):
141
141
  max_error_line=self._max_error_line,
142
142
  )
143
143
  # Check for errors
144
- if return_code != 0:
145
- ctx.log_error(f"Exit status: {return_code}")
146
- raise Exception(
147
- f"Process {self._name} exited ({return_code}): {cmd_result.error}"
148
- )
144
+ if return_code > 0:
145
+ raise Exception(f"Process {self._name} exited ({return_code})")
149
146
  ctx.log_info(f"Exit status: {return_code}")
150
147
  return cmd_result
151
148
 
zrb/util/cmd/command.py CHANGED
@@ -1,26 +1,12 @@
1
1
  import asyncio
2
2
  import os
3
3
  import re
4
- import signal
5
4
  import sys
6
5
  from collections.abc import Callable
7
6
 
8
- from zrb.cmd.cmd_result import CmdResult
9
-
10
- _RUNNING_PROCESSES = []
11
-
12
-
13
- def _cleanup_processes():
14
- """Terminate all running subprocesses on exit."""
15
- for process in _RUNNING_PROCESSES:
16
- if process.returncode is None:
17
- process.terminate()
18
- process.kill()
19
-
7
+ import psutil
20
8
 
21
- # Register a single cleanup function once
22
- signal.signal(signal.SIGINT, lambda sig, frame: _cleanup_processes())
23
- signal.signal(signal.SIGTERM, lambda sig, frame: _cleanup_processes())
9
+ from zrb.cmd.cmd_result import CmdResult
24
10
 
25
11
 
26
12
  def check_unrecommended_commands(cmd_script: str) -> dict[str, str]:
@@ -81,35 +67,40 @@ async def run_command(
81
67
 
82
68
  actual_print_method = print_method if print_method is not None else print
83
69
  cmd_process = None
84
- try:
85
- if cwd is None:
86
- cwd = os.getcwd()
87
- if env_map is None:
88
- env_map = os.environ
89
- cmd_process = await asyncio.create_subprocess_exec(
90
- *cmd,
91
- cwd=cwd,
92
- stdin=sys.stdin if sys.stdin.isatty() else None,
93
- stdout=asyncio.subprocess.PIPE,
94
- stderr=asyncio.subprocess.PIPE,
95
- env=env_map,
96
- bufsize=0,
97
- )
98
- if register_pid_method is not None:
99
- register_pid_method(cmd_process.pid)
100
- stdout_task = asyncio.create_task(
101
- __read_stream(cmd_process.stdout, actual_print_method, max_output_line)
102
- )
103
- stderr_task = asyncio.create_task(
104
- __read_stream(cmd_process.stderr, actual_print_method, max_error_line)
105
- )
106
- # Wait for process to complete and gather stdout/stderr
107
- return_code = await cmd_process.wait()
108
- stdout = await stdout_task
109
- stderr = await stderr_task
110
- return CmdResult(stdout, stderr), return_code
111
- finally:
112
- if cmd_process in _RUNNING_PROCESSES:
113
- _RUNNING_PROCESSES.remove(cmd_process)
114
- if cmd_process is not None and cmd_process.returncode is None:
115
- cmd_process.terminate()
70
+ if cwd is None:
71
+ cwd = os.getcwd()
72
+ if env_map is None:
73
+ env_map = os.environ
74
+ cmd_process = await asyncio.create_subprocess_exec(
75
+ *cmd,
76
+ cwd=cwd,
77
+ stdin=sys.stdin if sys.stdin.isatty() else None,
78
+ stdout=asyncio.subprocess.PIPE,
79
+ stderr=asyncio.subprocess.PIPE,
80
+ env=env_map,
81
+ bufsize=0,
82
+ )
83
+ if register_pid_method is not None:
84
+ register_pid_method(cmd_process.pid)
85
+ stdout_task = asyncio.create_task(
86
+ __read_stream(cmd_process.stdout, actual_print_method, max_output_line)
87
+ )
88
+ stderr_task = asyncio.create_task(
89
+ __read_stream(cmd_process.stderr, actual_print_method, max_error_line)
90
+ )
91
+ # Wait for process to complete and gather stdout/stderr
92
+ return_code = await cmd_process.wait()
93
+ stdout = await stdout_task
94
+ stderr = await stderr_task
95
+ return CmdResult(stdout, stderr), return_code
96
+
97
+
98
+ def kill_pid(pid: int, print_method: Callable[..., None] | None = None):
99
+ actual_print_method = print_method if print_method is not None else print
100
+ parent = psutil.Process(pid)
101
+ children = parent.children(recursive=True)
102
+ for child in children:
103
+ actual_print_method(f"Killing child process {child.pid}")
104
+ child.terminate()
105
+ actual_print_method(f"Killing process {pid}")
106
+ parent.terminate()
@@ -0,0 +1,3 @@
1
+ PREPEND = 0
2
+ APPEND = 1
3
+ REPLACE = 2
@@ -0,0 +1,58 @@
1
+ import libcst as cst
2
+
3
+ from zrb.util.codemod.modification_mode import APPEND, PREPEND, REPLACE
4
+
5
+
6
+ def replace_class_code(original_code: str, class_name: str, new_code: str) -> str:
7
+ return _modify_code(original_code, class_name, new_code, REPLACE)
8
+
9
+
10
+ def prepend_code_to_class(original_code: str, class_name: str, new_code: str) -> str:
11
+ return _modify_code(original_code, class_name, new_code, PREPEND)
12
+
13
+
14
+ def append_code_to_class(original_code: str, class_name: str, new_code: str) -> str:
15
+ return _modify_code(original_code, class_name, new_code, APPEND)
16
+
17
+
18
+ def _modify_code(original_code: str, class_name: str, new_code: str, mode: int) -> str:
19
+ # Parse the original code into a module
20
+ module = cst.parse_module(original_code)
21
+ # Initialize transformer with the class name and method code
22
+ transformer = _ClassCodeModifier(class_name, new_code, mode)
23
+ # Apply the transformation
24
+ modified_module = module.visit(transformer)
25
+ # Check if the class was found
26
+ if not transformer.class_found:
27
+ raise ValueError(f"Class {class_name} not found in the provided code.")
28
+ # Return the modified code
29
+ return modified_module.code
30
+
31
+
32
+ class _ClassCodeModifier(cst.CSTTransformer):
33
+ def __init__(self, class_name: str, new_code: str, mode: int):
34
+ self.class_name = class_name
35
+ self.new_code = cst.parse_module(new_code).body
36
+ self.class_found = False
37
+ self.mode = mode
38
+
39
+ def leave_ClassDef(
40
+ self, original_node: cst.ClassDef, updated_node: cst.ClassDef
41
+ ) -> cst.ClassDef:
42
+ # Check if this is the target class
43
+ if original_node.name.value == self.class_name:
44
+ self.class_found = True
45
+ if self.mode == REPLACE:
46
+ new_body = updated_node.body.with_changes(body=tuple(self.new_code))
47
+ return updated_node.with_changes(body=new_body)
48
+ if self.mode == PREPEND:
49
+ new_body = updated_node.body.with_changes(
50
+ body=tuple(self.new_code) + updated_node.body.body
51
+ )
52
+ return updated_node.with_changes(body=new_body)
53
+ if self.mode == APPEND:
54
+ new_body = updated_node.body.with_changes(
55
+ body=updated_node.body.body + tuple(self.new_code)
56
+ )
57
+ return updated_node.with_changes(body=new_body)
58
+ return updated_node
@@ -0,0 +1,68 @@
1
+ import libcst as cst
2
+
3
+ from zrb.util.codemod.modification_mode import APPEND, PREPEND, REPLACE
4
+
5
+
6
+ def replace_parent_class(
7
+ original_code: str, class_name: str, parent_class_name: str
8
+ ) -> str:
9
+ return _modify_parent_class(original_code, class_name, parent_class_name, REPLACE)
10
+
11
+
12
+ def append_parent_class(
13
+ original_code: str, class_name: str, parent_class_name: str
14
+ ) -> str:
15
+ return _modify_parent_class(original_code, class_name, parent_class_name, APPEND)
16
+
17
+
18
+ def prepend_parent_class(
19
+ original_code: str, class_name: str, parent_class_name: str
20
+ ) -> str:
21
+ return _modify_parent_class(original_code, class_name, parent_class_name, PREPEND)
22
+
23
+
24
+ def _modify_parent_class(
25
+ original_code: str, class_name: str, parent_class_name: str, mode: int
26
+ ) -> str:
27
+ # Parse the original code into a module
28
+ module = cst.parse_module(original_code)
29
+ # Initialize transformer with the class name and parent class name
30
+ transformer = _ParentClassAdder(class_name, parent_class_name, mode)
31
+ # Apply the transformation
32
+ modified_module = module.visit(transformer)
33
+ # Check if the class was found
34
+ if not transformer.class_found:
35
+ raise ValueError(f"Class {class_name} not found in the provided code.")
36
+ # Return the modified code
37
+ return modified_module.code
38
+
39
+
40
+ class _ParentClassAdder(cst.CSTTransformer):
41
+ def __init__(self, class_name: str, parent_class_name: str, mode: int):
42
+ self.class_name = class_name
43
+ self.parent_class_name = parent_class_name
44
+ self.class_found = False
45
+ self.mode = mode
46
+
47
+ def leave_ClassDef(
48
+ self, original_node: cst.ClassDef, updated_node: cst.ClassDef
49
+ ) -> cst.ClassDef:
50
+ # Check if this is the target class
51
+ if original_node.name.value == self.class_name:
52
+ self.class_found = True
53
+ if self.mode == REPLACE:
54
+ new_bases = (cst.Arg(value=cst.Name(self.parent_class_name)),)
55
+ return updated_node.with_changes(bases=new_bases)
56
+ if self.mode == PREPEND:
57
+ new_bases = (
58
+ cst.Arg(value=cst.Name(self.parent_class_name)),
59
+ *updated_node.bases,
60
+ )
61
+ return updated_node.with_changes(bases=new_bases)
62
+ if self.mode == APPEND:
63
+ new_bases = (
64
+ *updated_node.bases,
65
+ cst.Arg(value=cst.Name(self.parent_class_name)),
66
+ )
67
+ return updated_node.with_changes(bases=new_bases)
68
+ return updated_node
@@ -0,0 +1,128 @@
1
+ import libcst as cst
2
+
3
+ from zrb.util.codemod.modification_mode import APPEND, PREPEND
4
+
5
+
6
+ def append_property_to_class(
7
+ original_code: str,
8
+ class_name: str,
9
+ property_name: str,
10
+ annotation: str,
11
+ default_value: str,
12
+ ) -> str:
13
+ return _modify_class_property(
14
+ original_code, class_name, property_name, annotation, default_value, APPEND
15
+ )
16
+
17
+
18
+ def prepend_property_to_class(
19
+ original_code: str,
20
+ class_name: str,
21
+ property_name: str,
22
+ annotation: str,
23
+ default_value: str,
24
+ ) -> str:
25
+ return _modify_class_property(
26
+ original_code, class_name, property_name, annotation, default_value, PREPEND
27
+ )
28
+
29
+
30
+ def _modify_class_property(
31
+ original_code: str,
32
+ class_name: str,
33
+ property_name: str,
34
+ annotation: str,
35
+ default_value: str,
36
+ mode: int,
37
+ ) -> str:
38
+ # Parse the original code into a module
39
+ module = cst.parse_module(original_code)
40
+ # Initialize transformer with the class name, property name, annotation, and default value
41
+ transformer = _ClassPropertyModifier(
42
+ class_name, property_name, annotation, default_value, mode
43
+ )
44
+ # Apply the transformation
45
+ modified_module = module.visit(transformer)
46
+ # Check if the class was found
47
+ if not transformer.class_found:
48
+ raise ValueError(f"Class {class_name} not found in the provided code.")
49
+ # Return the modified code
50
+ return modified_module.code
51
+
52
+
53
+ class _ClassPropertyModifier(cst.CSTTransformer):
54
+ def __init__(
55
+ self,
56
+ class_name: str,
57
+ property_name: str,
58
+ annotation: str,
59
+ default_value: str,
60
+ mode: int,
61
+ ):
62
+ self.class_name = class_name
63
+ self.property_name = property_name
64
+ self.annotation = cst.Annotation(cst.parse_expression(annotation))
65
+ self.default_value = cst.parse_expression(default_value)
66
+ self.class_found = False
67
+ self.mode = mode
68
+
69
+ def leave_ClassDef(
70
+ self, original_node: cst.ClassDef, updated_node: cst.ClassDef
71
+ ) -> cst.ClassDef:
72
+ # Check if this is the target class
73
+ if original_node.name.value == self.class_name:
74
+ self.class_found = True
75
+ # Create the annotated property with a default value
76
+ new_property = cst.SimpleStatementLine(
77
+ body=[
78
+ cst.AnnAssign(
79
+ target=cst.Name(self.property_name),
80
+ annotation=self.annotation,
81
+ value=self.default_value,
82
+ )
83
+ ]
84
+ )
85
+ if self.mode == PREPEND:
86
+ new_body = cst.IndentedBlock(
87
+ body=(new_property,) + updated_node.body.body
88
+ )
89
+ return updated_node.with_changes(body=new_body)
90
+ if self.mode == APPEND:
91
+ # Identify properties and methods
92
+ properties = []
93
+ methods = []
94
+ for stmt in updated_node.body.body:
95
+ if isinstance(stmt, cst.SimpleStatementLine) and isinstance(
96
+ stmt.body[0], (cst.AnnAssign, cst.Assign)
97
+ ):
98
+ properties.append(stmt)
99
+ elif isinstance(stmt, cst.FunctionDef):
100
+ methods.append(stmt)
101
+ if properties:
102
+ # Class has properties
103
+ last_property_index = updated_node.body.body.index(properties[-1])
104
+ new_body = cst.IndentedBlock(
105
+ body=(
106
+ updated_node.body.body[: last_property_index + 1]
107
+ + (new_property,)
108
+ + updated_node.body.body[last_property_index + 1 :]
109
+ )
110
+ )
111
+ return updated_node.with_changes(body=new_body)
112
+ if methods:
113
+ # Class doesn't have properties but has methods
114
+ first_method_index = updated_node.body.body.index(methods[0])
115
+ new_body = cst.IndentedBlock(
116
+ body=(
117
+ updated_node.body.body[:first_method_index]
118
+ + (new_property,)
119
+ + updated_node.body.body[first_method_index:]
120
+ )
121
+ )
122
+ return updated_node.with_changes(body=new_body)
123
+ # Class is empty, add add the bottom
124
+ new_body = cst.IndentedBlock(
125
+ body=updated_node.body.body + (new_property,)
126
+ )
127
+ return updated_node.with_changes(body=new_body)
128
+ return updated_node
@@ -0,0 +1,75 @@
1
+ import libcst as cst
2
+
3
+ from zrb.util.codemod.modification_mode import APPEND, PREPEND
4
+
5
+
6
+ def prepend_key_to_dict(
7
+ original_code: str, dictionary_name: str, new_key: str, new_value: str
8
+ ) -> str:
9
+ return _modify_dict(original_code, dictionary_name, new_key, new_value, PREPEND)
10
+
11
+
12
+ def append_key_to_dict(
13
+ original_code: str, dictionary_name: str, new_key: str, new_value: str
14
+ ) -> str:
15
+ return _modify_dict(original_code, dictionary_name, new_key, new_value, APPEND)
16
+
17
+
18
+ def _modify_dict(
19
+ original_code: str, dictionary_name: str, new_key: str, new_value: str, mode: int
20
+ ) -> str:
21
+ # Parse the original code into a module
22
+ module = cst.parse_module(original_code)
23
+ # Initialize the transformer with the necessary information
24
+ transformer = _DictionaryModifier(dictionary_name, new_key, new_value, mode)
25
+ # Apply the transformation
26
+ modified_module = module.visit(transformer)
27
+ # Error handling: raise an error if the dictionary is not found
28
+ if not transformer.found:
29
+ raise ValueError(
30
+ f"Dictionary {dictionary_name} not found in the provided code."
31
+ )
32
+ # Return the modified code
33
+ return modified_module.code
34
+
35
+
36
+ class _DictionaryModifier(cst.CSTTransformer):
37
+ def __init__(self, dictionary_name: str, new_key: str, new_value: str, mode: int):
38
+ self.dictionary_name = dictionary_name
39
+ self.new_key = new_key
40
+ self.new_value = new_value
41
+ self.found = False
42
+ self.mode = mode
43
+
44
+ def leave_Assign(
45
+ self, original_node: cst.Assign, updated_node: cst.Assign
46
+ ) -> cst.Assign:
47
+ # Extract the first target from updated_node, which will be an AssignTarget
48
+ target = updated_node.targets[0]
49
+ # Check if the target is a Name (which should represent the dictionary)
50
+ if (
51
+ isinstance(target.target, cst.Name)
52
+ and target.target.value == self.dictionary_name
53
+ ):
54
+ # Check if it's a dictionary initialization (e.g., my_dict = {...})
55
+ if isinstance(updated_node.value, cst.Dict):
56
+ self.found = True
57
+ if self.mode == PREPEND:
58
+ new_entries = (
59
+ cst.DictElement(
60
+ key=cst.SimpleString(f'"{self.new_key}"'),
61
+ value=cst.SimpleString(f'"{self.new_value}"'),
62
+ ),
63
+ ) + updated_node.value.elements
64
+ new_dict = updated_node.value.with_changes(elements=new_entries)
65
+ return updated_node.with_changes(value=new_dict)
66
+ if self.mode == APPEND:
67
+ new_entries = updated_node.value.elements + (
68
+ cst.DictElement(
69
+ key=cst.SimpleString(f'"{self.new_key}"'),
70
+ value=cst.SimpleString(f'"{self.new_value}"'),
71
+ ),
72
+ )
73
+ new_dict = updated_node.value.with_changes(elements=new_entries)
74
+ return updated_node.with_changes(value=new_dict)
75
+ return updated_node
@@ -0,0 +1,65 @@
1
+ import libcst as cst
2
+
3
+ from zrb.util.codemod.modification_mode import APPEND, PREPEND, REPLACE
4
+
5
+
6
+ def replace_function_code(original_code: str, function_name: str, new_code: str) -> str:
7
+ return _modify_function(original_code, function_name, new_code, REPLACE)
8
+
9
+
10
+ def prepend_code_to_function(
11
+ original_code: str, function_name: str, new_code: str
12
+ ) -> str:
13
+ return _modify_function(original_code, function_name, new_code, PREPEND)
14
+
15
+
16
+ def append_code_to_function(
17
+ original_code: str, function_name: str, new_code: str
18
+ ) -> str:
19
+ return _modify_function(original_code, function_name, new_code, APPEND)
20
+
21
+
22
+ def _modify_function(
23
+ original_code: str, function_name: str, new_code: str, mode: int
24
+ ) -> str:
25
+ # Parse the original code into a module
26
+ module = cst.parse_module(original_code)
27
+ # Initialize the transformer with the necessary information
28
+ transformer = _FunctionCodeModifier(function_name, new_code, mode)
29
+ # Apply the transformation
30
+ modified_module = module.visit(transformer)
31
+ # Error handling: raise an error if the class or function is not found
32
+ if not transformer.function_found:
33
+ raise ValueError(f"Function {function_name} not found.")
34
+ # Return the modified code
35
+ return modified_module.code
36
+
37
+
38
+ class _FunctionCodeModifier(cst.CSTTransformer):
39
+ def __init__(self, function_name: str, new_code: str, mode: int):
40
+ self.function_name = function_name
41
+ # Use parse_module to handle multiple statements
42
+ self.new_code = cst.parse_module(new_code).body
43
+ self.function_found = False
44
+ self.mode = mode
45
+
46
+ def leave_FunctionDef(
47
+ self, original_node: cst.ClassDef, updated_node: cst.ClassDef
48
+ ) -> cst.ClassDef:
49
+ # Check if the class matches the target class
50
+ if original_node.name.value == self.function_name:
51
+ self.function_found = True
52
+ if self.mode == REPLACE:
53
+ new_body = updated_node.body.with_changes(body=tuple(self.new_code))
54
+ return updated_node.with_changes(body=new_body)
55
+ if self.mode == PREPEND:
56
+ new_body = updated_node.body.with_changes(
57
+ body=tuple(self.new_code) + updated_node.body.body
58
+ )
59
+ return updated_node.with_changes(body=new_body)
60
+ if self.mode == APPEND:
61
+ new_body = updated_node.body.with_changes(
62
+ body=updated_node.body.body + tuple(self.new_code)
63
+ )
64
+ return updated_node.with_changes(body=new_body)
65
+ return updated_node
@@ -0,0 +1,68 @@
1
+ import libcst as cst
2
+
3
+ from zrb.util.codemod.modification_mode import APPEND, PREPEND, REPLACE
4
+
5
+
6
+ def replace_function_call_param(
7
+ original_code: str, function_name: str, new_param: str
8
+ ) -> str:
9
+ return _modify_function_call(original_code, function_name, new_param, REPLACE)
10
+
11
+
12
+ def prepend_param_to_function_call(
13
+ original_code: str, function_name: str, new_param: str
14
+ ) -> str:
15
+ return _modify_function_call(original_code, function_name, new_param, PREPEND)
16
+
17
+
18
+ def append_param_to_function_call(
19
+ original_code: str, function_name: str, new_param: str
20
+ ) -> str:
21
+ return _modify_function_call(original_code, function_name, new_param, APPEND)
22
+
23
+
24
+ def _modify_function_call(
25
+ original_code: str, function_name: str, new_param: str, mode: int
26
+ ) -> str:
27
+ # Parse the original code into a module
28
+ module = cst.parse_module(original_code)
29
+ # Initialize the transformer with the necessary information
30
+ transformer = _FunctionCallParamModifier(function_name, new_param, mode)
31
+ # Apply the transformation
32
+ modified_module = module.visit(transformer)
33
+ # Error handling: raise an error if the function call is not found
34
+ if not transformer.param_added:
35
+ raise ValueError(
36
+ f"Function call to {function_name} not found in the provided code."
37
+ )
38
+ # Return the modified code
39
+ return modified_module.code
40
+
41
+
42
+ class _FunctionCallParamModifier(cst.CSTTransformer):
43
+ def __init__(self, func_name: str, new_param: str, mode: int):
44
+ self.func_name = func_name
45
+ # Parse the new parameter to ensure it’s a valid CST node
46
+ self.new_param = cst.parse_expression(new_param)
47
+ self.param_added = False
48
+ self.mode = mode
49
+
50
+ def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call:
51
+ # Check if the function call name matches the target function
52
+ if (
53
+ isinstance(original_node.func, cst.Name)
54
+ and original_node.func.value == self.func_name
55
+ ):
56
+ if self.mode == REPLACE:
57
+ new_args = (cst.Arg(value=self.new_param),)
58
+ self.param_added = True
59
+ return updated_node.with_changes(args=new_args)
60
+ if self.mode == PREPEND:
61
+ new_args = (cst.Arg(value=self.new_param),) + updated_node.args
62
+ self.param_added = True
63
+ return updated_node.with_changes(args=new_args)
64
+ if self.mode == APPEND:
65
+ new_args = updated_node.args + (cst.Arg(value=self.new_param),)
66
+ self.param_added = True
67
+ return updated_node.with_changes(args=new_args)
68
+ return updated_node