mindspore 2.3.0__cp39-cp39-win_amd64.whl → 2.4.1__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (287) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +3 -1
  3. mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
  6. mindspore/_checkparam.py +50 -9
  7. mindspore/_extends/parse/compile_config.py +41 -0
  8. mindspore/_extends/parse/parser.py +9 -7
  9. mindspore/_extends/parse/standard_method.py +52 -14
  10. mindspore/_extends/pijit/pijit_func_white_list.py +350 -24
  11. mindspore/amp.py +24 -10
  12. mindspore/avcodec-59.dll +0 -0
  13. mindspore/avdevice-59.dll +0 -0
  14. mindspore/avfilter-8.dll +0 -0
  15. mindspore/avformat-59.dll +0 -0
  16. mindspore/avutil-57.dll +0 -0
  17. mindspore/common/__init__.py +6 -4
  18. mindspore/common/_pijit_context.py +190 -0
  19. mindspore/common/_register_for_tensor.py +2 -1
  20. mindspore/common/_tensor_overload.py +139 -0
  21. mindspore/common/api.py +102 -87
  22. mindspore/common/dump.py +5 -6
  23. mindspore/common/generator.py +1 -7
  24. mindspore/common/hook_handle.py +14 -26
  25. mindspore/common/initializer.py +51 -15
  26. mindspore/common/mindir_util.py +2 -2
  27. mindspore/common/parameter.py +62 -15
  28. mindspore/common/recompute.py +39 -9
  29. mindspore/common/sparse_tensor.py +7 -3
  30. mindspore/common/tensor.py +183 -37
  31. mindspore/communication/__init__.py +1 -1
  32. mindspore/communication/_comm_helper.py +38 -3
  33. mindspore/communication/comm_func.py +315 -60
  34. mindspore/communication/management.py +14 -14
  35. mindspore/context.py +132 -22
  36. mindspore/dataset/__init__.py +1 -1
  37. mindspore/dataset/audio/__init__.py +1 -1
  38. mindspore/dataset/core/config.py +7 -0
  39. mindspore/dataset/core/validator_helpers.py +7 -0
  40. mindspore/dataset/engine/cache_client.py +1 -1
  41. mindspore/dataset/engine/datasets.py +72 -44
  42. mindspore/dataset/engine/datasets_audio.py +7 -7
  43. mindspore/dataset/engine/datasets_standard_format.py +53 -3
  44. mindspore/dataset/engine/datasets_text.py +20 -20
  45. mindspore/dataset/engine/datasets_user_defined.py +174 -104
  46. mindspore/dataset/engine/datasets_vision.py +33 -33
  47. mindspore/dataset/engine/iterators.py +29 -0
  48. mindspore/dataset/engine/obs/util.py +7 -0
  49. mindspore/dataset/engine/queue.py +114 -60
  50. mindspore/dataset/engine/serializer_deserializer.py +2 -2
  51. mindspore/dataset/engine/validators.py +34 -14
  52. mindspore/dataset/text/__init__.py +1 -4
  53. mindspore/dataset/transforms/__init__.py +0 -3
  54. mindspore/dataset/utils/line_reader.py +2 -0
  55. mindspore/dataset/vision/__init__.py +1 -4
  56. mindspore/dataset/vision/utils.py +1 -1
  57. mindspore/dataset/vision/validators.py +2 -1
  58. mindspore/dnnl.dll +0 -0
  59. mindspore/{nn/extend → experimental/es}/__init__.py +4 -11
  60. mindspore/experimental/es/embedding_service.py +883 -0
  61. mindspore/{nn/layer → experimental/es}/embedding_service_layer.py +218 -30
  62. mindspore/experimental/llm_boost/__init__.py +21 -0
  63. mindspore/{nn/extend/layer → experimental/llm_boost/atb}/__init__.py +4 -8
  64. mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
  65. mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
  66. mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
  67. mindspore/experimental/llm_boost/register.py +129 -0
  68. mindspore/experimental/llm_boost/utils.py +31 -0
  69. mindspore/experimental/optim/adamw.py +85 -0
  70. mindspore/experimental/optim/optimizer.py +3 -0
  71. mindspore/hal/__init__.py +3 -3
  72. mindspore/hal/contiguous_tensors_handle.py +175 -0
  73. mindspore/hal/stream.py +18 -0
  74. mindspore/include/api/model_group.h +13 -1
  75. mindspore/include/api/types.h +10 -10
  76. mindspore/include/dataset/config.h +2 -2
  77. mindspore/include/dataset/constants.h +2 -2
  78. mindspore/include/dataset/execute.h +2 -2
  79. mindspore/include/dataset/vision.h +4 -0
  80. mindspore/jpeg62.dll +0 -0
  81. mindspore/log.py +1 -1
  82. mindspore/mindrecord/filewriter.py +68 -51
  83. mindspore/mindspore_backend.dll +0 -0
  84. mindspore/mindspore_common.dll +0 -0
  85. mindspore/mindspore_core.dll +0 -0
  86. mindspore/mindspore_glog.dll +0 -0
  87. mindspore/mindspore_np_dtype.dll +0 -0
  88. mindspore/mindspore_ops.dll +0 -0
  89. mindspore/mint/__init__.py +983 -46
  90. mindspore/mint/distributed/__init__.py +31 -0
  91. mindspore/mint/distributed/distributed.py +254 -0
  92. mindspore/mint/nn/__init__.py +268 -23
  93. mindspore/mint/nn/functional.py +125 -19
  94. mindspore/mint/nn/layer/__init__.py +39 -0
  95. mindspore/mint/nn/layer/activation.py +133 -0
  96. mindspore/mint/nn/layer/normalization.py +477 -0
  97. mindspore/mint/nn/layer/pooling.py +110 -0
  98. mindspore/mint/optim/adamw.py +26 -13
  99. mindspore/mint/special/__init__.py +63 -0
  100. mindspore/multiprocessing/__init__.py +2 -1
  101. mindspore/nn/__init__.py +0 -1
  102. mindspore/nn/cell.py +276 -96
  103. mindspore/nn/layer/activation.py +211 -44
  104. mindspore/nn/layer/basic.py +137 -10
  105. mindspore/nn/layer/embedding.py +137 -2
  106. mindspore/nn/layer/normalization.py +101 -5
  107. mindspore/nn/layer/padding.py +34 -48
  108. mindspore/nn/layer/pooling.py +161 -7
  109. mindspore/nn/layer/transformer.py +3 -3
  110. mindspore/nn/loss/__init__.py +2 -2
  111. mindspore/nn/loss/loss.py +84 -6
  112. mindspore/nn/optim/__init__.py +2 -1
  113. mindspore/nn/optim/adadelta.py +1 -1
  114. mindspore/nn/optim/adam.py +1 -1
  115. mindspore/nn/optim/lamb.py +1 -1
  116. mindspore/nn/optim/tft_wrapper.py +124 -0
  117. mindspore/nn/wrap/cell_wrapper.py +12 -23
  118. mindspore/nn/wrap/grad_reducer.py +5 -5
  119. mindspore/nn/wrap/loss_scale.py +17 -3
  120. mindspore/numpy/__init__.py +1 -1
  121. mindspore/numpy/array_creations.py +65 -68
  122. mindspore/numpy/array_ops.py +64 -60
  123. mindspore/numpy/fft.py +610 -75
  124. mindspore/numpy/logic_ops.py +11 -10
  125. mindspore/numpy/math_ops.py +85 -84
  126. mindspore/numpy/utils_const.py +4 -4
  127. mindspore/opencv_core452.dll +0 -0
  128. mindspore/opencv_imgcodecs452.dll +0 -0
  129. mindspore/opencv_imgproc452.dll +0 -0
  130. mindspore/ops/__init__.py +6 -4
  131. mindspore/ops/_grad_experimental/grad_array_ops.py +0 -11
  132. mindspore/ops/_grad_experimental/grad_comm_ops.py +67 -4
  133. mindspore/ops/_grad_experimental/grad_math_ops.py +0 -22
  134. mindspore/ops/_vmap/vmap_array_ops.py +2 -4
  135. mindspore/ops/_vmap/vmap_math_ops.py +17 -1
  136. mindspore/ops/_vmap/vmap_nn_ops.py +43 -2
  137. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +91 -7
  138. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +2 -0
  139. mindspore/ops/auto_generate/gen_extend_func.py +767 -13
  140. mindspore/ops/auto_generate/gen_ops_def.py +2452 -364
  141. mindspore/ops/auto_generate/gen_ops_prim.py +5442 -1756
  142. mindspore/ops/auto_generate/pyboost_inner_prim.py +176 -56
  143. mindspore/ops/composite/base.py +85 -48
  144. mindspore/ops/composite/multitype_ops/_compile_utils.py +1 -0
  145. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -2
  146. mindspore/ops/function/__init__.py +22 -0
  147. mindspore/ops/function/array_func.py +492 -153
  148. mindspore/ops/function/debug_func.py +113 -1
  149. mindspore/ops/function/fft_func.py +15 -2
  150. mindspore/ops/function/grad/grad_func.py +3 -2
  151. mindspore/ops/function/math_func.py +564 -207
  152. mindspore/ops/function/nn_func.py +817 -383
  153. mindspore/ops/function/other_func.py +3 -2
  154. mindspore/ops/function/random_func.py +402 -12
  155. mindspore/ops/function/reshard_func.py +13 -11
  156. mindspore/ops/function/sparse_unary_func.py +1 -1
  157. mindspore/ops/function/vmap_func.py +3 -2
  158. mindspore/ops/functional.py +24 -14
  159. mindspore/ops/op_info_register.py +3 -3
  160. mindspore/ops/operations/__init__.py +7 -2
  161. mindspore/ops/operations/_grad_ops.py +2 -76
  162. mindspore/ops/operations/_infer_ops.py +1 -1
  163. mindspore/ops/operations/_inner_ops.py +71 -94
  164. mindspore/ops/operations/array_ops.py +14 -146
  165. mindspore/ops/operations/comm_ops.py +63 -53
  166. mindspore/ops/operations/custom_ops.py +83 -19
  167. mindspore/ops/operations/debug_ops.py +42 -10
  168. mindspore/ops/operations/manually_defined/_inner.py +12 -0
  169. mindspore/ops/operations/manually_defined/ops_def.py +273 -20
  170. mindspore/ops/operations/math_ops.py +12 -223
  171. mindspore/ops/operations/nn_ops.py +20 -114
  172. mindspore/ops/operations/other_ops.py +7 -4
  173. mindspore/ops/operations/random_ops.py +46 -1
  174. mindspore/ops/primitive.py +18 -6
  175. mindspore/ops_generate/arg_dtype_cast.py +2 -0
  176. mindspore/ops_generate/gen_aclnn_implement.py +11 -11
  177. mindspore/ops_generate/gen_constants.py +36 -0
  178. mindspore/ops_generate/gen_ops.py +67 -52
  179. mindspore/ops_generate/gen_ops_inner_prim.py +1 -1
  180. mindspore/ops_generate/gen_pyboost_func.py +131 -47
  181. mindspore/ops_generate/op_proto.py +10 -3
  182. mindspore/ops_generate/pyboost_utils.py +14 -1
  183. mindspore/ops_generate/template.py +43 -21
  184. mindspore/parallel/__init__.py +3 -1
  185. mindspore/parallel/_auto_parallel_context.py +31 -9
  186. mindspore/parallel/_cell_wrapper.py +85 -0
  187. mindspore/parallel/_parallel_serialization.py +47 -19
  188. mindspore/parallel/_tensor.py +127 -13
  189. mindspore/parallel/_utils.py +53 -22
  190. mindspore/parallel/algo_parameter_config.py +5 -5
  191. mindspore/parallel/checkpoint_transform.py +46 -39
  192. mindspore/parallel/cluster/process_entity/__init__.py +1 -1
  193. mindspore/parallel/cluster/process_entity/_api.py +31 -23
  194. mindspore/parallel/cluster/process_entity/_utils.py +2 -27
  195. mindspore/parallel/parameter_broadcast.py +3 -4
  196. mindspore/parallel/shard.py +162 -31
  197. mindspore/parallel/transform_safetensors.py +1146 -0
  198. mindspore/profiler/__init__.py +2 -1
  199. mindspore/profiler/common/constant.py +29 -0
  200. mindspore/profiler/common/registry.py +47 -0
  201. mindspore/profiler/common/util.py +28 -0
  202. mindspore/profiler/dynamic_profiler.py +694 -0
  203. mindspore/profiler/envprofiling.py +17 -19
  204. mindspore/profiler/parser/ascend_analysis/constant.py +18 -0
  205. mindspore/profiler/parser/ascend_analysis/file_manager.py +25 -4
  206. mindspore/profiler/parser/ascend_analysis/function_event.py +43 -19
  207. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +31 -26
  208. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +56 -10
  209. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +55 -8
  210. mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
  211. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +27 -20
  212. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +9 -2
  213. mindspore/profiler/parser/ascend_msprof_exporter.py +5 -4
  214. mindspore/profiler/parser/ascend_timeline_generator.py +27 -25
  215. mindspore/profiler/parser/base_timeline_generator.py +19 -25
  216. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
  217. mindspore/profiler/parser/framework_parser.py +1 -391
  218. mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
  219. mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
  220. mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
  221. mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
  222. mindspore/profiler/parser/memory_usage_parser.py +0 -154
  223. mindspore/profiler/parser/profiler_info.py +78 -6
  224. mindspore/profiler/profiler.py +153 -0
  225. mindspore/profiler/profiling.py +285 -413
  226. mindspore/rewrite/__init__.py +1 -2
  227. mindspore/rewrite/common/namespace.py +4 -4
  228. mindspore/rewrite/symbol_tree/symbol_tree.py +3 -3
  229. mindspore/run_check/_check_version.py +39 -104
  230. mindspore/safeguard/rewrite_obfuscation.py +591 -247
  231. mindspore/swresample-4.dll +0 -0
  232. mindspore/swscale-6.dll +0 -0
  233. mindspore/tinyxml2.dll +0 -0
  234. mindspore/train/__init__.py +4 -3
  235. mindspore/train/_utils.py +105 -19
  236. mindspore/train/amp.py +171 -53
  237. mindspore/train/callback/__init__.py +2 -2
  238. mindspore/train/callback/_callback.py +4 -4
  239. mindspore/train/callback/_checkpoint.py +97 -31
  240. mindspore/train/callback/_cluster_monitor.py +1 -1
  241. mindspore/train/callback/_flops_collector.py +1 -0
  242. mindspore/train/callback/_loss_monitor.py +3 -3
  243. mindspore/train/callback/_on_request_exit.py +145 -31
  244. mindspore/train/callback/_summary_collector.py +5 -5
  245. mindspore/train/callback/_tft_register.py +375 -0
  246. mindspore/train/dataset_helper.py +15 -3
  247. mindspore/train/metrics/metric.py +3 -3
  248. mindspore/train/metrics/roc.py +4 -4
  249. mindspore/train/mind_ir_pb2.py +44 -39
  250. mindspore/train/model.py +154 -58
  251. mindspore/train/serialization.py +342 -128
  252. mindspore/turbojpeg.dll +0 -0
  253. mindspore/utils/__init__.py +21 -0
  254. mindspore/utils/utils.py +60 -0
  255. mindspore/version.py +1 -1
  256. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/METADATA +13 -7
  257. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/RECORD +260 -254
  258. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/WHEEL +1 -1
  259. mindspore/include/c_api/ms/abstract.h +0 -67
  260. mindspore/include/c_api/ms/attribute.h +0 -197
  261. mindspore/include/c_api/ms/base/handle_types.h +0 -43
  262. mindspore/include/c_api/ms/base/macros.h +0 -32
  263. mindspore/include/c_api/ms/base/status.h +0 -33
  264. mindspore/include/c_api/ms/base/types.h +0 -283
  265. mindspore/include/c_api/ms/context.h +0 -102
  266. mindspore/include/c_api/ms/graph.h +0 -160
  267. mindspore/include/c_api/ms/node.h +0 -606
  268. mindspore/include/c_api/ms/tensor.h +0 -161
  269. mindspore/include/c_api/ms/value.h +0 -84
  270. mindspore/mindspore_shared_lib.dll +0 -0
  271. mindspore/nn/extend/basic.py +0 -140
  272. mindspore/nn/extend/embedding.py +0 -143
  273. mindspore/nn/extend/layer/normalization.py +0 -109
  274. mindspore/nn/extend/pooling.py +0 -117
  275. mindspore/nn/layer/embedding_service.py +0 -531
  276. mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
  277. mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
  278. mindspore/ops/extend/__init__.py +0 -53
  279. mindspore/ops/extend/array_func.py +0 -218
  280. mindspore/ops/extend/math_func.py +0 -76
  281. mindspore/ops/extend/nn_func.py +0 -308
  282. mindspore/ops/silent_check.py +0 -162
  283. mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
  284. mindspore/profiler/parser/msadvisor_parser.py +0 -240
  285. mindspore/train/callback/_mindio_ttp.py +0 -443
  286. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/entry_points.txt +0 -0
  287. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/top_level.txt +0 -0
