pytest-dsl 0.8.0__py3-none-any.whl → 0.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,593 @@
1
+ import xmlrpc.server
2
+ from functools import partial
3
+ import inspect
4
+ import json
5
+ import sys
6
+ import traceback
7
+ import signal
8
+ import atexit
9
+ import threading
10
+ import time
11
+ from typing import Dict, Any, Callable, List
12
+
13
+ from pytest_dsl.core.keyword_manager import keyword_manager
14
+ from pytest_dsl.remote.hook_manager import hook_manager, HookType
15
+ # 导入变量桥接模块,确保hook被注册
16
+ from pytest_dsl.remote import variable_bridge
17
+
18
+ class RemoteKeywordServer:
19
+ """远程关键字服务器,提供关键字的远程调用能力"""
20
+
21
+ def __init__(self, host='localhost', port=8270, api_key=None):
22
+ self.host = host
23
+ self.port = port
24
+ self.server = None
25
+ self.api_key = api_key
26
+
27
+ # 变量存储
28
+ self.shared_variables = {} # 存储共享变量
29
+
30
+ # 注册内置关键字
31
+ self._register_builtin_keywords()
32
+
33
+ # 注册关闭信号处理
34
+ self._register_shutdown_handlers()
35
+
36
+ def _register_builtin_keywords(self):
37
+ """注册所有内置关键字,复用本地模式的加载逻辑"""
38
+ from pytest_dsl.core.plugin_discovery import load_all_plugins, scan_local_keywords
39
+
40
+ # 0. 首先加载内置关键字模块(确保内置关键字被注册)
41
+ print("正在加载内置关键字...")
42
+ try:
43
+ import pytest_dsl.keywords
44
+ print("内置关键字模块加载完成")
45
+ except ImportError as e:
46
+ print(f"加载内置关键字模块失败: {e}")
47
+
48
+ # 1. 加载所有已安装的关键字插件(与本地模式一致)
49
+ print("正在加载第三方关键字插件...")
50
+ load_all_plugins()
51
+
52
+ # 2. 扫描本地keywords目录中的关键字(与本地模式一致)
53
+ print("正在扫描本地关键字...")
54
+ scan_local_keywords()
55
+
56
+ print(f"关键字加载完成,可用关键字数量: {len(keyword_manager._keywords)}")
57
+
58
+ def _register_shutdown_handlers(self):
59
+ """注册关闭信号处理器"""
60
+ def shutdown_handler(signum, frame):
61
+ if hasattr(self, '_shutdown_called') and self._shutdown_called:
62
+ return # 避免重复处理信号
63
+ print(f"接收到信号 {signum},正在关闭服务器...")
64
+
65
+ # 在新线程中执行关闭逻辑,避免阻塞信号处理器
66
+ shutdown_thread = threading.Thread(target=self._shutdown_in_thread, daemon=True)
67
+ shutdown_thread.start()
68
+
69
+ # 保存信号处理器引用
70
+ self._shutdown_handler = shutdown_handler
71
+
72
+ # 只在主线程中注册信号处理器
73
+ try:
74
+ signal.signal(signal.SIGINT, shutdown_handler)
75
+ signal.signal(signal.SIGTERM, shutdown_handler)
76
+ except ValueError:
77
+ # 如果不在主线程中,跳过信号处理器注册
78
+ print("警告: 无法在非主线程中注册信号处理器")
79
+
80
+ # 注册atexit处理器
81
+ atexit.register(self.shutdown)
82
+
83
+ def _shutdown_in_thread(self):
84
+ """在独立线程中执行关闭逻辑"""
85
+ if hasattr(self, '_shutdown_called') and self._shutdown_called:
86
+ return # 避免重复调用
87
+ self._shutdown_called = True
88
+
89
+ print("正在执行服务器关闭流程...")
90
+
91
+ # 执行关闭hook
92
+ try:
93
+ hook_manager.execute_hooks(
94
+ HookType.SERVER_SHUTDOWN,
95
+ server=self,
96
+ shared_variables=self.shared_variables
97
+ )
98
+ except Exception as e:
99
+ print(f"执行关闭hook时出错: {e}")
100
+
101
+ # 关闭XML-RPC服务器
102
+ if self.server:
103
+ try:
104
+ self.server.shutdown()
105
+ self.server.server_close()
106
+ print("服务器已关闭")
107
+ except Exception as e:
108
+ print(f"关闭服务器时出错: {e}")
109
+
110
+ print("服务器关闭完成")
111
+
112
+ # 给主线程一点时间完成清理
113
+ time.sleep(0.1)
114
+
115
+ # 强制退出
116
+ import os
117
+ os._exit(0)
118
+
119
+ def start(self):
120
+ """启动远程关键字服务器"""
121
+ try:
122
+ self.server = xmlrpc.server.SimpleXMLRPCServer((self.host, self.port), allow_none=True)
123
+ except OSError as e:
124
+ if "Address already in use" in str(e):
125
+ print(f"端口 {self.port} 已被占用,请使用其他端口或关闭占用该端口的进程")
126
+ return
127
+ else:
128
+ raise
129
+
130
+ # 执行启动前的hook
131
+ hook_manager.execute_hooks(
132
+ HookType.SERVER_STARTUP,
133
+ server=self,
134
+ shared_variables=self.shared_variables,
135
+ host=self.host,
136
+ port=self.port
137
+ )
138
+ self.server.register_introspection_functions()
139
+
140
+ # 注册核心方法
141
+ self.server.register_function(self.get_keyword_names)
142
+ self.server.register_function(self.run_keyword)
143
+ self.server.register_function(self.get_keyword_arguments)
144
+ self.server.register_function(self.get_keyword_documentation)
145
+ self.server.register_function(self.authenticate)
146
+
147
+ # 注册变量同步方法
148
+ self.server.register_function(self.sync_variables_from_client)
149
+ self.server.register_function(self.get_variables_for_client)
150
+ self.server.register_function(self.set_shared_variable)
151
+ self.server.register_function(self.get_shared_variable)
152
+ self.server.register_function(self.list_shared_variables)
153
+
154
+ print(f"远程关键字服务器已启动,监听地址: {self.host}:{self.port}")
155
+
156
+ try:
157
+ self.server.serve_forever()
158
+ except KeyboardInterrupt:
159
+ print("接收到中断信号,正在关闭服务器...")
160
+ finally:
161
+ self.shutdown()
162
+
163
+ def shutdown(self):
164
+ """关闭服务器(用于atexit处理器)"""
165
+ if hasattr(self, '_shutdown_called') and self._shutdown_called:
166
+ return # 避免重复调用
167
+
168
+ # 调用线程化的关闭逻辑
169
+ self._shutdown_in_thread()
170
+
171
+ def authenticate(self, api_key):
172
+ """验证API密钥"""
173
+ if not self.api_key:
174
+ return True
175
+ return api_key == self.api_key
176
+
177
+ def get_keyword_names(self):
178
+ """获取所有可用的关键字名称"""
179
+ return list(keyword_manager._keywords.keys())
180
+
181
+ def run_keyword(self, name, args_dict, api_key=None):
182
+ """执行关键字并返回结果
183
+
184
+ Args:
185
+ name: 关键字名称
186
+ args_dict: 关键字参数字典
187
+ api_key: API密钥(可选)
188
+
189
+ Returns:
190
+ dict: 包含执行结果的字典,格式为:
191
+ {
192
+ 'status': 'PASS' 或 'FAIL',
193
+ 'return': 返回值 (如果成功),
194
+ 'error': 错误信息 (如果失败),
195
+ 'traceback': 错误堆栈 (如果失败)
196
+ }
197
+ """
198
+ # 验证API密钥
199
+ if self.api_key and not self.authenticate(api_key):
200
+ return {
201
+ 'status': 'FAIL',
202
+ 'error': '认证失败:无效的API密钥',
203
+ 'traceback': []
204
+ }
205
+
206
+ try:
207
+ # 确保参数是字典格式
208
+ if not isinstance(args_dict, dict):
209
+ args_dict = json.loads(args_dict) if isinstance(args_dict, str) else {}
210
+
211
+ # 获取关键字信息
212
+ keyword_info = keyword_manager.get_keyword_info(name)
213
+ if not keyword_info:
214
+ raise Exception(f"未注册的关键字: {name}")
215
+
216
+ # 获取参数映射
217
+ mapping = keyword_info.get('mapping', {})
218
+
219
+ # 准备执行参数
220
+ exec_kwargs = {}
221
+
222
+ # 添加默认的步骤名称
223
+ exec_kwargs['step_name'] = name
224
+
225
+ # 创建测试上下文(所有关键字都需要)
226
+ from pytest_dsl.core.context import TestContext
227
+ test_context = TestContext()
228
+ exec_kwargs['context'] = test_context
229
+
230
+ # 映射参数(通用逻辑)
231
+ for param_name, param_value in args_dict.items():
232
+ if param_name in mapping:
233
+ exec_kwargs[mapping[param_name]] = param_value
234
+ else:
235
+ exec_kwargs[param_name] = param_value
236
+
237
+ # 执行关键字执行前的hook
238
+ before_context = hook_manager.execute_hooks(
239
+ HookType.BEFORE_KEYWORD_EXECUTION,
240
+ server=self,
241
+ shared_variables=self.shared_variables,
242
+ keyword_name=name,
243
+ keyword_args=exec_kwargs,
244
+ test_context=test_context
245
+ )
246
+
247
+ # 从hook上下文中更新执行参数(hook可能修改了参数)
248
+ if 'keyword_args' in before_context.data:
249
+ exec_kwargs.update(before_context.data['keyword_args'])
250
+
251
+ # 执行关键字
252
+ result = keyword_manager.execute(name, **exec_kwargs)
253
+
254
+ # 执行关键字执行后的hook
255
+ after_context = hook_manager.execute_hooks(
256
+ HookType.AFTER_KEYWORD_EXECUTION,
257
+ server=self,
258
+ shared_variables=self.shared_variables,
259
+ keyword_name=name,
260
+ keyword_args=exec_kwargs,
261
+ keyword_result=result,
262
+ test_context=test_context
263
+ )
264
+
265
+ # 从hook上下文中获取可能修改的结果
266
+ if 'keyword_result' in after_context.data:
267
+ result = after_context.data['keyword_result']
268
+
269
+ # 处理返回结果
270
+ return_data = self._process_keyword_result(result, test_context)
271
+
272
+ return {
273
+ 'status': 'PASS',
274
+ 'return': return_data
275
+ }
276
+ except Exception as e:
277
+ exc_type, exc_value, exc_tb = sys.exc_info()
278
+ return {
279
+ 'status': 'FAIL',
280
+ 'error': str(e),
281
+ 'traceback': traceback.format_exception(exc_type, exc_value, exc_tb)
282
+ }
283
+
284
+ def get_keyword_arguments(self, name):
285
+ """获取关键字的参数信息"""
286
+ keyword_info = keyword_manager.get_keyword_info(name)
287
+ if not keyword_info:
288
+ return []
289
+
290
+ return [param.name for param in keyword_info['parameters']]
291
+
292
+ def get_keyword_documentation(self, name):
293
+ """获取关键字的文档信息"""
294
+ keyword_info = keyword_manager.get_keyword_info(name)
295
+ if not keyword_info:
296
+ return ""
297
+
298
+ func = keyword_info['func']
299
+ return inspect.getdoc(func) or ""
300
+
301
+ def _process_keyword_result(self, result, test_context):
302
+ """处理关键字执行结果,确保可序列化并提取上下文变量
303
+
304
+ Args:
305
+ result: 关键字执行结果
306
+ test_context: 测试上下文
307
+
308
+ Returns:
309
+ 处理后的结果
310
+ """
311
+ # 如果结果已经是新格式(包含captures等),直接返回
312
+ if isinstance(result, dict) and ('captures' in result or 'session_state' in result):
313
+ # 确保结果可序列化
314
+ return self._ensure_serializable(result)
315
+
316
+ # 对于传统格式的结果,包装成新格式
317
+ processed_result = {
318
+ "result": result,
319
+ "captures": {},
320
+ "session_state": {},
321
+ "metadata": {}
322
+ }
323
+
324
+ # 从上下文中提取可能的变量(这是为了向后兼容)
325
+ # 注意:这只是一个备用方案,新的关键字应该主动返回所需数据
326
+ if hasattr(test_context, '_variables'):
327
+ # 只提取在执行过程中新增的变量
328
+ processed_result["captures"] = dict(test_context._variables)
329
+
330
+ return self._ensure_serializable(processed_result)
331
+
332
+ def _ensure_serializable(self, obj):
333
+ """确保对象可以被序列化为JSON"""
334
+ if self._is_serializable(obj):
335
+ return obj
336
+
337
+ # 如果不能序列化,尝试转换
338
+ if isinstance(obj, dict):
339
+ serializable_dict = {}
340
+ for key, value in obj.items():
341
+ serializable_dict[key] = self._ensure_serializable(value)
342
+ return serializable_dict
343
+ elif isinstance(obj, (list, tuple)):
344
+ return [self._ensure_serializable(item) for item in obj]
345
+ elif hasattr(obj, '__dict__'):
346
+ return self._ensure_serializable(obj.__dict__)
347
+ else:
348
+ return str(obj)
349
+
350
+ def _is_serializable(self, obj):
351
+ """检查对象是否可以被序列化为JSON"""
352
+ try:
353
+ json.dumps(obj)
354
+ return True
355
+ except (TypeError, OverflowError):
356
+ return False
357
+
358
+ def sync_variables_from_client(self, variables, api_key=None):
359
+ """接收客户端同步的变量
360
+
361
+ Args:
362
+ variables: 客户端发送的变量字典
363
+ api_key: API密钥(可选)
364
+
365
+ Returns:
366
+ dict: 同步结果
367
+ """
368
+ # 验证API密钥
369
+ if self.api_key and not self.authenticate(api_key):
370
+ return {
371
+ 'status': 'error',
372
+ 'error': '认证失败:无效的API密钥'
373
+ }
374
+
375
+ try:
376
+ # 更新共享变量
377
+ for name, value in variables.items():
378
+ self.shared_variables[name] = value
379
+ print(f"接收到客户端变量: {name}")
380
+
381
+ return {
382
+ 'status': 'success',
383
+ 'message': f'成功同步 {len(variables)} 个变量'
384
+ }
385
+ except Exception as e:
386
+ return {
387
+ 'status': 'error',
388
+ 'error': f'同步变量失败: {str(e)}'
389
+ }
390
+
391
+ def get_variables_for_client(self, api_key=None):
392
+ """获取要发送给客户端的变量
393
+
394
+ Args:
395
+ api_key: API密钥(可选)
396
+
397
+ Returns:
398
+ dict: 变量数据
399
+ """
400
+ # 验证API密钥
401
+ if self.api_key and not self.authenticate(api_key):
402
+ return {
403
+ 'status': 'error',
404
+ 'error': '认证失败:无效的API密钥'
405
+ }
406
+
407
+ try:
408
+ return {
409
+ 'status': 'success',
410
+ 'variables': self.shared_variables.copy()
411
+ }
412
+ except Exception as e:
413
+ return {
414
+ 'status': 'error',
415
+ 'error': f'获取变量失败: {str(e)}'
416
+ }
417
+
418
+ def set_shared_variable(self, name, value, api_key=None):
419
+ """设置共享变量
420
+
421
+ Args:
422
+ name: 变量名
423
+ value: 变量值
424
+ api_key: API密钥(可选)
425
+
426
+ Returns:
427
+ dict: 设置结果
428
+ """
429
+ # 验证API密钥
430
+ if self.api_key and not self.authenticate(api_key):
431
+ return {
432
+ 'status': 'error',
433
+ 'error': '认证失败:无效的API密钥'
434
+ }
435
+
436
+ try:
437
+ self.shared_variables[name] = value
438
+ print(f"设置共享变量: {name} = {value}")
439
+ return {
440
+ 'status': 'success',
441
+ 'message': f'成功设置变量 {name}'
442
+ }
443
+ except Exception as e:
444
+ return {
445
+ 'status': 'error',
446
+ 'error': f'设置变量失败: {str(e)}'
447
+ }
448
+
449
+ def get_shared_variable(self, name, api_key=None):
450
+ """获取共享变量
451
+
452
+ Args:
453
+ name: 变量名
454
+ api_key: API密钥(可选)
455
+
456
+ Returns:
457
+ dict: 变量值或错误信息
458
+ """
459
+ # 验证API密钥
460
+ if self.api_key and not self.authenticate(api_key):
461
+ return {
462
+ 'status': 'error',
463
+ 'error': '认证失败:无效的API密钥'
464
+ }
465
+
466
+ try:
467
+ if name in self.shared_variables:
468
+ return {
469
+ 'status': 'success',
470
+ 'value': self.shared_variables[name]
471
+ }
472
+ else:
473
+ return {
474
+ 'status': 'error',
475
+ 'error': f'变量 {name} 不存在'
476
+ }
477
+ except Exception as e:
478
+ return {
479
+ 'status': 'error',
480
+ 'error': f'获取变量失败: {str(e)}'
481
+ }
482
+
483
+ def list_shared_variables(self, api_key=None):
484
+ """列出所有共享变量
485
+
486
+ Args:
487
+ api_key: API密钥(可选)
488
+
489
+ Returns:
490
+ dict: 变量列表
491
+ """
492
+ # 验证API密钥
493
+ if self.api_key and not self.authenticate(api_key):
494
+ return {
495
+ 'status': 'error',
496
+ 'error': '认证失败:无效的API密钥'
497
+ }
498
+
499
+ try:
500
+ return {
501
+ 'status': 'success',
502
+ 'variables': list(self.shared_variables.keys()),
503
+ 'count': len(self.shared_variables)
504
+ }
505
+ except Exception as e:
506
+ return {
507
+ 'status': 'error',
508
+ 'error': f'列出变量失败: {str(e)}'
509
+ }
510
+
511
+ def main():
512
+ """启动远程关键字服务器的主函数"""
513
+ import argparse
514
+
515
+ parser = argparse.ArgumentParser(description='启动pytest-dsl远程关键字服务器')
516
+ parser.add_argument('--host', default='localhost', help='服务器主机名')
517
+ parser.add_argument('--port', type=int, default=8270, help='服务器端口')
518
+ parser.add_argument('--api-key', help='API密钥,用于认证')
519
+ parser.add_argument('--extensions', help='扩展模块路径,多个路径用逗号分隔')
520
+
521
+ args = parser.parse_args()
522
+
523
+ # 在创建服务器之前加载额外的扩展模块(如果指定)
524
+ if args.extensions:
525
+ print("正在加载额外的扩展模块...")
526
+ _load_extensions(args.extensions)
527
+
528
+ # 自动加载当前目录下的扩展
529
+ print("正在自动加载当前目录下的扩展...")
530
+ _auto_load_extensions()
531
+
532
+ # 创建并启动服务器(服务器初始化时会自动加载标准关键字)
533
+ server = RemoteKeywordServer(host=args.host, port=args.port, api_key=args.api_key)
534
+ server.start()
535
+
536
+
537
+ def _load_extensions(extensions_arg):
538
+ """加载指定的扩展模块"""
539
+ import importlib.util
540
+ import os
541
+
542
+ extension_paths = [path.strip() for path in extensions_arg.split(',')]
543
+
544
+ for ext_path in extension_paths:
545
+ if not ext_path:
546
+ continue
547
+
548
+ try:
549
+ if os.path.isfile(ext_path) and ext_path.endswith('.py'):
550
+ # 加载单个Python文件
551
+ module_name = os.path.splitext(os.path.basename(ext_path))[0]
552
+ spec = importlib.util.spec_from_file_location(module_name, ext_path)
553
+ module = importlib.util.module_from_spec(spec)
554
+ spec.loader.exec_module(module)
555
+ print(f"已加载扩展模块: {ext_path}")
556
+ elif os.path.isdir(ext_path):
557
+ # 加载目录下的所有Python文件
558
+ for filename in os.listdir(ext_path):
559
+ if filename.endswith('.py') and not filename.startswith('_'):
560
+ file_path = os.path.join(ext_path, filename)
561
+ module_name = os.path.splitext(filename)[0]
562
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
563
+ module = importlib.util.module_from_spec(spec)
564
+ spec.loader.exec_module(module)
565
+ print(f"已加载扩展模块: {file_path}")
566
+ else:
567
+ # 尝试作为模块名导入
568
+ importlib.import_module(ext_path)
569
+ print(f"已导入扩展模块: {ext_path}")
570
+
571
+ except Exception as e:
572
+ print(f"加载扩展模块失败 {ext_path}: {str(e)}")
573
+
574
+
575
+ def _auto_load_extensions():
576
+ """自动加载当前目录下的扩展"""
577
+ import os
578
+ import importlib.util
579
+
580
+ # 查找当前目录下的extensions目录
581
+ extensions_dir = os.path.join(os.getcwd(), 'extensions')
582
+ if os.path.isdir(extensions_dir):
583
+ print(f"发现扩展目录: {extensions_dir}")
584
+ _load_extensions(extensions_dir)
585
+
586
+ # 查找当前目录下的remote_extensions.py文件
587
+ remote_ext_file = os.path.join(os.getcwd(), 'remote_extensions.py')
588
+ if os.path.isfile(remote_ext_file):
589
+ print(f"发现扩展文件: {remote_ext_file}")
590
+ _load_extensions(remote_ext_file)
591
+
592
+ if __name__ == '__main__':
593
+ main()