easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of easy-cs-rec-custommodel might be problematic. Click here for more details.

Files changed (336) hide show
  1. easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
  2. easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
  3. easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
  4. easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
  5. easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
  6. easy_rec/__init__.py +114 -0
  7. easy_rec/python/__init__.py +0 -0
  8. easy_rec/python/builders/__init__.py +0 -0
  9. easy_rec/python/builders/hyperparams_builder.py +78 -0
  10. easy_rec/python/builders/loss_builder.py +333 -0
  11. easy_rec/python/builders/optimizer_builder.py +211 -0
  12. easy_rec/python/builders/strategy_builder.py +44 -0
  13. easy_rec/python/compat/__init__.py +0 -0
  14. easy_rec/python/compat/adam_s.py +245 -0
  15. easy_rec/python/compat/array_ops.py +229 -0
  16. easy_rec/python/compat/dynamic_variable.py +542 -0
  17. easy_rec/python/compat/early_stopping.py +653 -0
  18. easy_rec/python/compat/embedding_ops.py +162 -0
  19. easy_rec/python/compat/embedding_parallel_saver.py +316 -0
  20. easy_rec/python/compat/estimator_train.py +116 -0
  21. easy_rec/python/compat/exporter.py +473 -0
  22. easy_rec/python/compat/feature_column/__init__.py +0 -0
  23. easy_rec/python/compat/feature_column/feature_column.py +3675 -0
  24. easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
  25. easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
  26. easy_rec/python/compat/feature_column/utils.py +154 -0
  27. easy_rec/python/compat/layers.py +329 -0
  28. easy_rec/python/compat/ops.py +14 -0
  29. easy_rec/python/compat/optimizers.py +619 -0
  30. easy_rec/python/compat/queues.py +311 -0
  31. easy_rec/python/compat/regularizers.py +208 -0
  32. easy_rec/python/compat/sok_optimizer.py +440 -0
  33. easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
  34. easy_rec/python/compat/weight_decay_optimizers.py +475 -0
  35. easy_rec/python/core/__init__.py +0 -0
  36. easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
  37. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
  38. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
  39. easy_rec/python/core/learning_schedules.py +228 -0
  40. easy_rec/python/core/metrics.py +402 -0
  41. easy_rec/python/core/sampler.py +844 -0
  42. easy_rec/python/eval.py +102 -0
  43. easy_rec/python/export.py +150 -0
  44. easy_rec/python/feature_column/__init__.py +0 -0
  45. easy_rec/python/feature_column/feature_column.py +664 -0
  46. easy_rec/python/feature_column/feature_group.py +89 -0
  47. easy_rec/python/hpo/__init__.py +0 -0
  48. easy_rec/python/hpo/emr_hpo.py +140 -0
  49. easy_rec/python/hpo/generate_hpo_sql.py +71 -0
  50. easy_rec/python/hpo/pai_hpo.py +297 -0
  51. easy_rec/python/inference/__init__.py +0 -0
  52. easy_rec/python/inference/csv_predictor.py +189 -0
  53. easy_rec/python/inference/hive_parquet_predictor.py +200 -0
  54. easy_rec/python/inference/hive_predictor.py +166 -0
  55. easy_rec/python/inference/odps_predictor.py +70 -0
  56. easy_rec/python/inference/parquet_predictor.py +147 -0
  57. easy_rec/python/inference/parquet_predictor_v2.py +147 -0
  58. easy_rec/python/inference/predictor.py +621 -0
  59. easy_rec/python/inference/processor/__init__.py +0 -0
  60. easy_rec/python/inference/processor/test.py +170 -0
  61. easy_rec/python/inference/vector_retrieve.py +124 -0
  62. easy_rec/python/input/__init__.py +0 -0
  63. easy_rec/python/input/batch_tfrecord_input.py +117 -0
  64. easy_rec/python/input/criteo_binary_reader.py +259 -0
  65. easy_rec/python/input/criteo_input.py +107 -0
  66. easy_rec/python/input/csv_input.py +175 -0
  67. easy_rec/python/input/csv_input_ex.py +72 -0
  68. easy_rec/python/input/csv_input_v2.py +68 -0
  69. easy_rec/python/input/datahub_input.py +320 -0
  70. easy_rec/python/input/dummy_input.py +58 -0
  71. easy_rec/python/input/hive_input.py +123 -0
  72. easy_rec/python/input/hive_parquet_input.py +140 -0
  73. easy_rec/python/input/hive_rtp_input.py +174 -0
  74. easy_rec/python/input/input.py +1064 -0
  75. easy_rec/python/input/kafka_dataset.py +144 -0
  76. easy_rec/python/input/kafka_input.py +235 -0
  77. easy_rec/python/input/load_parquet.py +317 -0
  78. easy_rec/python/input/odps_input.py +101 -0
  79. easy_rec/python/input/odps_input_v2.py +110 -0
  80. easy_rec/python/input/odps_input_v3.py +132 -0
  81. easy_rec/python/input/odps_rtp_input.py +187 -0
  82. easy_rec/python/input/odps_rtp_input_v2.py +104 -0
  83. easy_rec/python/input/parquet_input.py +397 -0
  84. easy_rec/python/input/parquet_input_v2.py +180 -0
  85. easy_rec/python/input/parquet_input_v3.py +203 -0
  86. easy_rec/python/input/rtp_input.py +225 -0
  87. easy_rec/python/input/rtp_input_v2.py +145 -0
  88. easy_rec/python/input/tfrecord_input.py +100 -0
  89. easy_rec/python/layers/__init__.py +0 -0
  90. easy_rec/python/layers/backbone.py +571 -0
  91. easy_rec/python/layers/capsule_layer.py +176 -0
  92. easy_rec/python/layers/cmbf.py +390 -0
  93. easy_rec/python/layers/common_layers.py +192 -0
  94. easy_rec/python/layers/dnn.py +87 -0
  95. easy_rec/python/layers/embed_input_layer.py +25 -0
  96. easy_rec/python/layers/fm.py +26 -0
  97. easy_rec/python/layers/input_layer.py +396 -0
  98. easy_rec/python/layers/keras/__init__.py +34 -0
  99. easy_rec/python/layers/keras/activation.py +114 -0
  100. easy_rec/python/layers/keras/attention.py +267 -0
  101. easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
  102. easy_rec/python/layers/keras/blocks.py +262 -0
  103. easy_rec/python/layers/keras/bst.py +119 -0
  104. easy_rec/python/layers/keras/custom_ops.py +250 -0
  105. easy_rec/python/layers/keras/data_augment.py +133 -0
  106. easy_rec/python/layers/keras/din.py +67 -0
  107. easy_rec/python/layers/keras/einsum_dense.py +598 -0
  108. easy_rec/python/layers/keras/embedding.py +81 -0
  109. easy_rec/python/layers/keras/fibinet.py +251 -0
  110. easy_rec/python/layers/keras/interaction.py +416 -0
  111. easy_rec/python/layers/keras/layer_norm.py +364 -0
  112. easy_rec/python/layers/keras/mask_net.py +166 -0
  113. easy_rec/python/layers/keras/multi_head_attention.py +717 -0
  114. easy_rec/python/layers/keras/multi_task.py +125 -0
  115. easy_rec/python/layers/keras/numerical_embedding.py +376 -0
  116. easy_rec/python/layers/keras/ppnet.py +194 -0
  117. easy_rec/python/layers/keras/transformer.py +192 -0
  118. easy_rec/python/layers/layer_norm.py +51 -0
  119. easy_rec/python/layers/mmoe.py +83 -0
  120. easy_rec/python/layers/multihead_attention.py +162 -0
  121. easy_rec/python/layers/multihead_cross_attention.py +749 -0
  122. easy_rec/python/layers/senet.py +73 -0
  123. easy_rec/python/layers/seq_input_layer.py +134 -0
  124. easy_rec/python/layers/sequence_feature_layer.py +249 -0
  125. easy_rec/python/layers/uniter.py +301 -0
  126. easy_rec/python/layers/utils.py +248 -0
  127. easy_rec/python/layers/variational_dropout_layer.py +130 -0
  128. easy_rec/python/loss/__init__.py +0 -0
  129. easy_rec/python/loss/circle_loss.py +82 -0
  130. easy_rec/python/loss/contrastive_loss.py +79 -0
  131. easy_rec/python/loss/f1_reweight_loss.py +38 -0
  132. easy_rec/python/loss/focal_loss.py +93 -0
  133. easy_rec/python/loss/jrc_loss.py +128 -0
  134. easy_rec/python/loss/listwise_loss.py +161 -0
  135. easy_rec/python/loss/multi_similarity.py +68 -0
  136. easy_rec/python/loss/pairwise_loss.py +307 -0
  137. easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
  138. easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
  139. easy_rec/python/main.py +878 -0
  140. easy_rec/python/model/__init__.py +0 -0
  141. easy_rec/python/model/autoint.py +73 -0
  142. easy_rec/python/model/cmbf.py +47 -0
  143. easy_rec/python/model/collaborative_metric_learning.py +182 -0
  144. easy_rec/python/model/custom_model.py +323 -0
  145. easy_rec/python/model/dat.py +138 -0
  146. easy_rec/python/model/dbmtl.py +116 -0
  147. easy_rec/python/model/dcn.py +70 -0
  148. easy_rec/python/model/deepfm.py +106 -0
  149. easy_rec/python/model/dlrm.py +73 -0
  150. easy_rec/python/model/dropoutnet.py +207 -0
  151. easy_rec/python/model/dssm.py +154 -0
  152. easy_rec/python/model/dssm_senet.py +143 -0
  153. easy_rec/python/model/dummy_model.py +48 -0
  154. easy_rec/python/model/easy_rec_estimator.py +739 -0
  155. easy_rec/python/model/easy_rec_model.py +467 -0
  156. easy_rec/python/model/esmm.py +242 -0
  157. easy_rec/python/model/fm.py +63 -0
  158. easy_rec/python/model/match_model.py +357 -0
  159. easy_rec/python/model/mind.py +445 -0
  160. easy_rec/python/model/mmoe.py +70 -0
  161. easy_rec/python/model/multi_task_model.py +303 -0
  162. easy_rec/python/model/multi_tower.py +62 -0
  163. easy_rec/python/model/multi_tower_bst.py +190 -0
  164. easy_rec/python/model/multi_tower_din.py +130 -0
  165. easy_rec/python/model/multi_tower_recall.py +68 -0
  166. easy_rec/python/model/pdn.py +203 -0
  167. easy_rec/python/model/ple.py +120 -0
  168. easy_rec/python/model/rank_model.py +485 -0
  169. easy_rec/python/model/rocket_launching.py +203 -0
  170. easy_rec/python/model/simple_multi_task.py +54 -0
  171. easy_rec/python/model/uniter.py +46 -0
  172. easy_rec/python/model/wide_and_deep.py +121 -0
  173. easy_rec/python/ops/1.12/incr_record.so +0 -0
  174. easy_rec/python/ops/1.12/kafka.so +0 -0
  175. easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
  176. easy_rec/python/ops/1.12/libembed_op.so +0 -0
  177. easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
  178. easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
  179. easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
  180. easy_rec/python/ops/1.12/libredis++.so +0 -0
  181. easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
  182. easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
  183. easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
  184. easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
  185. easy_rec/python/ops/1.15/incr_record.so +0 -0
  186. easy_rec/python/ops/1.15/kafka.so +0 -0
  187. easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
  188. easy_rec/python/ops/1.15/libembed_op.so +0 -0
  189. easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
  190. easy_rec/python/ops/1.15/librdkafka++.so +0 -0
  191. easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
  192. easy_rec/python/ops/1.15/librdkafka.so +0 -0
  193. easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
  194. easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
  195. easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
  196. easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
  197. easy_rec/python/ops/2.12/libload_embed.so +0 -0
  198. easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
  199. easy_rec/python/ops/__init__.py +0 -0
  200. easy_rec/python/ops/gen_kafka_ops.py +193 -0
  201. easy_rec/python/ops/gen_str_avx_op.py +28 -0
  202. easy_rec/python/ops/incr_record.py +30 -0
  203. easy_rec/python/predict.py +170 -0
  204. easy_rec/python/protos/__init__.py +0 -0
  205. easy_rec/python/protos/autoint_pb2.py +122 -0
  206. easy_rec/python/protos/backbone_pb2.py +1416 -0
  207. easy_rec/python/protos/cmbf_pb2.py +435 -0
  208. easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
  209. easy_rec/python/protos/custom_model_pb2.py +57 -0
  210. easy_rec/python/protos/dat_pb2.py +262 -0
  211. easy_rec/python/protos/data_source_pb2.py +422 -0
  212. easy_rec/python/protos/dataset_pb2.py +1920 -0
  213. easy_rec/python/protos/dbmtl_pb2.py +191 -0
  214. easy_rec/python/protos/dcn_pb2.py +197 -0
  215. easy_rec/python/protos/deepfm_pb2.py +163 -0
  216. easy_rec/python/protos/dlrm_pb2.py +163 -0
  217. easy_rec/python/protos/dnn_pb2.py +329 -0
  218. easy_rec/python/protos/dropoutnet_pb2.py +239 -0
  219. easy_rec/python/protos/dssm_pb2.py +262 -0
  220. easy_rec/python/protos/dssm_senet_pb2.py +282 -0
  221. easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
  222. easy_rec/python/protos/esmm_pb2.py +133 -0
  223. easy_rec/python/protos/eval_pb2.py +930 -0
  224. easy_rec/python/protos/export_pb2.py +379 -0
  225. easy_rec/python/protos/feature_config_pb2.py +1359 -0
  226. easy_rec/python/protos/fm_pb2.py +90 -0
  227. easy_rec/python/protos/hive_config_pb2.py +138 -0
  228. easy_rec/python/protos/hyperparams_pb2.py +624 -0
  229. easy_rec/python/protos/keras_layer_pb2.py +692 -0
  230. easy_rec/python/protos/layer_pb2.py +1936 -0
  231. easy_rec/python/protos/loss_pb2.py +1713 -0
  232. easy_rec/python/protos/mind_pb2.py +497 -0
  233. easy_rec/python/protos/mmoe_pb2.py +215 -0
  234. easy_rec/python/protos/multi_tower_pb2.py +295 -0
  235. easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
  236. easy_rec/python/protos/optimizer_pb2.py +2017 -0
  237. easy_rec/python/protos/pdn_pb2.py +293 -0
  238. easy_rec/python/protos/pipeline_pb2.py +516 -0
  239. easy_rec/python/protos/ple_pb2.py +231 -0
  240. easy_rec/python/protos/predict_pb2.py +1140 -0
  241. easy_rec/python/protos/rocket_launching_pb2.py +169 -0
  242. easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
  243. easy_rec/python/protos/simi_pb2.py +54 -0
  244. easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
  245. easy_rec/python/protos/tf_predict_pb2.py +630 -0
  246. easy_rec/python/protos/tower_pb2.py +661 -0
  247. easy_rec/python/protos/train_pb2.py +1197 -0
  248. easy_rec/python/protos/uniter_pb2.py +307 -0
  249. easy_rec/python/protos/variational_dropout_pb2.py +91 -0
  250. easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
  251. easy_rec/python/test/__init__.py +0 -0
  252. easy_rec/python/test/csv_input_test.py +340 -0
  253. easy_rec/python/test/custom_early_stop_func.py +19 -0
  254. easy_rec/python/test/dh_local_run.py +104 -0
  255. easy_rec/python/test/embed_test.py +155 -0
  256. easy_rec/python/test/emr_run.py +119 -0
  257. easy_rec/python/test/eval_metric_test.py +107 -0
  258. easy_rec/python/test/excel_convert_test.py +64 -0
  259. easy_rec/python/test/export_test.py +513 -0
  260. easy_rec/python/test/fg_test.py +70 -0
  261. easy_rec/python/test/hive_input_test.py +311 -0
  262. easy_rec/python/test/hpo_test.py +235 -0
  263. easy_rec/python/test/kafka_test.py +373 -0
  264. easy_rec/python/test/local_incr_test.py +122 -0
  265. easy_rec/python/test/loss_test.py +110 -0
  266. easy_rec/python/test/odps_command.py +61 -0
  267. easy_rec/python/test/odps_local_run.py +86 -0
  268. easy_rec/python/test/odps_run.py +254 -0
  269. easy_rec/python/test/odps_test_cls.py +39 -0
  270. easy_rec/python/test/odps_test_prepare.py +198 -0
  271. easy_rec/python/test/odps_test_util.py +237 -0
  272. easy_rec/python/test/pre_check_test.py +54 -0
  273. easy_rec/python/test/predictor_test.py +394 -0
  274. easy_rec/python/test/rtp_convert_test.py +133 -0
  275. easy_rec/python/test/run.py +138 -0
  276. easy_rec/python/test/train_eval_test.py +1299 -0
  277. easy_rec/python/test/util_test.py +85 -0
  278. easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
  279. easy_rec/python/tools/__init__.py +0 -0
  280. easy_rec/python/tools/add_boundaries_to_config.py +67 -0
  281. easy_rec/python/tools/add_feature_info_to_config.py +145 -0
  282. easy_rec/python/tools/convert_config_format.py +48 -0
  283. easy_rec/python/tools/convert_rtp_data.py +79 -0
  284. easy_rec/python/tools/convert_rtp_fg.py +106 -0
  285. easy_rec/python/tools/create_config_from_excel.py +427 -0
  286. easy_rec/python/tools/criteo/__init__.py +0 -0
  287. easy_rec/python/tools/criteo/convert_data.py +157 -0
  288. easy_rec/python/tools/edit_lookup_graph.py +134 -0
  289. easy_rec/python/tools/faiss_index_pai.py +116 -0
  290. easy_rec/python/tools/feature_selection.py +316 -0
  291. easy_rec/python/tools/hit_rate_ds.py +223 -0
  292. easy_rec/python/tools/hit_rate_pai.py +138 -0
  293. easy_rec/python/tools/pre_check.py +120 -0
  294. easy_rec/python/tools/predict_and_chk.py +111 -0
  295. easy_rec/python/tools/read_kafka.py +55 -0
  296. easy_rec/python/tools/split_model_pai.py +286 -0
  297. easy_rec/python/tools/split_pdn_model_pai.py +272 -0
  298. easy_rec/python/tools/test_saved_model.py +80 -0
  299. easy_rec/python/tools/view_saved_model.py +39 -0
  300. easy_rec/python/tools/write_kafka.py +65 -0
  301. easy_rec/python/train_eval.py +325 -0
  302. easy_rec/python/utils/__init__.py +15 -0
  303. easy_rec/python/utils/activation.py +120 -0
  304. easy_rec/python/utils/check_utils.py +87 -0
  305. easy_rec/python/utils/compat.py +14 -0
  306. easy_rec/python/utils/config_util.py +652 -0
  307. easy_rec/python/utils/constant.py +43 -0
  308. easy_rec/python/utils/convert_rtp_fg.py +616 -0
  309. easy_rec/python/utils/dag.py +192 -0
  310. easy_rec/python/utils/distribution_utils.py +268 -0
  311. easy_rec/python/utils/ds_util.py +65 -0
  312. easy_rec/python/utils/embedding_utils.py +73 -0
  313. easy_rec/python/utils/estimator_utils.py +1036 -0
  314. easy_rec/python/utils/export_big_model.py +630 -0
  315. easy_rec/python/utils/expr_util.py +118 -0
  316. easy_rec/python/utils/fg_util.py +53 -0
  317. easy_rec/python/utils/hit_rate_utils.py +220 -0
  318. easy_rec/python/utils/hive_utils.py +183 -0
  319. easy_rec/python/utils/hpo_util.py +137 -0
  320. easy_rec/python/utils/hvd_utils.py +56 -0
  321. easy_rec/python/utils/input_utils.py +108 -0
  322. easy_rec/python/utils/io_util.py +282 -0
  323. easy_rec/python/utils/load_class.py +249 -0
  324. easy_rec/python/utils/meta_graph_editor.py +941 -0
  325. easy_rec/python/utils/multi_optimizer.py +62 -0
  326. easy_rec/python/utils/numpy_utils.py +18 -0
  327. easy_rec/python/utils/odps_util.py +79 -0
  328. easy_rec/python/utils/pai_util.py +86 -0
  329. easy_rec/python/utils/proto_util.py +90 -0
  330. easy_rec/python/utils/restore_filter.py +89 -0
  331. easy_rec/python/utils/shape_utils.py +432 -0
  332. easy_rec/python/utils/static_shape.py +71 -0
  333. easy_rec/python/utils/test_utils.py +866 -0
  334. easy_rec/python/utils/tf_utils.py +56 -0
  335. easy_rec/version.py +4 -0
  336. test/__init__.py +0 -0