@@ -15,13 +15,14 @@
15
15
  """obfuscate network based on rewrite interfaces."""
16
16
  import os
17
17
  import re
18
- import secrets
19
18
  from pathlib import Path
19
+ from string import Template
20
+ import numpy as np
20
21
 
22
+ import mindspore as ms
21
23
  from mindspore import ops, nn
22
- from mindspore.common.tensor import Tensor
23
- from mindspore import log as logger
24
- from mindspore import load_checkpoint, save_checkpoint
24
+ from mindspore import load_checkpoint, save_checkpoint, log
25
+ from mindspore.ops import functional as F
25
26
  from mindspore.rewrite import SymbolTree, Node, NodeType, ScopedValue
26
27
  from mindspore.rewrite.parsers import ClassDefParser
27
28
  from mindspore.rewrite.parsers import ModuleParser
@@ -29,36 +30,127 @@ from mindspore.rewrite.parsers import ModuleParser
29
30
  OBF_RATIOS_LENGTH = 1
30
31
  MAX_OBF_RATIOS_NUM = 50
31
32
  OBF_RATIOS_WIDTH = 0
32
- OBF_RATIOS_INSERT_INDEX = 0
33
+
34
+ _supported_ops = {
35
+ 'mul': ops.Mul,
36
+ 'matmul': ops.MatMul,
37
+ 'invert': ops.Inv
38
+ }
39
+
40
+ _supported_config_type = [
41
+ 'obf_metadata_config',
42
+ 'weight_obf_config',
43
+ 'network_obf_config'
44
+ ]
45
+
46
+ _supported_metadata_type = [
47
+ 'random',
48
+ 'rearrange'
49
+ ]
50
+
51
+ obf_medatadata_template = {
52
+ 'name': 'obf_metadata',
53
+ 'shape': [1,],
54
+ 'type': 'random',
55
+ 'save_metadata': True,
56
+ 'metadata_op': 'invert'
57
+ }
58
+
59
+ weight_obf_template = {
60
+ 'target': '',
61
+ 'weight_obf_ops': [{'name': 'mul', 'input_x': 'weight', 'input_y': 'obf_metadata'}]
62
+ }
63
+
64
+ network_obf_template = {
65
+ 'module': '',
66
+ 'target': '',
67
+ 'insert_new_input': [{'name': 'obf_metadata'}],
68
+ 'insert_ops': [{'name': 'mul', 'input_x': 'weight', 'input_y': 'obf_metadata'}]
69
+ }
70
+
71
+
72
+ def _transform_target_modules(target_modules):
73
+ """transform target_modules to obf config"""
74
+ obf_config = {}
75
+ path = target_modules[0]
76
+ target_list = target_modules[1].split('|')
77
+ max_layers = 12
78
+ layers = []
79
+ obf_medatadata = obf_medatadata_template.copy()
80
+ if len(target_modules) >= 3:
81
+ obfuscate_layers = target_modules[2].split(':')
82
+ if obfuscate_layers[1] != 'all':
83
+ max_layers = int(obfuscate_layers[1])
84
+ layers = [i for i in range(0, max_layers)]
85
+ path_new = path.replace("blocks", "blocks/${layer}")
86
+ network_obf_template['insert_ops'][0]['input_y'] = "obf_metadata_${layer}"
87
+ weight_obf_template['weight_obf_ops'][0]['input_y'] = "obf_metadata_${layer}"
88
+ weight_obf_template['name'] = "obf_metadata_${layer}"
89
+ obf_medatadata['layers'] = layers
90
+ else:
91
+ path_new = path
92
+ obf_config['obf_metadata_config'] = []
93
+ obf_config['weight_obf_config'] = []
94
+ obf_config['network_obf_config'] = []
95
+ obf_config['obf_metadata_config'].append(obf_medatadata)
96
+
97
+ for name in target_list:
98
+ target_weight = path_new + '/' + name + '/weight'
99
+ target_bias = path_new + '/' + name + '/bias'
100
+ weight_obf = weight_obf_template.copy()
101
+ weight_obf['target'] = target_weight
102
+ bias_obf = weight_obf_template.copy()
103
+ bias_obf['target'] = target_bias
104
+ network_obf = network_obf_template.copy()
105
+ network_obf['module'] = '/' + path_new
106
+ network_obf['target'] = name
107
+ if not layers:
108
+ weight_obf['layers'] = layers
109
+ bias_obf['layers'] = layers
110
+ network_obf['layers'] = layers
111
+ obf_config['weight_obf_config'].append(weight_obf)
112
+ obf_config['weight_obf_config'].append(bias_obf)
113
+ obf_config['network_obf_config'].append(network_obf)
114
+ return obf_config
115
+
116
+
117
+ def _get_op(op_name):
118
+ if op_name is None:
119
+ return None
120
+ if op_name not in _supported_ops:
121
+ raise KeyError(f"'op name' must be in {list(_supported_ops.keys())}, but got {op_name}.")
122
+ return _supported_ops[op_name]()
33
123
 
