mindstudio-probe 8.2.0__py3-none-any.whl → 8.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.3.0.dist-info}/METADATA +2 -2
  2. {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.3.0.dist-info}/RECORD +90 -79
  3. msprobe/README.md +7 -5
  4. msprobe/core/common/const.py +6 -0
  5. msprobe/core/common/db_manager.py +35 -4
  6. msprobe/core/common/file_utils.py +105 -27
  7. msprobe/core/common/framework_adapter.py +7 -6
  8. msprobe/core/common/megatron_utils.py +59 -0
  9. msprobe/core/common/utils.py +14 -3
  10. msprobe/core/compare/find_first/analyzer.py +8 -7
  11. msprobe/core/compare/find_first/graph.py +11 -3
  12. msprobe/core/compare/find_first/utils.py +2 -1
  13. msprobe/core/compare/highlight.py +13 -6
  14. msprobe/core/compare/multiprocessing_compute.py +17 -10
  15. msprobe/core/compare/utils.py +14 -5
  16. msprobe/core/data_dump/data_collector.py +18 -21
  17. msprobe/core/data_dump/data_processor/pytorch_processor.py +43 -20
  18. msprobe/core/data_dump/json_writer.py +18 -8
  19. msprobe/core/data_dump/scope.py +4 -6
  20. msprobe/core/hook_manager.py +37 -3
  21. msprobe/core/service.py +18 -5
  22. msprobe/core/single_save/single_comparator.py +16 -3
  23. msprobe/docs/01.installation.md +7 -5
  24. msprobe/docs/02.config_introduction.md +14 -1
  25. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  26. msprobe/docs/06.data_dump_MindSpore.md +1 -1
  27. msprobe/docs/08.accuracy_checker_online_PyTorch.md +295 -0
  28. msprobe/docs/10.accuracy_compare_PyTorch.md +46 -5
  29. msprobe/docs/14.data_parse_PyTorch.md +1 -1
  30. msprobe/docs/15.free_benchmarking_PyTorch.md +1 -1
  31. msprobe/docs/19.monitor.md +2 -0
  32. msprobe/docs/21.visualization_PyTorch.md +15 -80
  33. msprobe/docs/22.visualization_MindSpore.md +20 -104
  34. msprobe/docs/23.generate_operator_PyTorch.md +1 -1
  35. msprobe/docs/25.tool_function_introduction.md +1 -0
  36. msprobe/docs/26.data_dump_PyTorch_baseline.md +7 -7
  37. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  38. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  39. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  40. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  41. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  42. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  43. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  44. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +1 -1
  45. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +1 -1
  46. msprobe/mindspore/cell_processor.py +33 -5
  47. msprobe/mindspore/compare/common_dir_compare.py +22 -26
  48. msprobe/mindspore/compare/utils.py +1 -2
  49. msprobe/mindspore/debugger/precision_debugger.py +1 -1
  50. msprobe/mindspore/dump/cell_dump_process.py +73 -62
  51. msprobe/mindspore/dump/graph_mode_cell_dump.py +21 -10
  52. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +2 -0
  53. msprobe/msprobe.py +6 -4
  54. msprobe/pytorch/api_accuracy_checker/common/config.py +36 -3
  55. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +24 -0
  56. msprobe/pytorch/api_accuracy_checker/compare/compare.py +12 -2
  57. msprobe/pytorch/api_accuracy_checker/config.yaml +6 -1
  58. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +1 -1
  59. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +132 -12
  60. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
  61. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +205 -0
  62. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +378 -0
  63. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +239 -0
  64. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +115 -0
  65. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +250 -0
  66. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  67. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +198 -0
  68. msprobe/pytorch/attl_manager.py +65 -0
  69. msprobe/pytorch/common/utils.py +22 -2
  70. msprobe/pytorch/compare/utils.py +3 -3
  71. msprobe/pytorch/debugger/debugger_config.py +10 -0
  72. msprobe/pytorch/dump/module_dump/hook_wrapper.py +34 -7
  73. msprobe/pytorch/dump/module_dump/module_processer.py +23 -10
  74. msprobe/pytorch/hook_module/api_register.py +6 -1
  75. msprobe/pytorch/monitor/module_hook.py +28 -9
  76. msprobe/pytorch/online_dispatch/dispatch.py +42 -24
  77. msprobe/pytorch/pt_config.py +57 -2
  78. msprobe/pytorch/pytorch_service.py +11 -2
  79. msprobe/visualization/builder/graph_builder.py +170 -64
  80. msprobe/visualization/builder/graph_merger.py +0 -1
  81. msprobe/visualization/builder/msprobe_adapter.py +1 -1
  82. msprobe/visualization/db_utils.py +25 -2
  83. msprobe/visualization/graph/base_node.py +0 -24
  84. msprobe/visualization/graph/graph.py +5 -14
  85. msprobe/visualization/graph_service.py +29 -53
  86. msprobe/visualization/utils.py +11 -1
  87. {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.3.0.dist-info}/LICENSE +0 -0
  88. {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.3.0.dist-info}/WHEEL +0 -0
  89. {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.3.0.dist-info}/entry_points.txt +0 -0
  90. {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.3.0.dist-info}/top_level.txt +0 -0