@@ -0,0 +1,340 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ """Define cv_input, the base class for cv tasks."""
4
+
5
+ import os
6
+ import unittest
7
+
8
+ import tensorflow as tf
9
+ from google.protobuf import text_format
10
+
11
+ from easy_rec.python.input.csv_input import CSVInput
12
+ from easy_rec.python.input.csv_input_ex import CSVInputEx
13
+ from easy_rec.python.protos.dataset_pb2 import DatasetConfig
14
+ from easy_rec.python.protos.feature_config_pb2 import FeatureConfig
15
+ from easy_rec.python.utils import config_util
16
+ from easy_rec.python.utils import constant
17
+ from easy_rec.python.utils.test_utils import RunAsSubprocess
18
+
19
+ if tf.__version__ >= '2.0':
20
+ from tensorflow.python.framework.ops import disable_eager_execution
21
+
22
+ disable_eager_execution()
23
+ tf = tf.compat.v1
24
+
25
+
26
+ class CSVInputTest(tf.test.TestCase):
27
+
28
+ def __init__(self, methodName='CSVInputTest'):
29
+ super(CSVInputTest, self).__init__(methodName=methodName)
30
+ self._input_path = 'data/test/test.csv'
31
+ self._input_path_with_quote = 'data/test/test_with_quote.csv'
32
+
33
+ @RunAsSubprocess
34
+ def test_csv_data(self):
35
+ data_config_str = """
36
+ input_fields {
37
+ input_name: 'label'
38
+ input_type: FLOAT
39
+ }
40
+ input_fields {
41
+ input_name: 'field[1-3]'
42
+ input_type: STRING
43
+ }
44
+ label_fields: 'label'
45
+ batch_size: 1024
46
+ num_epochs: 10000
47
+ prefetch_size: 32
48
+ auto_expand_input_fields: true
49
+ """
50
+ feature_config_str = """
51
+ input_names: 'field1'
52
+ shared_names: 'field[2-3]'
53
+ feature_type: IdFeature
54
+ embedding_dim: 32
55
+ hash_bucket_size: 2000
56
+ """
57
+ dataset_config = DatasetConfig()
58
+ text_format.Merge(data_config_str, dataset_config)
59
+ feature_config = FeatureConfig()
60
+ text_format.Merge(feature_config_str, feature_config)
61
+ feature_configs = [feature_config]
62
+ empty_config = FeatureConfig()
63
+ empty_config.CopyFrom(feature_config)
64
+ while len(empty_config.input_names) > 0:
65
+ empty_config.input_names.pop()
66
+ while len(empty_config.shared_names) > 0:
67
+ empty_config.shared_names.pop()
68
+ for input_name in feature_config.shared_names:
69
+ input_names = config_util.auto_expand_names(input_name)
70
+ for tmp_name in input_names:
71
+ tmp_config = FeatureConfig()
72
+ tmp_config.CopyFrom(empty_config)
73
+ tmp_config.input_names.append(tmp_name)
74
+ feature_configs.append(tmp_config)
75
+ train_input_fn = CSVInput(dataset_config, feature_configs,
76
+ self._input_path).create_input()
77
+ dataset = train_input_fn(mode=tf.estimator.ModeKeys.TRAIN)
78
+ iterator = dataset.make_initializable_iterator()
79
+ tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, iterator.initializer)
80
+ features, labels = iterator.get_next()
81
+ init_op = tf.get_collection(tf.GraphKeys.TABLE_INITIALIZERS)
82
+ gpu_options = tf.GPUOptions(allow_growth=True)
83
+ session_config = tf.ConfigProto(
84
+ gpu_options=gpu_options,
85
+ allow_soft_placement=True,
86
+ log_device_placement=False)
87
+ with self.test_session(config=session_config) as sess:
88
+ sess.run(init_op)
89
+ feature_dict, label_dict = sess.run([features, labels])
90
+
91
+ @RunAsSubprocess
92
+ def test_csv_data_flt_to_str_exception(self):
93
+ data_config_str = """
94
+ input_fields {
95
+ input_name: 'label'
96
+ input_type: FLOAT
97
+ }
98
+ input_fields {
99
+ input_name: 'field1'
100
+ input_type: STRING
101
+ }
102
+ input_fields {
103
+ input_name: 'field[2-3]'
104
+ input_type: FLOAT
105
+ }
106
+ label_fields: 'label'
107
+ batch_size: 1024
108
+ num_epochs: 10000
109
+ prefetch_size: 32
110
+ auto_expand_input_fields: true
111
+ """
112
+ feature_config_str = """
113
+ input_names: 'field1'
114
+ shared_names: 'field[2-3]'
115
+ feature_type: IdFeature
116
+ embedding_dim: 32
117
+ hash_bucket_size: 2000
118
+ """
119
+
120
+ dataset_config = DatasetConfig()
121
+ text_format.Merge(data_config_str, dataset_config)
122
+ feature_config = FeatureConfig()
123
+ text_format.Merge(feature_config_str, feature_config)
124
+ feature_configs = [feature_config]
125
+ empty_config = FeatureConfig()
126
+ empty_config.CopyFrom(feature_config)
127
+ while len(empty_config.input_names) > 0:
128
+ empty_config.input_names.pop()
129
+ while len(empty_config.shared_names) > 0:
130
+ empty_config.shared_names.pop()
131
+ for input_name in feature_config.shared_names:
132
+ input_names = config_util.auto_expand_names(input_name)
133
+ for tmp_name in input_names:
134
+ tmp_config = FeatureConfig()
135
+ tmp_config.CopyFrom(empty_config)
136
+ tmp_config.input_names.append(tmp_name)
137
+ feature_configs.append(tmp_config)
138
+ train_input_fn = CSVInput(dataset_config, feature_configs,
139
+ self._input_path).create_input()
140
+ try:
141
+ dataset = train_input_fn(mode=tf.estimator.ModeKeys.TRAIN) # noqa: F841
142
+ passed = True
143
+ except Exception:
144
+ passed = False
145
+ assert not passed, 'if precision is not set, exception should be reported in convert float to string'
146
+
147
+ @RunAsSubprocess
148
+ def test_csv_data_flt_to_str(self):
149
+ data_config_str = """
150
+ input_fields {
151
+ input_name: 'label'
152
+ input_type: FLOAT
153
+ }
154
+ input_fields {
155
+ input_name: 'field1'
156
+ input_type: STRING
157
+ }
158
+ input_fields {
159
+ input_name: 'field[2-3]'
160
+ input_type: FLOAT
161
+ }
162
+ label_fields: 'label'
163
+ batch_size: 1024
164
+ num_epochs: 10000
165
+ prefetch_size: 32
166
+ auto_expand_input_fields: true
167
+ """
168
+ feature_config_str = """
169
+ input_names: 'field1'
170
+ shared_names: 'field[2-3]'
171
+ feature_type: IdFeature
172
+ embedding_dim: 32
173
+ hash_bucket_size: 2000
174
+ precision: 3
175
+ """
176
+
177
+ dataset_config = DatasetConfig()
178
+ text_format.Merge(data_config_str, dataset_config)
179
+ feature_config = FeatureConfig()
180
+ text_format.Merge(feature_config_str, feature_config)
181
+ feature_configs = [feature_config]
182
+ empty_config = FeatureConfig()
183
+ empty_config.CopyFrom(feature_config)
184
+ while len(empty_config.input_names) > 0:
185
+ empty_config.input_names.pop()
186
+ while len(empty_config.shared_names) > 0:
187
+ empty_config.shared_names.pop()
188
+ for input_name in feature_config.shared_names:
189
+ input_names = config_util.auto_expand_names(input_name)
190
+ for tmp_name in input_names:
191
+ tmp_config = FeatureConfig()
192
+ tmp_config.CopyFrom(empty_config)
193
+ tmp_config.input_names.append(tmp_name)
194
+ feature_configs.append(tmp_config)
195
+ train_input_fn = CSVInput(dataset_config, feature_configs,
196
+ self._input_path).create_input()
197
+
198
+ dataset = train_input_fn(mode=tf.estimator.ModeKeys.TRAIN)
199
+
200
+ iterator = dataset.make_initializable_iterator()
201
+ tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, iterator.initializer)
202
+ features, labels = iterator.get_next()
203
+ init_op = tf.get_collection(tf.GraphKeys.TABLE_INITIALIZERS)
204
+ gpu_options = tf.GPUOptions(allow_growth=True)
205
+ session_config = tf.ConfigProto(
206
+ gpu_options=gpu_options,
207
+ allow_soft_placement=True,
208
+ log_device_placement=False)
209
+ with self.test_session(config=session_config) as sess:
210
+ sess.run(init_op)
211
+ feature_dict, label_dict = sess.run([features, labels])
212
+
213
+ @RunAsSubprocess
214
+ def test_csv_input_ex(self):
215
+ data_config_str = """
216
+ input_fields {
217
+ input_name: 'label'
218
+ input_type: FLOAT
219
+ }
220
+ input_fields {
221
+ input_name: 'field[1-3]'
222
+ input_type: STRING
223
+ }
224
+ label_fields: 'label'
225
+ batch_size: 1024
226
+ num_epochs: 10000
227
+ prefetch_size: 32
228
+ auto_expand_input_fields: true
229
+ """
230
+ feature_config_str = """
231
+ input_names: 'field1'
232
+ shared_names: 'field[2-3]'
233
+ feature_type: IdFeature
234
+ embedding_dim: 32
235
+ hash_bucket_size: 2000
236
+ """
237
+ dataset_config = DatasetConfig()
238
+ text_format.Merge(data_config_str, dataset_config)
239
+ feature_config = FeatureConfig()
240
+ text_format.Merge(feature_config_str, feature_config)
241
+ feature_configs = [feature_config]
242
+ empty_config = FeatureConfig()
243
+ empty_config.CopyFrom(feature_config)
244
+ while len(empty_config.input_names) > 0:
245
+ empty_config.input_names.pop()
246
+ while len(empty_config.shared_names) > 0:
247
+ empty_config.shared_names.pop()
248
+ for input_name in feature_config.shared_names:
249
+ input_names = config_util.auto_expand_names(input_name)
250
+ for tmp_name in input_names:
251
+ tmp_config = FeatureConfig()
252
+ tmp_config.CopyFrom(empty_config)
253
+ tmp_config.input_names.append(tmp_name)
254
+ feature_configs.append(tmp_config)
255
+ train_input_fn = CSVInputEx(dataset_config, feature_configs,
256
+ self._input_path_with_quote).create_input()
257
+ dataset = train_input_fn(mode=tf.estimator.ModeKeys.TRAIN)
258
+ iterator = dataset.make_initializable_iterator()
259
+ tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, iterator.initializer)
260
+ features, labels = iterator.get_next()
261
+ init_op = tf.get_collection(tf.GraphKeys.TABLE_INITIALIZERS)
262
+ gpu_options = tf.GPUOptions(allow_growth=True)
263
+ session_config = tf.ConfigProto(
264
+ gpu_options=gpu_options,
265
+ allow_soft_placement=True,
266
+ log_device_placement=False)
267
+ with self.test_session(config=session_config) as sess:
268
+ sess.run(init_op)
269
+ feature_dict, label_dict = sess.run([features, labels])
270
+
271
+ @unittest.skipIf('AVX_TEST' not in os.environ,
272
+ 'Only execute when avx512 instructions are supported')
273
+ @RunAsSubprocess
274
+ def test_csv_input_ex_avx(self):
275
+ constant.enable_avx_str_split()
276
+ self.test_csv_input_ex()
277
+ constant.disable_avx_str_split()
278
+
279
+ @RunAsSubprocess
280
+ def test_csv_data_ignore_error(self):
281
+ data_config_str = """
282
+ input_fields {
283
+ input_name: 'label'
284
+ input_type: FLOAT
285
+ }
286
+ input_fields {
287
+ input_name: 'field[1-3]'
288
+ input_type: STRING
289
+ }
290
+ label_fields: 'label'
291
+ batch_size: 32
292
+ num_epochs: 10000
293
+ prefetch_size: 32
294
+ auto_expand_input_fields: true
295
+ ignore_error: true
296
+ """
297
+ feature_config_str = """
298
+ input_names: 'field1'
299
+ shared_names: 'field[2-3]'
300
+ feature_type: IdFeature
301
+ embedding_dim: 32
302
+ hash_bucket_size: 2000
303
+ """
304
+ dataset_config = DatasetConfig()
305
+ text_format.Merge(data_config_str, dataset_config)
306
+ feature_config = FeatureConfig()
307
+ text_format.Merge(feature_config_str, feature_config)
308
+ feature_configs = [feature_config]
309
+ empty_config = FeatureConfig()
310
+ empty_config.CopyFrom(feature_config)
311
+ while len(empty_config.input_names) > 0:
312
+ empty_config.input_names.pop()
313
+ while len(empty_config.shared_names) > 0:
314
+ empty_config.shared_names.pop()
315
+ for input_name in feature_config.shared_names:
316
+ input_names = config_util.auto_expand_names(input_name)
317
+ for tmp_name in input_names:
318
+ tmp_config = FeatureConfig()
319
+ tmp_config.CopyFrom(empty_config)
320
+ tmp_config.input_names.append(tmp_name)
321
+ feature_configs.append(tmp_config)
322
+ train_input_fn = CSVInput(dataset_config, feature_configs,
323
+ self._input_path_with_quote).create_input()
324
+ dataset = train_input_fn(mode=tf.estimator.ModeKeys.TRAIN)
325
+ iterator = dataset.make_initializable_iterator()
326
+ tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, iterator.initializer)
327
+ features, labels = iterator.get_next()
328
+ init_op = tf.get_collection(tf.GraphKeys.TABLE_INITIALIZERS)
329
+ gpu_options = tf.GPUOptions(allow_growth=True)
330
+ session_config = tf.ConfigProto(
331
+ gpu_options=gpu_options,
332
+ allow_soft_placement=True,
333
+ log_device_placement=False)
334
+ with self.test_session(config=session_config) as sess:
335
+ sess.run(init_op)
336
+ feature_dict, label_dict = sess.run([features, labels])
337
+
338
+
339
+ if __name__ == '__main__':
340
+ tf.test.main()
@@ -0,0 +1,19 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+
4
+ import json
5
+ import logging
6
+
7
+
8
+ def custom_early_stop_func(eval_results, func_param):
9
+ params = json.loads(func_param)
10
+ tmp_thre = params['thre']
11
+ metric_name = params['metric_name']
12
+ for step, metrics in eval_results.items():
13
+ val = metrics[metric_name]
14
+ if val > tmp_thre:
15
+ logging.info(
16
+ 'At step %s, metric "%s" has value %s, which is larger than %s, causing early stop.'
17
+ % (step, metric_name, val, tmp_thre))
18
+ return True
19
+ return False
@@ -0,0 +1,104 @@
1
+ import argparse
2
+ import logging
3
+ import os
4
+ import shutil
5
+ import sys
6
+
7
+ import tensorflow as tf
8
+
9
+ from easy_rec.python.test.odps_command import OdpsCommand
10
+ from easy_rec.python.test.odps_test_prepare import prepare
11
+ from easy_rec.python.test.odps_test_util import OdpsOSSConfig
12
+ from easy_rec.python.test.odps_test_util import delete_oss_path
13
+ from easy_rec.python.test.odps_test_util import get_oss_bucket
14
+ from easy_rec.python.utils import test_utils
15
+
16
+ logging.basicConfig(
17
+ level=logging.INFO, format='[%(asctime)s][%(levelname)s] %(message)s')
18
+
19
+ odps_oss_config = OdpsOSSConfig(script_path='./samples/dh_script')
20
+
21
+
22
+ class TestPipelineOnEmr(tf.test.TestCase):
23
+ """Train eval test on emr."""
24
+
25
+ def setUp(self):
26
+ logging.info('Testing %s.%s' % (type(self).__name__, self._testMethodName))
27
+ self._success = True
28
+ self._test_dir = test_utils.get_tmp_dir()
29
+ logging.info('test datahub local dir: %s' % self._test_dir)
30
+
31
+ def tearDown(self):
32
+ if self._success:
33
+ shutil.rmtree(self._test_dir)
34
+
35
+ def test_datahub_train_eval(self):
36
+ end = ['deep_fm/drop_table.sql']
37
+ odps_cmd = OdpsCommand(odps_oss_config)
38
+
39
+ self._success = test_utils.test_datahub_train_eval(
40
+ '%s/configs/deepfm.config' % odps_oss_config.temp_dir, odps_oss_config,
41
+ self._test_dir)
42
+ odps_cmd.run_list(end)
43
+ self.assertTrue(self._success)
44
+
45
+
46
+ if __name__ == '__main__':
47
+ parser = argparse.ArgumentParser()
48
+ parser.add_argument(
49
+ '--odps_config', type=str, default=None, help='odps config path')
50
+ parser.add_argument(
51
+ '--oss_config', type=str, default=None, help='ossutilconfig path')
52
+ parser.add_argument(
53
+ '--bucket_name', type=str, default=None, help='test oss bucket name')
54
+ parser.add_argument('--arn', type=str, default=None, help='oss rolearn')
55
+ parser.add_argument(
56
+ '--odpscmd', type=str, default='odpscmd', help='odpscmd path')
57
+ parser.add_argument(
58
+ '--algo_project', type=str, default=None, help='algo project name')
59
+ parser.add_argument(
60
+ '--algo_res_project',
61
+ type=str,
62
+ default=None,
63
+ help='algo resource project name')
64
+ parser.add_argument(
65
+ '--algo_version', type=str, default=None, help='algo version')
66
+ args, unknown_args = parser.parse_known_args()
67
+
68
+ sys.argv = [sys.argv[0]]
69
+ for unk_arg in unknown_args:
70
+ sys.argv.append(unk_arg)
71
+
72
+ if args.odps_config:
73
+ odps_oss_config.load_odps_config(args.odps_config)
74
+ os.environ['ODPS_CONFIG_FILE_PATH'] = args.odps_config
75
+ if args.oss_config:
76
+ odps_oss_config.load_oss_config(args.oss_config)
77
+ if args.odpscmd:
78
+ odps_oss_config.odpscmd_path = args.odpscmd
79
+ if args.algo_project:
80
+ odps_oss_config.algo_project = args.algo_project
81
+ if args.algo_res_project:
82
+ odps_oss_config.algo_res_project = args.algo_res_project
83
+ if args.algo_version:
84
+ odps_oss_config.algo_version = args.algo_version
85
+ if args.arn:
86
+ odps_oss_config.arn = args.arn
87
+ if args.bucket_name:
88
+ odps_oss_config.bucket_name = args.bucket_name
89
+ prepare(odps_oss_config)
90
+ start = [
91
+ 'deep_fm/create_external_deepfm_table.sql',
92
+ 'deep_fm/create_inner_deepfm_table.sql'
93
+ ]
94
+ end = ['deep_fm/drop_table.sql']
95
+ odps_cmd = OdpsCommand(odps_oss_config)
96
+ odps_cmd.run_list(start)
97
+ odps_oss_config.init_dh_and_odps()
98
+ tf.test.main()
99
+ # delete oss path
100
+ bucket = get_oss_bucket(odps_oss_config.oss_key, odps_oss_config.oss_secret,
101
+ odps_oss_config.endpoint, odps_oss_config.bucket_name)
102
+ delete_oss_path(bucket, odps_oss_config.exp_dir, odps_oss_config.bucket_name)
103
+ # delete tmp
104
+ shutil.rmtree(odps_oss_config.temp_dir)
@@ -0,0 +1,155 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+
4
+ import logging
5
+
6
+ import numpy as np
7
+ import tensorflow as tf
8
+ from google.protobuf import text_format
9
+
10
+ from easy_rec.python.compat.feature_column import feature_column
11
+ from easy_rec.python.feature_column.feature_column import FeatureColumnParser
12
+ from easy_rec.python.input.dummy_input import DummyInput
13
+ from easy_rec.python.protos.dataset_pb2 import DatasetConfig
14
+ from easy_rec.python.protos.feature_config_pb2 import FeatureConfig
15
+ from easy_rec.python.protos.feature_config_pb2 import WideOrDeep
16
+
17
+ if tf.__version__ >= '2.0':
18
+ tf = tf.compat.v1
19
+
20
+
21
+ class EmbedTest(tf.test.TestCase):
22
+
23
+ def test_raw_embed(self):
24
+ # embedding variable is:
25
+ # [[1, 2 ],
26
+ # [3, 4 ],
27
+ # [5, 6 ],
28
+ # [7, 8 ],
29
+ # [9, 10]
30
+ # ]
31
+ feature_config_str = '''
32
+ input_names: 'field1'
33
+ feature_type: RawFeature
34
+ initializer {
35
+ constant_initializer {
36
+ consts: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
37
+ }
38
+ }
39
+ separator: ',',
40
+ raw_input_dim: 5
41
+ embedding_dim: 2
42
+ combiner: 'sum'
43
+ '''
44
+ feature_config = FeatureConfig()
45
+ text_format.Merge(feature_config_str, feature_config)
46
+
47
+ data_config_str = '''
48
+ input_fields {
49
+ input_name: 'clk'
50
+ input_type: INT32
51
+ default_val: '0'
52
+ }
53
+ input_fields {
54
+ input_name: 'field1'
55
+ input_type: STRING
56
+ default_val: '0'
57
+ }
58
+ label_fields: 'clk'
59
+ batch_size: 1
60
+ '''
61
+ data_config = DatasetConfig()
62
+ text_format.Merge(data_config_str, data_config)
63
+
64
+ feature_configs = [feature_config]
65
+ features = {'field1': tf.constant(['0.1,0.2,0.3,0.4,0.5'])}
66
+ dummy_input = DummyInput(
67
+ data_config, feature_configs, '', input_vals=features)
68
+ field_dict, _ = dummy_input._build(tf.estimator.ModeKeys.TRAIN, {})
69
+
70
+ wide_and_deep_dict = {'field1': WideOrDeep.WIDE_AND_DEEP}
71
+ fc_parser = FeatureColumnParser(feature_configs, wide_and_deep_dict, 2)
72
+ wide_cols = list(fc_parser._wide_columns.values())
73
+ wide_features = feature_column.input_layer(field_dict, wide_cols)
74
+ deep_cols = list(fc_parser._deep_columns.values())
75
+ deep_features = feature_column.input_layer(field_dict, deep_cols)
76
+ init = tf.initialize_all_variables()
77
+ with tf.Session() as sess:
78
+ sess.run(init)
79
+ fea_val = sess.run(wide_features)
80
+ logging.info('wide fea_val = %s' % str(fea_val[0]))
81
+ assert np.abs(fea_val[0][0] - 9.5) < 1e-6
82
+ assert np.abs(fea_val[0][1] - 11.0) < 1e-6
83
+ fea_val = sess.run(deep_features)
84
+ logging.info('deep fea_val = %s' % str(fea_val[0]))
85
+ assert np.abs(fea_val[0][0] - 9.5) < 1e-6
86
+ assert np.abs(fea_val[0][1] - 11.0) < 1e-6
87
+
88
+ def test_seq_multi_embed(self):
89
+ # embedding variable is:
90
+ # [[1, 2 ],
91
+ # [3, 4 ],
92
+ # [5, 6 ],
93
+ # [7, 8 ],
94
+ # [9, 10]
95
+ # ]
96
+ feature_config_str = '''
97
+ input_names: 'field1'
98
+ feature_type: SequenceFeature
99
+ initializer {
100
+ constant_initializer {
101
+ consts: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
102
+ }
103
+ }
104
+ separator: '',
105
+ seq_multi_sep: '',
106
+ embedding_dim: 2
107
+ num_buckets: 5
108
+ combiner: 'mean'
109
+ '''
110
+ feature_config = FeatureConfig()
111
+ text_format.Merge(feature_config_str, feature_config)
112
+
113
+ data_config_str = '''
114
+ input_fields {
115
+ input_name: 'clk'
116
+ input_type: INT32
117
+ default_val: '0'
118
+ }
119
+ input_fields {
120
+ input_name: 'field1'
121
+ input_type: STRING
122
+ default_val: '0'
123
+ }
124
+ label_fields: 'clk'
125
+ batch_size: 1
126
+ '''
127
+ data_config = DatasetConfig()
128
+ text_format.Merge(data_config_str, data_config)
129
+
130
+ feature_configs = [feature_config]
131
+ features = {'field1': tf.constant(['0112', '132430'])}
132
+ dummy_input = DummyInput(
133
+ data_config, feature_configs, '', input_vals=features)
134
+ field_dict, _ = dummy_input._build(tf.estimator.ModeKeys.TRAIN, {})
135
+
136
+ wide_and_deep_dict = {'field1': WideOrDeep.DEEP}
137
+ fc_parser = FeatureColumnParser(feature_configs, wide_and_deep_dict)
138
+ builder = feature_column._LazyBuilder(field_dict)
139
+ hist_embedding, hist_seq_len = \
140
+ fc_parser.sequence_columns['field1']._get_sequence_dense_tensor(builder)
141
+
142
+ init = tf.initialize_all_variables()
143
+ with tf.Session() as sess:
144
+ sess.run(init)
145
+ fea_val, len_val = sess.run([hist_embedding, hist_seq_len])
146
+ logging.info('length_val = %s' % str(len_val))
147
+ logging.info('deep fea_val = %s' % str(fea_val))
148
+ assert np.abs(fea_val[0][0][0] - 2) < 1e-6
149
+ assert np.abs(fea_val[0][0][1] - 3) < 1e-6
150
+ assert np.abs(fea_val[0][1][0] - 4) < 1e-6
151
+ assert np.abs(fea_val[0][1][1] - 5) < 1e-6
152
+
153
+
154
+ if __name__ == '__main__':
155
+ tf.test.main()