34
124
 
35
- def obfuscate_ckpt(network, ckpt_files, target_modules=None, saved_path='./', obfuscate_scale=100):
125
+ def obfuscate_ckpt(network, ckpt_files, target_modules=None, obf_config=None, saved_path='./', obfuscate_scale=100):
36
126
  """
37
- obfuscate the plaintext checkpoint files. Usually used in conjunction with
38
- :func:`mindspore.load_obf_params_into_net`.
39
- interface.
127
+ Obfuscate the plaintext checkpoint files according to the obfuscation config.
40
128
 
41
129
  Args:
42
130
  network (nn.Cell): The original network that need to be obfuscated.
43
131
  ckpt_files (str): The directory path of original ckpt files.
44
- target_modules (list[str]): The target module of network that need to be obfuscated. The first string
45
- represents the network path of target module in original network, which should be in form of ``'A/B/C'``.
46
- The second string represents the obfuscation target module, which should be in form of ``'D|E|F'``. For
47
- example, thr target_modules of GPT2 can be ``['backbone/blocks/attention', 'dense1|dense2|dense3']``.
48
- If target_modules has the third value, it should be in the format of 'obfuscate_layers:all' or
49
- 'obfuscate_layers:int', which represents the number of layers need to be obfuscated of duplicate layers
50
- (such as transformer layers or resnet blocks). If target_modules is ``None``, the function would search
51
- target modules by itself. If found, the searched target module would be used, otherwise suggested target
52
- modules would be given with warning log. Default: ``None``.
132
+ target_modules (list[str]): The target ops that need to be obfuscated in the network. The first string
133
+ represents the network path of the target ops in the original network, which should be in form of
134
+ ``"A/B/C"``. The second string represents the names of multiple target ops in the same path, which
135
+ should be in form of ``"D|E|F"``. For example, the target_modules of GPT2 can be ``['backbone/blocks
136
+ /attention', 'dense1|dense2|dense3']``. If target_modules has the third value, it should be in the
137
+ format of 'obfuscate_layers:all' or 'obfuscate_layers:int', which represents the number of layers
138
+ need to be obfuscated of duplicate layers (such as transformer layers or resnet blocks).
139
+ Default: ``None``.
140
+ obf_config (dict): The configuration of model obfuscation polices. Default: ``None``.
53
141
  saved_path (str): The directory path for saving obfuscated ckpt files. Default: ``'./'``.
54
142
  obfuscate_scale (Union[float, int]): Obfuscate scale of weights. The generated random obf_ratios will be in
55
143
  range of (1 / obfuscate_scale, obfuscate_scale). Default: 100.
56
144
 
145
+ Returns:
146
+ dict[str], obf_metadata, which is the necessary data that needs to be load when running obfuscated network.
147
+
57
148
  Raises:
58
149
  TypeError: If `network` is not nn.Cell.
59
150
  TypeError: If `ckpt_files` is not string or `saved_path` is not string.
60
151
  TypeError: If `target_modules` is not list.
61
152
  TypeError: If target_modules's elements are not string.
153
+ TypeError: If obf_config is not dict.
62
154
  ValueError: If `ckpt_files` is not exist or `saved_path` is not exist.
63
155
  ValueError: If the number of elements of `target_modules` is less than ``2``.
64
156
  ValueError: If the first string of `target_modules` contains characters other than uppercase and lowercase
@@ -68,54 +160,91 @@ def obfuscate_ckpt(network, ckpt_files, target_modules=None, saved_path='./', ob
68
160
  ValueError: If the third string of `target_modules` is not in the format of 'obfuscate_layers:all' or
69
161
  'obfuscate_layers:int'.
70
162
 
71
- Returns:
72
- list[float], obf_ratios, which is the necessary data that needs to be load when running obfuscated network.
73
-
74
163
  Examples:
75
164
  >>> from mindspore import obfuscate_ckpt, save_checkpoint
76
165
  >>> # Refer to https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
77
166
  >>> net = LeNet5()
78
167
  >>> save_checkpoint(net, './test_net.ckpt')
79
168
  >>> target_modules = ['', 'fc1|fc2']
80
- >>> obfuscate_ckpt(net, target_modules, './', './')
169
+ >>> obfuscate_ckpt(net, './', target_modules=target_modules, saved_path='./')
81
170
  """