@@ -51,6 +51,8 @@ from msprobe.pytorch.pt_config import parse_json_config
51
51
  from msprobe.core.common.const import Const, FileCheckConst, CompareConst
52
52
  from msprobe.core.common.utils import safe_get_value, CompareException, is_int, check_op_str_pattern_valid
53
53
  from msprobe.pytorch.common.utils import seed_all
54
+ from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import ATTL, ATTLConfig, move2device_exec
55
+ from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.device_dispatch import ConsumerDispatcher
54
56
  from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import generate_cpu_params, generate_device_params, \
55
57
  ExecParams
56
58
 
@@ -88,22 +90,27 @@ seed_all()
88
90
 
89
91
  def run_ut(config):
90
92
  logger.info("start UT test")
91
-
92
- logger.info(f"UT task result will be saved in {config.result_csv_path}")
93
- logger.info(f"UT task details will be saved in {config.details_csv_path}")
93
+ if config.online_config.is_online:
94
+ logger.info(f"UT task result will be saved in {config.result_csv_path}".replace(".csv", "_rank*.csv"))
95
+ logger.info(f"UT task details will be saved in {config.details_csv_path}".replace(".csv", "_rank*.csv"))
96
+ else:
97
+ logger.info(f"UT task result will be saved in {config.result_csv_path}")
98
+ logger.info(f"UT task details will be saved in {config.details_csv_path}")
94
99
 
95
100
  if config.save_error_data:
96
101
  logger.info(f"UT task error_data will be saved in {config.error_data_path}")
97
102
  compare = Comparator(config.result_csv_path, config.details_csv_path, config.is_continue_run_ut, config=config)
98
103
 
99
-
100
- csv_df = read_csv(config.result_csv_path)
101
- try:
102
- api_name_set = {row[0] for row in csv_df.itertuples(index=False, name=None)}
103
- except IndexError:
104
- logger.error(f"Read {config.result_csv_path} error, api_name_set is empty.")
105
- api_name_set = set()
106
- run_api_offline(config, compare, api_name_set)
104
+ if config.online_config.is_online:
105
+ run_api_online(config, compare)
106
+ else:
107
+ csv_df = read_csv(config.result_csv_path)
108
+ try:
109
+ api_name_set = {row[0] for row in csv_df.itertuples(index=False, name=None)}
110
+ except IndexError:
111
+ logger.error(f"Read {config.result_csv_path} error, api_name_set is empty.")
112
+ api_name_set = set()
113
+ run_api_offline(config, compare, api_name_set)
107
114
  for result_csv_path, details_csv_path in zip(compare.save_path_list, compare.detail_save_path_list):