171
+ def _gen_obfuscate_tensor(tensor_shape, tensor_type='rearrange'):
172
+ obf_tensor = None
173
+ if tensor_type == 'rearrange':
174
+ if len(tensor_shape) == 1:
175
+ obf_tensor = ms.Tensor(np.random.permutation(tensor_shape[0]), dtype=ms.int32)
176
+ if len(tensor_shape) == 2:
177
+ tensor = ms.Tensor(np.identity(tensor_shape[0]), dtype=ms.int32)
178
+ p = ms.Tensor(np.random.permutation(tensor_shape[1]), dtype=ms.int32)
179
+ obf_tensor = tensor[:, p]
180
+ if tensor_type == 'random':
181
+ obf_tensor = ms.Tensor(np.random.randint(1, obfuscate_scale, size=tensor_shape), dtype=ms.float16)
182
+ return obf_tensor
183
+
184
+ def _gen_obf_metadata(config):
185
+ name = config.get('name')
186
+ if name is None:
187
+ return False
188
+ save_metadata = config.get('save_metadata', False)
189
+ metadata_op_name = config.get('metadata_op')
190
+ layers = config.get('layers')
191
+ if not layers:
192
+ if not obf_metadata.get(name):
193
+ obf_tensor = _gen_obfuscate_tensor(config.get('shape'), config.get('type'))
194
+ obf_metadata[name] = obf_tensor
195
+ if save_metadata:
196
+ saved_obf_tensor = obf_tensor
197
+ if metadata_op_name is not None:
198
+ metadata_op = _get_op(metadata_op_name)
199
+ saved_obf_tensor = metadata_op(saved_obf_tensor)
200
+ if saved_obf_tensor is not None:
201
+ saved_metadata[name] = saved_obf_tensor.asnumpy()
202
+ else:
203
+ for layer in layers:
204
+ strTemplate = Template(name)
205
+ obf_name = strTemplate.safe_substitute({"layer": str(layer)})
206
+ obf_tensor = _gen_obfuscate_tensor(config.get('shape'), config.get('type'))
207
+ obf_metadata[obf_name] = obf_tensor
208
+ if save_metadata:
209
+ saved_obf_tensor = obf_tensor
210
+ if metadata_op_name is not None:
211
+ metadata_op = _get_op(metadata_op_name)
212
+ saved_obf_tensor = metadata_op(saved_obf_tensor)
213
+ if saved_obf_tensor is not None:
214
+ saved_metadata[obf_name] = saved_obf_tensor.asnumpy()
215
+ return True
216
+
82
217
  if not isinstance(network, nn.Cell):
83
218
  raise TypeError("network must be nn.Cell, but got {}.".format(type(network)))
84
219
  _check_dir_path('ckpt_files', ckpt_files)
85
220
  _check_dir_path('saved_path', saved_path)
86
- # Try to find default target modules
87
- if target_modules is None:
88
- to_split_modules = _get_default_target_modules(ckpt_files)
89
- else:
90
- if len(target_modules) >= 1 and target_modules[0] == '/':
91
- target_modules[0] = ''
92
- to_split_modules = target_modules
93
- if not _check_valid_target(network, to_split_modules):
94
- raise ValueError("The obfuscate module path {} is not exist, please check the input 'target_modules'."
95
- .format(to_split_modules))
221
+
222
+ if obf_config is None:
223
+ if not _check_valid_target(network, target_modules):
224
+ raise ValueError("{} is not exist, please check the input 'target_modules'.".format(target_modules))
225
+ log.warning("'target_modules and obf_ratios' will be deprecated and "
226
+ "removed in a future version, use 'obf_config' instead.")
227
+ obf_config = _transform_target_modules(target_modules)
228
+ if not isinstance(obf_config, dict):
229
+ raise TypeError("obf_config type should be dict, but got {}.".format(type(obf_config)))
230
+ if not obf_config or not _check_valid_obf_config(obf_config, 'obf_metadata_config')\
231
+ or not _check_valid_obf_config(obf_config, 'weight_obf_config'):
232
+ raise ValueError("'obf_config' is empty or not valid, please check the input.")
233
+ obf_metadata = {}
234
+ obf_metadata_config = obf_config.get('obf_metadata_config', [])
235
+ saved_metadata = {}
236
+ for config in obf_metadata_config:
237
+ _gen_obf_metadata(config)
96
238
  if (not isinstance(obfuscate_scale, (float, int))) or (obfuscate_scale <= 1):
97
239
  raise ValueError("obfuscate_scale must be float or int, and larger than 1, but got {}."
98
240
  .format(obfuscate_scale))
99
- # generate and save obf_ratios to saved_path
100
- path_list = to_split_modules[0].split('/')
101
- target_list = to_split_modules[1].split('|')
102
- global OBF_RATIOS_LENGTH
103
- number_of_ratios = OBF_RATIOS_LENGTH * OBF_RATIOS_WIDTH
104
- if number_of_ratios > MAX_OBF_RATIOS_NUM:
105
- OBF_RATIOS_LENGTH = MAX_OBF_RATIOS_NUM // OBF_RATIOS_WIDTH
106
- number_of_ratios = OBF_RATIOS_LENGTH * OBF_RATIOS_WIDTH
107
- obf_ratios = []
108
- secrets_generator = secrets.SystemRandom()
109
- for _ in range(number_of_ratios):
110
- secure_float = secrets_generator.uniform(1 / obfuscate_scale, obfuscate_scale)
111
- obf_ratios.append(secure_float)
112
241
  # start obfuscate ckpt
113
242
  ckpt_dir_files = os.listdir(ckpt_files)
114
243
  for ckpt_name in ckpt_dir_files:
115
- sub_path = os.path.abspath(ckpt_files) + '/' + ckpt_name
244
+ sub_path = os.path.realpath(ckpt_files) + '/' + ckpt_name
116
245
  if Path(sub_path).is_dir():
117
246
  sub_ckpt_file_list = os.listdir(sub_path)
118
- new_saved_path = os.path.abspath(saved_path) + '/' + ckpt_name
247
+ new_saved_path = os.path.realpath(saved_path) + '/' + ckpt_name
119
248
  if not os.path.exists(new_saved_path):
120
249
  try:
121
250
  os.mkdir(new_saved_path, mode=0o700)
@@ -124,71 +253,148 @@ def obfuscate_ckpt(network, ckpt_files, target_modules=None, saved_path='./', ob
124
253
  for sub_ckpt_name in sub_ckpt_file_list:
125
254
  if not sub_ckpt_name.endswith('.ckpt'):
126
255
  continue
127
- _obfuscate_single_ckpt(os.path.abspath(sub_path) + '/' + sub_ckpt_name, obf_ratios, path_list,
128
- target_list, new_saved_path)
256
+ _obfuscate_single_ckpt(os.path.realpath(sub_path) + '/' + sub_ckpt_name, obf_metadata,
257
+ obf_config, new_saved_path)
129
258
  else:
130
259
  if not ckpt_name.endswith('.ckpt'):
131
260
  continue
132
- _obfuscate_single_ckpt(os.path.abspath(ckpt_files) + '/' + ckpt_name, obf_ratios, path_list,
133
- target_list, saved_path)
134
- return obf_ratios
261
+ _obfuscate_single_ckpt(os.path.realpath(ckpt_files) + '/' + ckpt_name,
262
+ obf_metadata, obf_config, saved_path)
263
+ return saved_metadata
135
264
 
136
265
 
137
- def _obfuscate_single_ckpt(ckpt_name, obf_ratios, path_list, target_list, saved_path):
266
+ def _obfuscate_single_ckpt(ckpt_name, obf_metadata, obf_config, saved_path):
138
267
  """Obfuscate single ckpt file"""
139
- module_has_been_obfuscated = set()
268
+ def _get_op_input_name(obf_op, name_key='input_x', layer=0):
269
+ op_name = obf_op.get('name')
270
+ input_name = obf_op.get(name_key)
271
+ if input_name is None:
272
+ log.error("can not find input: {} for op: {}.".format(name_key, op_name))
273
+ return None
274
+ strTemplate = Template(input_name)
275
+ input_name = strTemplate.safe_substitute({"layer": str(layer)})
276
+ return input_name
277
+
278
+ def _get_op_input(input_name, obf_param):
279
+ op_input = obf_metadata.get(input_name, None) if input_name.startswith('obf_metadata') else obf_param
280
+ return op_input
281
+
282
+ def _obfuscate_param(param, obf_metadata, obf_ops, layer=0):
283
+ param_dtype = F.dtype(param)
284
+ obf_param = param
285
+ for i in range(len(obf_ops)):
286
+ op_name = obf_ops[i].get('name')
287
+ if not isinstance(op_name, str):
288
+ raise TypeError('{} should be str type, but got {}'.format(op_name, type(op_name)))
289
+ if op_name == 'mul':
290
+ input_x = obf_param
291
+ input_y_name = _get_op_input_name(obf_ops[i], 'input_y', layer)
292
+ input_y = obf_metadata.get(input_y_name)
293
+ if input_x is None or input_y is None:
294
+ log.error("input_x or input_y is None")
295
+ return None
296
+ input_y = F.cast(input_y, param_dtype)
297
+ obf_param = ops.mul(input_x, input_y)
298
+ elif op_name == 'permuate':
299
+ input_x_name = _get_op_input_name(obf_ops[i], 'input_x', layer)
300
+ p = obf_metadata.get(input_x_name, None)
301
+ if p is None or obf_param is None:
302
+ log.error("input_x or param is None")
303
+ return None
304
+ obf_param = obf_param[p]
305
+ elif op_name == 'matmul':
306
+ input_x_name = _get_op_input_name(obf_ops[i], 'input_x', layer)
307
+ input_y_name = _get_op_input_name(obf_ops[i], 'input_y', layer)
308
+ input_x = _get_op_input(input_x_name, obf_param)
309
+ input_y = _get_op_input(input_y_name, obf_param)
310
+ if input_x is None or input_y is None:
311
+ log.error("the input_x or input_y of op: {} is None.".format(op_name))
312
+ return None
313
+ input_x = ops.transpose(input_x, (1, 0)) if obf_ops[i].get('transpose_a', False) else input_x
314
+ input_y = ops.transpose(input_y, (1, 0)) if obf_ops[i].get('transpose_b', False) else input_y
315
+ obf_param = ops.matmul(F.cast(input_x, param_dtype), F.cast(input_y, param_dtype))
316
+ else:
317
+ log.error("unsupported op, op must be matmul or permuate or mul, but got {}."
318
+ .format(op_name))
319
+ return None
320
+ return obf_param
321
+
140
322
  try:
141
323
  ckpt_param = load_checkpoint(ckpt_name)
142
324
  except (ValueError, TypeError, OSError):
143
- logger.error("Load checkpoint failed for file {}.".format(ckpt_name))
144
- return None
145
- obf_ratios_index = -1
325
+ log.error("Load checkpoint failed for file {}.".format(ckpt_name))
326
+ return False
327
+
328
+ weight_obf_config = obf_config.get('weight_obf_config', [])
146
329
  for item in ckpt_param:
147
- module = _get_valid_module(item, path_list, target_list)
148
- if module:
149
- layer_index = _judge_layer_index(item)
150
- if layer_index >= OBF_RATIOS_LENGTH:
151
- continue
152
- if module not in module_has_been_obfuscated:
153
- module_has_been_obfuscated.add(module)
154
- obf_ratios_index += 1
155
- ratio_total_index = layer_index * OBF_RATIOS_WIDTH + obf_ratios_index % OBF_RATIOS_WIDTH
156
- ckpt_param[item].set_data(ckpt_param[item].value() / obf_ratios[ratio_total_index])
330
+ item_split = item.split('.')
331
+ param_path = '/'.join(item_split[:len(item_split)])
332
+ for obf_target in weight_obf_config:
333
+ if not isinstance(obf_target, dict):
334
+ raise TypeError('{} should be dict type, but got {}'.format(obf_target, type(obf_target)))
335
+ target = obf_target.get('target', None)
336
+ layers = obf_target.get('layers', [])
337
+ obf_ops = obf_target.get('weight_obf_ops', None)
338
+ if not target or not obf_ops:
339
+ raise KeyError("target or obf_ops is None.")
340
+ if not layers:
341
+ if target == param_path:
342
+ obf_param = _obfuscate_param(ckpt_param[item].value(), obf_metadata, obf_ops)
343
+ if obf_param is None:
344
+ log.error("obfuscate weight {} failed.".format(item))
345
+ return False
346
+ ckpt_param[item].set_data(obf_param)
347
+ for layer in layers:
348
+ strTemplate = Template(target)
349
+ target_path = strTemplate.safe_substitute({"layer": str(layer)})
350
+ if target_path == param_path:
351
+ obf_param = _obfuscate_param(ckpt_param[item].value(), obf_metadata, obf_ops, layer)
352
+ if obf_param is None:
353
+ log.error("obfuscate weight {} failed.".format(item))
354
+ return False
355
+ ckpt_param[item].set_data(obf_param)
356
+
157
357
  # save the obfuscated model to saved_path
158
358
  obf_param_list = []
159
359
  for item in ckpt_param:
160
360
  obf_param_list.append({'name': item, 'data': ckpt_param[item]})
161
361
  ckpt_file_name = ckpt_name.split('/')[-1]
162
362
  obf_ckpt_file_name = ckpt_file_name.split('.')[0] + '_obf' + '.ckpt'
163
- save_checkpoint(obf_param_list, os.path.abspath(saved_path) + '/' + obf_ckpt_file_name)
164
- return None
363
+ save_checkpoint(obf_param_list, os.path.realpath(saved_path) + '/' + obf_ckpt_file_name)
364
+ return True
165
365
 
166
366
 
167
- def load_obf_params_into_net(network, target_modules, obf_ratios, data_parallel_num=1, **kwargs):
367
+ def load_obf_params_into_net(network, target_modules=None, obf_ratios=None, obf_config=None,
368
+ data_parallel_num=1, **kwargs):
168
369
  """
169
- load obfuscate ratios into obfuscated network. Usually used in conjunction with :func:`mindspore.obfuscate_ckpt`
170
- interface.
370
+ Modify model structure according to obfuscation config and load obfuscated checkpoint into obfuscated network.
171
371
 
172
372
  Args:
173
373
  network (nn.Cell): The original network that need to be obfuscated.
174
- target_modules (list[str]): The target module of network that need to be obfuscated. The first string
175
- represents the network path of target module in original network, which should be in form of ``'A/B/C'``.
176
- The second string represents the obfuscation target module, which should be in form of ``'D|E|F'``. For
177
- example, thr target_modules of GPT2 can be ``['backbone/blocks/attention', 'dense1|dense2|dense3']``.
178
- If target_modules has the third value, it should be in the format of 'obfuscate_layers:all' or
179
- 'obfuscate_layers:int', which represents the number of layers need to be obfuscated of duplicate layers
180
- (such as transformer layers or resnet blocks).
374
+ target_modules (list[str]): The target ops that need to be obfuscated in the network. The first string
375
+ represents the network path of the target ops in the original network, which should be in form of
376
+ ``"A/B/C"``. The second string represents the names of multiple target ops in the same path, which
377
+ should be in form of ``"D|E|F"``. For example, thr target_modules of GPT2 can be ``['backbone
378
+ /blocks/attention', 'dense1|dense2|dense3']``. If target_modules has the third value, it should be
379
+ in the format of 'obfuscate_layers:all' or 'obfuscate_layers:int', which represents the number of
380
+ layers need to be obfuscated of duplicate layers (such as transformer layers or resnet blocks).
381
+ Default: ``None``.
382
+ obf_ratios (Tensor): The obf ratios generated when execute :func:`mindspore.obfuscate_ckpt`. Default: ``None``.
383
+ obf_config (dict): The configuration of model obfuscation polices. Default: ``None``.
181
384
  data_parallel_num (int): The data parallel number of parallel training. Default: 1.
182
- obf_ratios (Tensor): The obf ratios generated when execute :func:`mindspore.obfuscate_ckpt`.
183
385
  kwargs (dict): Configuration options dictionary.
184
386
 
185
387
  - ignored_func_decorators (list[str]): The name list of function decorators in network's python code.
186
388
  - ignored_class_decorators (list[str]): The name list of class decorators in network's python code.
187
389
 
390
+ Returns:
391
+ nn.Cell, new_net, which is the obfuscated network.
392
+
188
393
  Raises:
189
394
  TypeError: If `network` is not nn.Cell.
190
395
  TypeError: If `obf_ratios` is not Tensor.
191
396
  TypeError: If `target_modules` is not list.
397
+ TypeError: If `obf_config` is not dict.
192
398
  TypeError: If target_modules's elements are not string.
193
399
  ValueError: If the number of elements of `target_modules` is less than ``2``.
194
400
  ValueError: If `obf_ratios` is empty Tensor.
@@ -209,40 +415,33 @@ def load_obf_params_into_net(network, target_modules, obf_ratios, data_parallel_
209
415
  >>> save_checkpoint(net, './test_net.ckpt')
210
416
  >>> target_modules = ['', 'fc1|fc2']
211
417
  >>> # obfuscate ckpt files
212
- >>> obfuscate_ckpt(net, target_modules, './', './')
418
+ >>> obfuscate_ckpt(net, './', target_modules=target_modules, saved_path='./')
213
419
  >>> # load obf ckpt into network
214
420
  >>> new_net = LeNet5()
215
421
  >>> load_checkpoint('./test_net_obf.ckpt', new_net)
216
- >>> obf_ratios = Tensor(np.load('./obf_ratios.npy'), mstype.float16)
217
- >>> obf_net = load_obf_params_into_net(new_net, target_modules, obf_ratios)
422
+ >>> obf_net = load_obf_params_into_net(new_net, target_modules)
218
423
  """
219
424
  if not isinstance(network, nn.Cell):
220
425
  raise TypeError("network must be nn.Cell, but got {}.".format(type(network)))
221
- if not isinstance(obf_ratios, Tensor):
222
- raise TypeError("obf_ratios must be MindSpore Tensor, but got {}.".format(type(obf_ratios)))
223
- if obf_ratios.size == 0:
224
- raise ValueError("obf_ratios can not be empty.")
225
- if not _check_valid_target(network, target_modules):
226
- raise ValueError("{} is not exist, please check the input 'target_modules'.".format(target_modules))
426
+ if obf_config is None:
427
+ if not _check_valid_target(network, target_modules):
428
+ raise ValueError("{} is not exist, please check the input 'target_modules'.".format(target_modules))
429
+ log.warning("'target_modules and obf_ratios' will be deprecated and "
430
+ "removed in a future version, use 'obf_config' instead.")
431
+ obf_config = _transform_target_modules(target_modules)
432
+
433
+ if not isinstance(obf_config, dict):
434
+ raise TypeError('{} should be dict type, but got {}'.format(obf_config, type(obf_config)))
435
+
436
+ if not obf_config or not _check_valid_obf_config(obf_config, 'network_obf_config'):
437
+ raise ValueError("'obf_config' is empty or not valid, please check the input.")
438
+
227
439
  if (not isinstance(data_parallel_num, int)) or (data_parallel_num <= 0):
228
440
  raise ValueError("data_parallel_num must be positive number, but got {}.".format(data_parallel_num))
229
- if len(target_modules) >= 1 and target_modules[0] == '/':
230
- target_modules[0] = ''
231
- path_list = target_modules[0].split('/')
232
- path_len = len(path_list)
233
- target_list = []
234
- for _ in range(path_len):
235
- target_list.append([])
236
- target_list.append(target_modules[1].split('|'))
237
- global MAX_OBF_RATIOS_NUM, OBF_RATIOS_LENGTH
238
- number_of_ratios = OBF_RATIOS_LENGTH * OBF_RATIOS_WIDTH
239
- if number_of_ratios > MAX_OBF_RATIOS_NUM:
240
- OBF_RATIOS_LENGTH = MAX_OBF_RATIOS_NUM // OBF_RATIOS_WIDTH
241
- number_of_ratios = OBF_RATIOS_LENGTH * OBF_RATIOS_WIDTH
242
- MAX_OBF_RATIOS_NUM = number_of_ratios
243
- rewrite_network = _obfuscate_network(network, path_list, target_list, data_parallel_num=data_parallel_num, **kwargs)
244
- setattr(rewrite_network, 'obf_ratios', obf_ratios)
245
- return rewrite_network
441
+
442
+ network_obf_config = obf_config.get('network_obf_config', [])
443
+ new_net = _obfuscate_network(network, network_obf_config, data_parallel_num=data_parallel_num, **kwargs)
444
+ return new_net
246
445
 
247
446
 
248
447
  def _check_dir_path(name, dir_path):
@@ -255,15 +454,6 @@ def _check_dir_path(name, dir_path):
255
454
  raise TypeError("{} must be a directory path, but got {}.".format(name, dir_path))
256
455
 
257
456
 
258
- def _judge_layer_index(layer_name):
259
- """Judge the layer index of target layers"""
260
- split_name = layer_name.split('.')
261
- for split_str in split_name[:]:
262
- if split_str.isdigit():
263
- return int(split_str)
264
- return 0
265
-
266
-
267
457
  def _check_valid_target(network, target_modules):
268
458
  """check whether the input 'target_modules' exists"""
269
459
  if not isinstance(target_modules, list):
@@ -314,7 +504,7 @@ def _check_valid_target(network, target_modules):
314
504
  OBF_RATIOS_WIDTH = 0
315
505
  for target in target_list:
316
506
  if not hasattr(net, target):
317
- logger.warning("{} does not exist in the path {}".format(target, target_modules[0]))
507
+ log.warning("{} does not exist in the path {}".format(target, target_modules[0]))
318
508
  else:
319
509
  OBF_RATIOS_WIDTH += 1
320
510
  if OBF_RATIOS_WIDTH == 0:
@@ -323,6 +513,118 @@ def _check_valid_target(network, target_modules):
323
513
  return True
324
514
 
325
515
 
516
+ def _check_ops_info(ops_info):
517
+ """check ops info config"""
518
+ for op in ops_info:
519
+ op_name = op.get('name')
520
+ if not isinstance(op_name, str):
521
+ raise TypeError("op_name type should be str, but got {}.".format(type(op_name)))
522
+ input_x_name = op.get('input_x')
523
+ if not isinstance(input_x_name, str):
524
+ raise TypeError("input_x_name type should be str, but got {}.".format(type(input_x_name)))
525
+ input_y_name = op.get('input_y')
526
+ if not isinstance(input_y_name, str):
527
+ raise TypeError("input_y_name type should be str, but got {}.".format(type(input_y_name)))
528
+ if not isinstance(op.get('transpose_a', False), bool):
529
+ raise TypeError("transpose_a type should be bool, but got {}.".format(type(op.get('transpose_a'))))
530
+ if not isinstance(op.get('transpose_b', False), bool):
531
+ raise TypeError("transpose_b type should be bool, but got {}.".format(type(op.get('transpose_b'))))
532
+
533
+
534
+ def _check_new_input_info(insert_new_input):
535
+ """check new input config"""
536
+ if not isinstance(insert_new_input, list):
537
+ raise TypeError("obf_config[][]['insert_new_input'] type should be list, but got {}."
538
+ .format(type(insert_new_input)))
539
+ for new_input in insert_new_input:
540
+ input_name = new_input.get('name')
541
+ if not isinstance(input_name, str):
542
+ raise TypeError("obf_config[][]['insert_new_input'][]['name'] type should be str, but got {}."
543
+ .format(type(input_name)))
544
+
545
+
546
+ def _check_obf_metadata_config(config):
547
+ """check obf metadata config"""
548
+ name = config.get('name')
549
+ if not name or not isinstance(name, str):
550
+ raise TypeError("obf_config[][]['name'] type should be str, but got {}.".format(type(name)))
551
+ shape = config.get('shape')
552
+ if not shape or not isinstance(shape, list):
553
+ raise TypeError("obf_config[][]['shape'] type should be list, but got {}.".format(type(shape)))
554
+ for item in shape:
555
+ if not isinstance(item, int):
556
+ raise TypeError("shape[] type should be int, but got {}.".format(type(item)))
557
+ save_metadata = config.get('save_metadata', True)
558
+ if not isinstance(save_metadata, bool):
559
+ raise TypeError("obf_config[][]['save_metadata'] type should be bool, but got {}."
560
+ .format(type(save_metadata)))
561
+ metadata_type = config.get('type')
562
+ if metadata_type is not None:
563
+ if not isinstance(metadata_type, str) or metadata_type not in _supported_metadata_type:
564
+ raise TypeError("obf_config[][]['type'] should be str and must in {}, but got {}."
565
+ .format(str(_supported_metadata_type), type(metadata_type)))
566
+
567
+
568
+ def _check_weight_obf_config(config):
569
+ """check weight obfuscation config"""
570
+ target = config.get('target')
571
+ if not target or not isinstance(target, str):
572
+ raise TypeError("obf_config[][]['target'] type should be str, but got {}.".format(type(target)))
573
+ weight_obf_ops = config.get('weight_obf_ops', [])
574
+ if not isinstance(weight_obf_ops, list):
575
+ raise TypeError("obf_config[][]['weight_obf_ops'] type should be list, but got {}."
576
+ .format(type(weight_obf_ops)))
577
+ _check_ops_info(weight_obf_ops)
578
+
579
+
580
+ def _check_network_obf_config(config):
581
+ """check network obfuscation config"""
582
+ target = config.get('target')
583
+ if not target or not isinstance(target, str):
584
+ raise TypeError("obf_config[][]['target'] type should be str, but got {}.".format(type(target)))
585
+ module = config.get('module')
586
+ if not module or not isinstance(module, str):
587
+ raise TypeError("obf_config[][]['module'] type should be str, but got {}.".format(type(module)))
588
+ insert_new_input = config.get('insert_new_input', [])
589
+ _check_new_input_info(insert_new_input)
590
+ insert_ops = config.get('insert_ops', [])
591
+ if not isinstance(insert_ops, list):
592
+ raise TypeError("obf_config[][]['insert_ops'] type should be list, but got {}.".format(type(insert_ops)))
593
+ _check_ops_info(insert_ops)
594
+
595
+
596
+ def _check_valid_obf_config(obf_config, config_type):
597
+ """check obfuscation config"""
598
+ if not isinstance(config_type, str) or config_type not in _supported_config_type:
599
+ raise TypeError("config_type must be str, and in {}, but got {}."
600
+ .format(str(_supported_config_type), config_type))
601
+ for config_type_item in obf_config.keys():
602
+ if not isinstance(config_type_item, str) or config_type_item not in _supported_config_type:
603
+ raise TypeError("config_type must be str, and in {}, but got {}."
604
+ .format(str(_supported_config_type), config_type_item))
605
+ config_list = obf_config.get(config_type)
606
+ if not isinstance(config_list, list):
607
+ raise TypeError("obf_config[] type of should be list, but got {}.".format(type(config_list)))
608
+
609
+ for config in config_list:
610
+ if not isinstance(config, dict):
611
+ raise TypeError("obf_config[][] type should be dict, but got {}.".format(type(config)))
612
+ if config_type == 'obf_metadata_config':
613
+ _check_obf_metadata_config(config)
614
+ elif config_type == 'weight_obf_config':
615
+ _check_weight_obf_config(config)
616
+ elif config_type == 'network_obf_config':
617
+ _check_network_obf_config(config)
618
+ layers = config.get('layers')
619
+ if layers is not None:
620
+ if not isinstance(layers, list):
621
+ raise TypeError("obf_config[][]['layers'] type should be list, but got {}.".format(type(layers)))
622
+ for layer in layers:
623
+ if not isinstance(layer, int):
624
+ raise TypeError("obf_config[][]['layers'][] type should be int, but got {}.".format(type(layer)))
625
+ return True
626
+
627
+
326
628
  def _update_max_obf_ratios_num(target_modules):
327
629
  """Update MAX_OBF_RATIOS_NUM"""
328
630
  if len(target_modules) >= 3:
@@ -341,160 +643,163 @@ def _update_max_obf_ratios_num(target_modules):
341
643
  MAX_OBF_RATIOS_NUM = int(obfuscate_layers[1]) * OBF_RATIOS_WIDTH
342
644
 
343
645
 
344
- def _get_default_target_modules(ckpt_files):
345
- """Get the default or suggested target modules, if the target modules is None."""
346
-
347
- def _split_to_path_and_target(module, target):
348
- # split module into path list and target list
349
- target_index = module.index(target)
350
- path = module[:target_index - 1]
351
- target = module[target_index:].split('/')[0]
352
- return path, target
353
-
354
- def _find_default_obfuscate_modules(net_path):
355
- # find modules including the default paths
356
- default_module = {'attention'}
357
- for module in default_module:
358
- if module in net_path and module not in candidate_modules:
359
- candidate_modules.append(net_path)
360
- # find the default targets in the default module
361
- default_target = {'dense', 'query', 'key', 'value'}
362
- for target in default_target:
363
- for candidate in candidate_modules:
364
- if target in candidate:
365
- path, target = _split_to_path_and_target(candidate, target)
366
- if path not in paths:
367
- paths.append(path)
368
- if target not in targets:
369
- targets.append(target)
370
-
371
- def _find_suggested_obfuscate_modules(net_path):
372
- default_target = {'dense', 'query', 'key', 'value'}
373
- for target in default_target:
374
- # find the suggest modules
375
- if target in net_path:
376
- path, target = _split_to_path_and_target(net_path, target)
377
- if [path, target] not in suggest_modules:
378
- suggest_modules.append([path, target])
379
-
380
- # store the potential candidate_modules
381
- candidate_modules = []
382
- suggest_modules = []
383
- paths = []
384
- targets = []
385
- ckpt_dir_files = os.listdir(ckpt_files)
386
- for ckpt_name in ckpt_dir_files:
387
- if not ckpt_name.endswith('.ckpt'):
388
- continue
389
- try:
390
- ckpt_param = load_checkpoint(os.path.abspath(ckpt_files) + '/' + ckpt_name)
391
- except (ValueError, TypeError, OSError):
392
- logger.error("Load checkpoint failed for file {}.".format(os.path.abspath(ckpt_files) + '/' + ckpt_name))
393
- return None
394
- for item in ckpt_param:
395
- param_path = _remove_digit(item)
396
- param_path = '/'.join(param_path)
397
- # find candidate modules including the default paths and append candidate_modules
398
- _find_default_obfuscate_modules(param_path)
399
- # give the suggested modules and find the default targets in the default module
400
- _find_suggested_obfuscate_modules(param_path)
401
- if paths and targets:
402
- target_modules = [paths[0], '|'.join(targets)]
403
- logger.warning("The default obfuscate modules is obtained:{}".format(target_modules))
404
- return target_modules
405
- # logging the suggested target module
406
- logger.warning("The default obfuscate modules can not be obtained. The suggested possible paths are given below: {}"
407
- .format(suggest_modules))
408
- raise ValueError("Can not get the default path, please specify the path in the form of ['A/B/C', 'D1|D2']")
409
-
410
-
411
- def _get_valid_module(item, path_list, target_list):
412
- """get the valid module"""
413
- number_path = len(path_list)
414
- net_path = _remove_digit(item)
415
- net_path = '/'.join(net_path[:number_path])
416
- tar_path = '/'.join(path_list)
417
- # update the weights with obf_ratios in target module
418
- if net_path == tar_path:
419
- for target in target_list:
420
- if target in item.split('.'):
421
- target_index = item.split('.').index(target)
422
- module = ''.join(item.split('.')[:target_index + 1])
423
- return module
424
- return None
425
-
426
-
427
646
  def _remove_digit(item):
428
647
  """remove digit in the parameter path"""
429
- param_path = item.split('.')
430
- for tmp_str in param_path[:]:
648
+ item_split = item.split('_')
649
+ for tmp_str in item_split[:]:
431
650
  if tmp_str.isdigit():
432
- param_path.remove(tmp_str)
433
- return param_path
651
+ item_split.remove(tmp_str)
652
+ return '_'.join(item_split)
434
653
 
435
654
 
436
- def _obfuscate_network(model, path_list, target_list, data_parallel_num=1, **kwargs):
437
- """obfuscate original network, including add mul operation and add inputs for passing obf_ratio."""
655
+ def _remove_scope(item):
656
+ """remove scope of name values"""
657
+ item_split = item.split('.')
658
+ for tmp_str in item_split[:]:
659
+ if tmp_str == 'self':
660
+ item_split.remove(tmp_str)
661
+ return '.'.join(item_split)
438
662
 
439
- def _insert_input(stree: SymbolTree, arg_name: str = 'y_obf'):
663
+
664
+ def _obfuscate_network(model, obf_config=None, data_parallel_num=1, **kwargs):
665
+ """obfuscate original network, including add deobfuscation ops and add inputs for passing obf_metadata."""
666
+
667
+ def _insert_input(stree: SymbolTree, arg_name: str = 'obf_metadata'):
440
668
  """add inputs for passing obf_ratio"""
441
669
  last_input = None
442
670
  for node in stree.nodes():
443
671
  if node.get_node_type() == NodeType.Input:
444
672
  last_input = node
445
673
  position = stree.after(last_input)
446
- # the insert input node name would be 'input_y_obf'
674
+ # the insert input node name would be 'input_obf_metadata'
447
675
  new_input_node = last_input.create_input(arg_name)
448
676
  stree.insert(position, new_input_node)
449
677
 
450
- def _insert_mul(stree: SymbolTree, node: Node, index: int):
451
- """add mul operation for original network"""
452
- arg_list = node.get_targets().copy()
453
- input_y_node = stree.get_node("input_y_obf")
454
- v: str = input_y_node.get_targets()[0].value
455
- sv: ScopedValue = ScopedValue.create_naming_value(v + f'[{index}]')
456
- arg_list.append(sv)
457
- target_list = node.get_targets().copy()
458
- if data_parallel_num > 1:
459
- logger.info("Data parallel number is: {}".format(data_parallel_num))
460
- new_mul_node = node.create_call_cell(cell=ops.Mul().shard(((data_parallel_num, 1), ())),
461
- targets=target_list, args=arg_list, name='mul')
678
+ def _update_subnet(substree: SymbolTree, subnode: Node):
679
+ """update the network once the subnet is obfuscated"""
680
+ input_y_node = substree.get_node("input_obf_metadata")
681
+ if input_y_node is None:
682
+ log.error("can not find input node: obf_metadata for net: {}.".format(subnode.get_name()))
683
+ return False
684
+ if hasattr(subnode, 'get_handler'):
685
+ subnode.get_handler().append_kwarg({"obf_metadata": input_y_node.get_targets()[0]})
462
686
  else:
463
- new_mul_node = node.create_call_cell(cell=ops.Mul(), targets=target_list, args=arg_list, name='mul')
464
- position = stree.after(node)
465
- stree.insert(position, new_mul_node)
687
+ subnode.append_kwarg({"obf_metadata": input_y_node.get_targets()[0]})
688
+ return True
466
689
 
467
- def _insert_mul_by_name(stree: SymbolTree, after_name_list: list):
690
+ def _insert_ops(stree: SymbolTree, node: Node, insert_ops: list):
691
+ """add mul operation for original network"""
692
+ current_node = node
693
+ for insert_op in insert_ops:
694
+ arg_list = current_node.get_targets().copy()
695
+ obf_metadata = stree.get_node("input_obf_metadata")
696
+ if obf_metadata is None:
697
+ raise ValueError("can not find input node: obf_metadata for net: {}.".format(current_node.get_name()))
698
+ v: str = obf_metadata.get_targets()[0].value
699
+ index = insert_op['input_y']
700
+ sv: ScopedValue = ScopedValue.create_naming_value(v + f'["{index}"]')
701
+ arg_list.append(sv)
702
+ target_list = current_node.get_targets().copy()
703
+ name = insert_op['name']
704
+ if data_parallel_num > 1:
705
+ new_node = current_node.create_call_cell(cell=_get_op(name).shard(((data_parallel_num, 1), ())),
706
+ targets=target_list, args=arg_list, name=name)
707
+ else:
708
+ new_node = current_node.create_call_cell(cell=_get_op(name), targets=target_list, args=arg_list,
709
+ name=name)
710
+ position = stree.after(current_node)
711
+ stree.insert(position, new_node)
712
+ current_node = new_node
713
+
714
+ def _insert_ops_by_name(stree: SymbolTree, after_name_list: list, module: str):
468
715
  """add mul operation after the target nodes according the name of them"""
469
716
  if not after_name_list:
470
717
  return
471
718
  for node in stree.nodes():
472
719
  for after_name in after_name_list:
473
720
  if node.get_name() == after_name:
474
- global OBF_RATIOS_INSERT_INDEX
475
- if OBF_RATIOS_INSERT_INDEX < MAX_OBF_RATIOS_NUM:
476
- _insert_mul(stree, node, OBF_RATIOS_INSERT_INDEX)
477
- OBF_RATIOS_INSERT_INDEX += 1
478
-
479
- def _update_subnet(substree: SymbolTree, subnode: Node):
480
- """update the network once the subnet is obfuscated"""
481
- input_y_node = substree.get_node("input_y_obf")
482
- if input_y_node is None:
483
- return
484
- subnode.get_handler().append_kwarg({"y_obf": input_y_node.get_targets()[0]})
485
-
486
- def _traverse(stree, i=0):
487
- """traverse and obfuscate the original network"""
488
- if len(path_list) == i:
489
- return
721
+ insert_ops = insert_ops_map[module+'/'+after_name]
722
+ _insert_ops(stree, node, insert_ops)
723
+
724
+ def _process_controlflow_node(node: Node, stree: SymbolTree, full_path: str, path: str, targets: dict):
725
+ ctrl = node.get_handler() if hasattr(node, 'get_handler') else node
726
+ cell_loop_name = ''
727
+ find_cell_loop = False
728
+ if hasattr(ctrl, "loop_vars") and ctrl.loop_vars:
729
+ cell_loop_name = ctrl.loop_vars[0]
730
+ inputs = ctrl.get_inputs()
731
+ for input in inputs:
732
+ if input.get_node_type() == NodeType.CellContainer:
733
+ find_cell_loop = True
734
+ full_node_name = input.get_name()
735
+ node_name = _remove_digit(_remove_scope(full_node_name))
736
+ if not _process_cellcontainer_node(input, full_path+'/'+full_node_name,
737
+ path+'/'+node_name, targets):
738
+ log.error("_process_cellcontainer_node for node: {} failed.".format(node_name))
739
+ return False
740
+ for c_node in ctrl.nodes():
741
+ c_node_name = c_node.get_name()
742
+ c_node_type = c_node.get_node_type()
743
+ if c_node.get_node_type() == NodeType.ControlFlow:
744
+ if not _process_controlflow_node(c_node, stree, full_path+'/'+c_node_name, path, targets):
745
+ return False
746
+ elif c_node.get_node_type() == NodeType.Tree and _is_target_module(path + '/' + c_node_name, targets):
747
+ sub_stree = SymbolTree(c_node.symbol_tree)
748
+ _insert_input(sub_stree, arg_name='obf_metadata')
749
+ _insert_ops_by_name(sub_stree, after_name_list=targets.get(path + '/' + c_node_name, None),
750
+ module=path + '/' + c_node_name)
751
+ if not _traverse(sub_stree, full_path+'/'+c_node_name, path+'/'+c_node_name, targets):
752
+ log.error("_traverse for node: {} failed.".format(c_node_name))
753
+ return False
754
+ if not _update_subnet(sub_stree, c_node):
755
+ log.error("_update_subnet for node: {} failed.".format(c_node_name))
756
+ return False
757
+ elif find_cell_loop and c_node_type == NodeType.CallFunction and c_node_name.startswith(cell_loop_name):
758
+ input_y_node = stree.get_node("input_obf_metadata")
759
+ if input_y_node is None:
760
+ log.error("input_y_node for node: {} is None.".format(c_node_name))
761
+ return False
762
+ c_node.append_kwarg({"obf_metadata": input_y_node.get_targets()[0]})
763
+ return True
764
+
765
+ def _process_cellcontainer_node(node: Node, full_path: str, path: str, targets: dict):
766
+ cellcontainer = node.get_handler() if hasattr(node, 'get_handler') else node
767
+ for i in range(len(cellcontainer.nodes())):
768
+ cell_node = cellcontainer.nodes()[i]
769
+ # insert input for each sub_stree in cell_container
770
+ if _is_target_module(path, targets) and cell_node.get_node_type() == NodeType.Tree:
771
+ sub_stree = SymbolTree(cell_node.symbol_tree)
772
+ _insert_input(sub_stree, arg_name='obf_metadata')
773
+ _insert_ops_by_name(sub_stree, after_name_list=targets.get(path, None), module=path)
774
+ if not _traverse(sub_stree, full_path + '/' + str(i), path + '/' + str(i), targets):
775
+ return False
776
+ return True
777
+
778
+ def _is_target_module(path, targets):
779
+ for target_module in targets.keys():
780
+ if target_module.startswith(path):
781
+ return True
782
+ return False
783
+
784
+ def _traverse(stree: SymbolTree, full_path: str, path: str, targets: dict):
490
785
  for node in stree.nodes():
491
786
  node_name = node.get_name()
492
- if node.get_node_type() == NodeType.Tree and node_name.startswith(path_list[i]):
787
+ if node.get_node_type() == NodeType.ControlFlow:
788
+ if not _process_controlflow_node(node, stree, full_path + '/' + node_name, path, targets):
789
+ log.error("process controlflow node: {} failed.".format(node.get_name()))
790
+ return False
791
+ elif node.get_node_type() == NodeType.Tree and _is_target_module(path + '/' + node_name, targets):
493
792
  sub_stree = node.get_sub_tree()
494
- _traverse(sub_stree, i + 1)
495
- _insert_input(sub_stree, arg_name='y_obf')
496
- _insert_mul_by_name(sub_stree, after_name_list=target_list[i + 1])
497
- _update_subnet(sub_stree, node)
793
+ _insert_input(sub_stree, arg_name='obf_metadata')
794
+ _insert_ops_by_name(sub_stree, after_name_list=targets.get(path + '/' + node_name, None),
795
+ module=path + '/' + node_name)
796
+ if not _traverse(sub_stree, full_path + '/' + node_name, path + '/' + node_name, targets):
797
+ log.error("traverse sub_stree for node: {} failed.".format(node.get_name()))
798
+ return False
799
+ if not _update_subnet(sub_stree, node):
800
+ log.error("update subnet for node: {} failed.".format(node.get_name()))
801
+ return False
802
+ return True
498
803
 
499
804
  def _register_denied_func_decorators(fn):
500
805
  """set the function decorators which should be denied for parse"""
@@ -523,9 +828,48 @@ def _obfuscate_network(model, path_list, target_list, data_parallel_num=1, **kwa
523
828
  if kw_class_dec and not isinstance(kw_class_dec[0], str):
524
829
  raise TypeError('elements of {} should be str, but got {}'.format(kw_class_dec, type(kw_class_dec[0])))
525
830
 
831
+ targets = {}
832
+ insert_ops_map = {}
833
+ for obf_item in obf_config:
834
+ module = obf_item.get('module', None)
835
+ target = obf_item.get('target', None)
836
+ insert_ops_info = obf_item.get('insert_ops', None)
837
+ layers = obf_item.get('layers', [])
838
+ if not layers:
839
+ real_insert_ops_info = []
840
+ if not targets.get(module, None):
841
+ targets[module] = []
842
+ if target not in targets[module]:
843
+ targets[module].append(target)
844
+ target_path = module + '/' + target
845
+ for op_info in insert_ops_info:
846
+ real_op_info = op_info.copy()
847
+ real_insert_ops_info.append(real_op_info)
848
+ insert_ops_map[target_path] = real_insert_ops_info
849
+ for layer in layers:
850
+ real_insert_ops_info = []
851
+ strTemplate = Template(module)
852
+ real_module = strTemplate.safe_substitute({"layer": str(layer)})
853
+ if not targets.get(real_module, None):
854
+ targets[real_module] = []
855
+ if target not in targets[real_module]:
856
+ targets[real_module].append(target)
857
+ target_path = real_module + '/' + target
858
+ for op_info in insert_ops_info:
859
+ real_op_info = op_info.copy()
860
+ strTemplate = Template(real_op_info['input_x'])
861
+ real_op_info['input_x'] = strTemplate.safe_substitute({"layer": str(layer)})
862
+ strTemplate = Template(real_op_info['input_y'])
863
+ real_op_info['input_y'] = strTemplate.safe_substitute({"layer": str(layer)})
864
+ real_insert_ops_info.append(real_op_info)
865
+ insert_ops_map[target_path] = real_insert_ops_info
866
+
867
+ root_path = ""
526
868
  main_stree = SymbolTree.create(model)
527
- _traverse(main_stree, 0)
528
- _insert_input(main_stree, arg_name='y_obf')
529
- _insert_mul_by_name(main_stree, after_name_list=target_list[0])
869
+ _insert_input(main_stree, arg_name='obf_metadata')
870
+ _insert_ops_by_name(main_stree, after_name_list=targets.get(root_path, None), module=root_path)
871
+ if not _traverse(main_stree, full_path=root_path, path=root_path, targets=targets):
872
+ log.error("_traverse for root_path: {} failed.".format(root_path))
873
+ return None
530
874
  new_net = main_stree.get_network()
531
875
  return new_net