108
115
  change_mode(result_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
109
116
  change_mode(details_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
@@ -157,6 +164,60 @@ def run_api_offline(config, compare, api_name_set):
157
164
  gc.collect()
158
165
 
159
166
 
167
+ def run_api_online(config, compare):
168
+ attl = init_attl(config.online_config)
169
+ dispatcher = ConsumerDispatcher(compare=compare)
170
+ dispatcher.start(handle_func=run_torch_api_online, config=config)
171
+
172
+ def tcp_communication_flow():
173
+ while True:
174
+ api_data = attl.recv()
175
+ if api_data == 'STOP_':
176
+ continue
177
+ if api_data == 'KILL_':
178
+ time.sleep(1)
179
+ logger.info("==========接收到STOP信号==========")
180
+ dispatcher.stop()
181
+ attl.stop_serve()
182
+ time.sleep(1)
183
+ break
184
+ if not isinstance(api_data, ApiData):
185
+ continue
186
+ api_full_name = api_data.name
187
+ _, api_name = extract_basic_api_segments(api_full_name)
188
+ if blacklist_and_whitelist_filter(api_name, config.black_list, config.white_list):
189
+ continue
190
+ if api_data.rank in config.online_config.rank_list:
191
+ dispatcher.update_consume_queue(api_data)
192
+
193
+ def shared_storage_communication_flow():
194
+ flag_num = -1
195
+ while True:
196
+ api_data = attl.download()
197
+ if api_data == "start":
198
+ if flag_num == -1:
199
+ flag_num += 1
200
+ flag_num += 1
201
+ if api_data == "end":
202
+ flag_num -= 1
203
+ if flag_num == 0:
204
+ dispatcher.stop()
205
+ break
206
+ if not isinstance(api_data, ApiData):
207
+ continue
208
+ api_full_name = api_data.name
209
+ _, api_name = extract_basic_api_segments(api_full_name)
210
+ if blacklist_and_whitelist_filter(api_name, config.black_list, config.white_list):
211
+ continue
212
+ if api_data.rank in config.online_config.rank_list:
213
+ dispatcher.update_consume_queue(api_data)
214
+
215
+ if config.online_config.nfs_path:
216
+ shared_storage_communication_flow()
217
+ else:
218
+ tcp_communication_flow()
219
+
220
+
160
221
  def blacklist_and_whitelist_filter(api_name, black_list, white_list):
161
222
  """
162
223
  run api(api_name) if api_name not in black_list and in white_list.
@@ -254,6 +315,21 @@ def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict
254
315
  return UtDataInfo(bench_grad_out, device_grad_out, device_out, out, bench_grad, in_fwd_data_list, backward_message)
255
316
 
256
317
 
318
+ def run_torch_api_online(api_full_name, api_data, backward_content):
319
+ in_fwd_data_list = []
320
+ api_type, api_name = extract_basic_api_segments(api_full_name)
321
+ args, kwargs, out = api_data.args, api_data.kwargs, api_data.result
322
+ in_fwd_data_list.append(args)
323
+ in_fwd_data_list.append(kwargs)
324
+ if kwargs.get("device"):
325
+ del kwargs["device"]
326
+
327
+ device_exec_params = ExecParams(api_type, api_name, current_device, args, kwargs, False, None)
328
+ device_out = exec_api(device_exec_params)
329
+ device_out = move2device_exec(device_out, "cpu")
330
+ return UtDataInfo(None, None, out, device_out, None, in_fwd_data_list, None, rank=api_data.rank)
331
+
332
+
257
333
  def check_need_grad(api_info_dict):
258
334
  need_grad = True
259
335
  if api_info_dict.get(Const.INPUT_KWARGS) and "out" in api_info_dict.get(Const.INPUT_KWARGS):
@@ -313,6 +389,16 @@ def initialize_save_error_data(error_data_path):
313
389
  return error_data_path
314
390
 
315
391
 
392
+ def init_attl(config):
393
+ """config: OnlineConfig"""
394
+ attl = ATTL('gpu', ATTLConfig(is_benchmark_device=True,
395
+ connect_ip=config.host,
396
+ connect_port=config.port,
397
+ nfs_path=config.nfs_path,
398
+ tls_path=config.tls_path))
399
+ return attl
400
+
401
+
316
402
  def _run_ut_parser(parser):
317
403
  parser.add_argument("-api_info", "--api_info_file", dest="api_info_file", default="", type=str,
318
404
  help="<Optional> The api param tool result file: generate from api param tool, "
@@ -395,6 +481,38 @@ def _run_ut(parser=None):
395
481
  _run_ut_parser(parser)
396
482
  args = parser.parse_args(sys.argv[1:])
397
483
  run_ut_command(args)
484
+
485
+
486
+ def checked_online_config(online_config):
487
+ if not online_config.is_online:
488
+ return
489
+ if not isinstance(online_config.is_online, bool):
490
+ raise ValueError("is_online must be bool type")
491
+ # rank_list
492
+ if not isinstance(online_config.rank_list, list):
493
+ raise ValueError("rank_list must be a list")
494
+ if online_config.rank_list and not all(isinstance(rank, int) for rank in online_config.rank_list):
495
+ raise ValueError("All elements in rank_list must be integers")
496
+
497
+ # nfs_path
498
+ if online_config.nfs_path:
499
+ check_file_or_directory_path(online_config.nfs_path, isdir=True)
500
+ return
501
+ # tls_path
502
+ if online_config.tls_path:
503
+ check_file_or_directory_path(online_config.tls_path, isdir=True)
504
+ check_file_or_directory_path(os.path.join(online_config.tls_path, "server.key"))
505
+ check_file_or_directory_path(os.path.join(online_config.tls_path, "server.crt"))
506
+ check_file_or_directory_path(os.path.join(online_config.tls_path, "ca.crt"))
507
+ crl_path = os.path.join(online_config.tls_path, "crl.pem")
508
+ if os.path.exists(crl_path):
509
+ check_file_or_directory_path(crl_path)
510
+
511
+ # host and port
512
+ if not isinstance(online_config.host, str) or not re.match(Const.ipv4_pattern, online_config.host):
513
+ raise Exception(f"host: {online_config.host} is invalid.")
514
+ if not isinstance(online_config.port, int) or not (0 < online_config.port <= 65535):
515
+ raise Exception(f"port: {online_config.port} is invalid, port range 0-65535.")
398
516
 
399
517
 
400
518
  def run_ut_command(args):
@@ -407,7 +525,7 @@ def run_ut_command(args):
407
525
  else:
408
526
  checker_config = CheckerConfig()
409
527
 
410
- if not args.api_info_file:
528
+ if not checker_config.is_online and not args.api_info_file:
411
529
  logger.error("Please provide api_info_file for offline run ut.")
412
530
  raise Exception("Please provide api_info_file for offline run ut.")
413
531
 
@@ -470,6 +588,8 @@ def run_ut_command(args):
470
588
  global UT_ERROR_DATA_DIR
471
589
  UT_ERROR_DATA_DIR = 'ut_error_data' + time_info
472
590
  error_data_path = initialize_save_error_data(error_data_path)
591
+ online_config = checker_config.get_online_config()
592
+ checked_online_config(online_config)
473
593
  config_params = {
474
594
  'forward_content': forward_content,
475
595
  'backward_content': backward_content,
@@ -0,0 +1,205 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import glob
17
+ import os.path
18
+ import time
19
+ from multiprocessing import Queue
20
+ from typing import Optional, Union, Dict, Any
21
+ from dataclasses import dataclass
22
+
23
+ import torch
24
+
25
+ from msprobe.pytorch.api_accuracy_checker.common.utils import ApiData
26
+ from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.client import TCPClient
27
+ from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.server import TCPServer
28
+ from msprobe.core.common.file_utils import remove_path
29
+ from msprobe.pytorch.common.utils import logger, save_api_data, load_api_data, save_pkl, load_pkl
30
+ from msprobe.core.common.decorator import recursion_depth_decorator
31
+
32
+ BufferType = Union[ApiData, Dict[str, Any], str] # Union[Tensor, Tuple[Optional[Tensor]]]
33
+
34
+
35
+ @dataclass
36
+ class ATTLConfig:
37
+ is_benchmark_device: bool
38
+ connect_ip: str
39
+ connect_port: int
40
+ # storage_config
41
+ nfs_path: str = None
42
+ tls_path: str = None
43
+ check_sum: bool = True
44
+ queue_size: int = 50
45
+
46
+
47
+ class ATTL:
48
+ def __init__(self, session_id: str, session_config: ATTLConfig, need_dump=True) -> None:
49
+ self.session_id = session_id
50
+ self.session_config = session_config
51
+ self.logger = logger
52
+ self.socket_manager = None
53
+ self.data_queue = Queue(maxsize=50)
54
+ self.dequeue_list = []
55
+ self.message_end = False
56
+ self.kill_progress = False
57
+ self.nfs_path = None
58
+ if self.session_config.nfs_path:
59
+ self.nfs_path = self.session_config.nfs_path
60
+ elif self.session_config.is_benchmark_device:
61
+
62
+ self.socket_manager = TCPServer(self.session_config.connect_port,
63
+ self.data_queue,
64
+ self.session_config.check_sum,
65
+ self.session_config.tls_path)
66
+ self.socket_manager.start()
67
+ elif need_dump:
68
+ self.socket_manager = TCPClient(self.session_config.connect_ip,
69
+ self.session_config.connect_port,
70
+ self.session_config.check_sum,
71
+ self.session_config.tls_path)
72
+ self.socket_manager.start()
73
+
74
+ def stop_serve(self):
75
+ if isinstance(self.socket_manager, TCPServer):
76
+ self.socket_manager.stop()
77
+
78
+ def send(self, buffer: BufferType) -> None:
79
+ """
80
+ npu major in 'send' (client)
81
+ """
82
+
83
+ # if tcp connection lost,
84
+ if self.socket_manager.signal_exit:
85
+ raise ConnectionError(f"Failed to connect to {self.session_config.connect_ip}.")
86
+
87
+ # know receiver receive and go next
88
+ if isinstance(buffer, ApiData):
89
+ buffer = move2target_device(buffer, torch.device('cpu'))
90
+
91
+ if 'device' in buffer.kwargs:
92
+ buffer.kwargs.pop('device')
93
+ rank = buffer.rank if hasattr(buffer, "rank") and buffer.rank is not None else 0
94
+ step = buffer.step if hasattr(buffer, "step") else 0
95
+ try:
96
+ io_buff = save_api_data(buffer)
97
+ except Exception as e:
98
+ self.logger.info(f"{buffer.name} can not be saved, skip: {e}")
99
+ return
100
+ data = io_buff.getvalue()
101
+ self.socket_manager.add_to_sending_queue(data, rank=rank, step=step)
102
+
103
+ def recv(self, timeout_ms=0) -> Optional[BufferType]:
104
+ buffer = ''
105
+ while not buffer:
106
+ if timeout_ms > 0:
107
+ time.sleep(timeout_ms / 1000.0)
108
+ if not buffer and not self.data_queue.empty():
109
+ buffer = self.data_queue.get()
110
+ break
111
+ if not buffer and timeout_ms > 0: # timeout is the only case we give up and return None
112
+ break
113
+ if self.message_end and self.data_queue.empty():
114
+ buffer = b"KILL_CONFIRM"
115
+ self.kill_progress = True
116
+ break
117
+ time.sleep(0.1) # waiting outside the lock before next attempt
118
+ if not buffer:
119
+ # this is a result of a timeout
120
+ self.logger.info(f"RECEIVE API DATA TIMED OUT")
121
+ else:
122
+ if buffer == b"STOP_":
123
+ return "STOP_"
124
+ if buffer == b"KILL_":
125
+ self.message_end = True
126
+ return "STOP_"
127
+ if buffer == b"KILL_CONFIRM":
128
+ self.kill_progress = True
129
+ return "KILL_"
130
+ try:
131
+ buffer = load_api_data(buffer)
132
+ except Exception as e:
133
+ self.logger.warning("there is something error. please check it. %s", e)
134
+ if isinstance(buffer, bytes):
135
+ return ''
136
+ if isinstance(buffer, str):
137
+ return buffer
138
+
139
+ return buffer
140
+
141
+ def upload(self, buffer: BufferType):
142
+ if isinstance(buffer, ApiData):
143
+ buffer = move2target_device(buffer, torch.device('cpu'))
144
+ file_path = os.path.join(self.session_config.nfs_path, buffer.name + ".pt")
145
+ else:
146
+ file_path = os.path.join(self.session_config.nfs_path, buffer + f"_{int(time.time())}")
147
+
148
+ try:
149
+ save_pkl(buffer, file_path)
150
+ except Exception as e:
151
+ self.logger.warning("there is something error in save_pt. please check it. %s", e)
152
+
153
+ def download(self):
154
+ buffer = None
155
+ cur_file = None
156
+ for file_type in ("start*", "*.pt", "end*"):
157
+ pattern = os.path.join(self.nfs_path, file_type)
158
+ files = glob.glob(pattern)
159
+ if len(files) > 0:
160
+ cur_file = files[0]
161
+ break
162
+
163
+ if cur_file is not None:
164
+ try:
165
+ buffer = load_pkl(cur_file)
166
+ except Exception as e:
167
+ self.logger.warning("there is something error. please check it. %s", e)
168
+ remove_path(cur_file)
169
+ return buffer
170
+
171
+
172
+ @recursion_depth_decorator("move2device_exec")
173
+ def move2device_exec(obj, device):
174
+ if isinstance(obj, (tuple, list)):
175
+ data_list = [move2device_exec(val, device) for val in obj]
176
+ return data_list if isinstance(obj, list) else tuple(data_list)
177
+ if isinstance(obj, dict):
178
+ return {key: move2device_exec(val, device) for key, val in obj.items()}
179
+ elif isinstance(obj, torch.Tensor):
180
+ obj = obj.detach()
181
+ if obj.device.type != device:
182
+ obj = obj.to(device)
183
+ return obj
184
+ elif "return_types" in str(type(obj)):
185
+ return move2device_exec(tuple(obj), device)
186
+ elif isinstance(obj, torch._C.device):
187
+ return torch.device(device)
188
+ else:
189
+ return obj
190
+
191
+
192
+ def move2target_device(buffer: ApiData, target_device):
193
+ # handle args
194
+ new_args = move2device_exec(buffer.args, target_device)
195
+
196
+ # handle kwargs
197
+ new_kwargs = move2device_exec(buffer.kwargs, target_device)
198
+
199
+ # handle result
200
+ new_results = move2device_exec(buffer.result, target_device)
201
+
202
+ if target_device == torch.device('cpu') or target_device == "cpu":
203
+ return ApiData(buffer.name, tuple(new_args), new_kwargs, new_results, buffer.step, buffer.rank)
204
+ else:
205
+ return ApiData(buffer.name, tuple(new_args), new_kwargs, buffer.result, buffer.step, buffer.rank)