scitex 2.0.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (572) hide show
  1. scitex/__init__.py +73 -0
  2. scitex/__main__.py +89 -0
  3. scitex/__version__.py +14 -0
  4. scitex/_sh.py +59 -0
  5. scitex/ai/_LearningCurveLogger.py +583 -0
  6. scitex/ai/__Classifiers.py +101 -0
  7. scitex/ai/__init__.py +55 -0
  8. scitex/ai/_gen_ai/_Anthropic.py +173 -0
  9. scitex/ai/_gen_ai/_BaseGenAI.py +336 -0
  10. scitex/ai/_gen_ai/_DeepSeek.py +175 -0
  11. scitex/ai/_gen_ai/_Google.py +161 -0
  12. scitex/ai/_gen_ai/_Groq.py +97 -0
  13. scitex/ai/_gen_ai/_Llama.py +142 -0
  14. scitex/ai/_gen_ai/_OpenAI.py +230 -0
  15. scitex/ai/_gen_ai/_PARAMS.py +565 -0
  16. scitex/ai/_gen_ai/_Perplexity.py +191 -0
  17. scitex/ai/_gen_ai/__init__.py +32 -0
  18. scitex/ai/_gen_ai/_calc_cost.py +78 -0
  19. scitex/ai/_gen_ai/_format_output_func.py +183 -0
  20. scitex/ai/_gen_ai/_genai_factory.py +71 -0
  21. scitex/ai/act/__init__.py +8 -0
  22. scitex/ai/act/_define.py +11 -0
  23. scitex/ai/classification/__init__.py +7 -0
  24. scitex/ai/classification/classification_reporter.py +1137 -0
  25. scitex/ai/classification/classifier_server.py +131 -0
  26. scitex/ai/classification/classifiers.py +101 -0
  27. scitex/ai/classification_reporter.py +1161 -0
  28. scitex/ai/classifier_server.py +131 -0
  29. scitex/ai/clustering/__init__.py +11 -0
  30. scitex/ai/clustering/_pca.py +115 -0
  31. scitex/ai/clustering/_umap.py +376 -0
  32. scitex/ai/early_stopping.py +149 -0
  33. scitex/ai/feature_extraction/__init__.py +56 -0
  34. scitex/ai/feature_extraction/vit.py +148 -0
  35. scitex/ai/genai/__init__.py +277 -0
  36. scitex/ai/genai/anthropic.py +177 -0
  37. scitex/ai/genai/anthropic_provider.py +320 -0
  38. scitex/ai/genai/anthropic_refactored.py +109 -0
  39. scitex/ai/genai/auth_manager.py +200 -0
  40. scitex/ai/genai/base_genai.py +336 -0
  41. scitex/ai/genai/base_provider.py +291 -0
  42. scitex/ai/genai/calc_cost.py +78 -0
  43. scitex/ai/genai/chat_history.py +307 -0
  44. scitex/ai/genai/cost_tracker.py +276 -0
  45. scitex/ai/genai/deepseek.py +188 -0
  46. scitex/ai/genai/deepseek_provider.py +251 -0
  47. scitex/ai/genai/format_output_func.py +183 -0
  48. scitex/ai/genai/genai_factory.py +71 -0
  49. scitex/ai/genai/google.py +169 -0
  50. scitex/ai/genai/google_provider.py +228 -0
  51. scitex/ai/genai/groq.py +104 -0
  52. scitex/ai/genai/groq_provider.py +248 -0
  53. scitex/ai/genai/image_processor.py +250 -0
  54. scitex/ai/genai/llama.py +155 -0
  55. scitex/ai/genai/llama_provider.py +214 -0
  56. scitex/ai/genai/mock_provider.py +127 -0
  57. scitex/ai/genai/model_registry.py +304 -0
  58. scitex/ai/genai/openai.py +230 -0
  59. scitex/ai/genai/openai_provider.py +293 -0
  60. scitex/ai/genai/params.py +565 -0
  61. scitex/ai/genai/perplexity.py +202 -0
  62. scitex/ai/genai/perplexity_provider.py +205 -0
  63. scitex/ai/genai/provider_base.py +302 -0
  64. scitex/ai/genai/provider_factory.py +370 -0
  65. scitex/ai/genai/response_handler.py +235 -0
  66. scitex/ai/layer/_Pass.py +21 -0
  67. scitex/ai/layer/__init__.py +10 -0
  68. scitex/ai/layer/_switch.py +8 -0
  69. scitex/ai/loss/_L1L2Losses.py +34 -0
  70. scitex/ai/loss/__init__.py +12 -0
  71. scitex/ai/loss/multi_task_loss.py +47 -0
  72. scitex/ai/metrics/__init__.py +9 -0
  73. scitex/ai/metrics/_bACC.py +51 -0
  74. scitex/ai/metrics/silhoute_score_block.py +496 -0
  75. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/__init__.py +0 -0
  76. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/__init__.py +3 -0
  77. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/ranger.py +207 -0
  78. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/ranger2020.py +238 -0
  79. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/ranger913A.py +215 -0
  80. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/ranger/rangerqh.py +184 -0
  81. scitex/ai/optim/Ranger_Deep_Learning_Optimizer/setup.py +24 -0
  82. scitex/ai/optim/__init__.py +13 -0
  83. scitex/ai/optim/_get_set.py +31 -0
  84. scitex/ai/optim/_optimizers.py +71 -0
  85. scitex/ai/plt/__init__.py +21 -0
  86. scitex/ai/plt/_conf_mat.py +592 -0
  87. scitex/ai/plt/_learning_curve.py +194 -0
  88. scitex/ai/plt/_optuna_study.py +111 -0
  89. scitex/ai/plt/aucs/__init__.py +2 -0
  90. scitex/ai/plt/aucs/example.py +60 -0
  91. scitex/ai/plt/aucs/pre_rec_auc.py +223 -0
  92. scitex/ai/plt/aucs/roc_auc.py +246 -0
  93. scitex/ai/sampling/undersample.py +29 -0
  94. scitex/ai/sk/__init__.py +11 -0
  95. scitex/ai/sk/_clf.py +58 -0
  96. scitex/ai/sk/_to_sktime.py +100 -0
  97. scitex/ai/sklearn/__init__.py +26 -0
  98. scitex/ai/sklearn/clf.py +58 -0
  99. scitex/ai/sklearn/to_sktime.py +100 -0
  100. scitex/ai/training/__init__.py +7 -0
  101. scitex/ai/training/early_stopping.py +150 -0
  102. scitex/ai/training/learning_curve_logger.py +555 -0
  103. scitex/ai/utils/__init__.py +22 -0
  104. scitex/ai/utils/_check_params.py +50 -0
  105. scitex/ai/utils/_default_dataset.py +46 -0
  106. scitex/ai/utils/_format_samples_for_sktime.py +26 -0
  107. scitex/ai/utils/_label_encoder.py +134 -0
  108. scitex/ai/utils/_merge_labels.py +22 -0
  109. scitex/ai/utils/_sliding_window_data_augmentation.py +11 -0
  110. scitex/ai/utils/_under_sample.py +51 -0
  111. scitex/ai/utils/_verify_n_gpus.py +16 -0
  112. scitex/ai/utils/grid_search.py +148 -0
  113. scitex/context/__init__.py +9 -0
  114. scitex/context/_suppress_output.py +38 -0
  115. scitex/db/_BaseMixins/_BaseBackupMixin.py +30 -0
  116. scitex/db/_BaseMixins/_BaseBatchMixin.py +31 -0
  117. scitex/db/_BaseMixins/_BaseBlobMixin.py +81 -0
  118. scitex/db/_BaseMixins/_BaseConnectionMixin.py +43 -0
  119. scitex/db/_BaseMixins/_BaseImportExportMixin.py +39 -0
  120. scitex/db/_BaseMixins/_BaseIndexMixin.py +29 -0
  121. scitex/db/_BaseMixins/_BaseMaintenanceMixin.py +33 -0
  122. scitex/db/_BaseMixins/_BaseQueryMixin.py +52 -0
  123. scitex/db/_BaseMixins/_BaseRowMixin.py +32 -0
  124. scitex/db/_BaseMixins/_BaseSchemaMixin.py +44 -0
  125. scitex/db/_BaseMixins/_BaseTableMixin.py +66 -0
  126. scitex/db/_BaseMixins/_BaseTransactionMixin.py +52 -0
  127. scitex/db/_BaseMixins/__init__.py +30 -0
  128. scitex/db/_PostgreSQL.py +126 -0
  129. scitex/db/_PostgreSQLMixins/_BackupMixin.py +166 -0
  130. scitex/db/_PostgreSQLMixins/_BatchMixin.py +82 -0
  131. scitex/db/_PostgreSQLMixins/_BlobMixin.py +231 -0
  132. scitex/db/_PostgreSQLMixins/_ConnectionMixin.py +92 -0
  133. scitex/db/_PostgreSQLMixins/_ImportExportMixin.py +59 -0
  134. scitex/db/_PostgreSQLMixins/_IndexMixin.py +64 -0
  135. scitex/db/_PostgreSQLMixins/_MaintenanceMixin.py +175 -0
  136. scitex/db/_PostgreSQLMixins/_QueryMixin.py +108 -0
  137. scitex/db/_PostgreSQLMixins/_RowMixin.py +75 -0
  138. scitex/db/_PostgreSQLMixins/_SchemaMixin.py +126 -0
  139. scitex/db/_PostgreSQLMixins/_TableMixin.py +176 -0
  140. scitex/db/_PostgreSQLMixins/_TransactionMixin.py +57 -0
  141. scitex/db/_PostgreSQLMixins/__init__.py +34 -0
  142. scitex/db/_SQLite3.py +2136 -0
  143. scitex/db/_SQLite3Mixins/_BatchMixin.py +243 -0
  144. scitex/db/_SQLite3Mixins/_BlobMixin.py +229 -0
  145. scitex/db/_SQLite3Mixins/_ConnectionMixin.py +108 -0
  146. scitex/db/_SQLite3Mixins/_ImportExportMixin.py +80 -0
  147. scitex/db/_SQLite3Mixins/_IndexMixin.py +32 -0
  148. scitex/db/_SQLite3Mixins/_MaintenanceMixin.py +176 -0
  149. scitex/db/_SQLite3Mixins/_QueryMixin.py +83 -0
  150. scitex/db/_SQLite3Mixins/_RowMixin.py +75 -0
  151. scitex/db/_SQLite3Mixins/_TableMixin.py +183 -0
  152. scitex/db/_SQLite3Mixins/_TransactionMixin.py +71 -0
  153. scitex/db/_SQLite3Mixins/__init__.py +30 -0
  154. scitex/db/__init__.py +14 -0
  155. scitex/db/_delete_duplicates.py +397 -0
  156. scitex/db/_inspect.py +163 -0
  157. scitex/decorators/__init__.py +54 -0
  158. scitex/decorators/_auto_order.py +172 -0
  159. scitex/decorators/_batch_fn.py +127 -0
  160. scitex/decorators/_cache_disk.py +32 -0
  161. scitex/decorators/_cache_mem.py +12 -0
  162. scitex/decorators/_combined.py +98 -0
  163. scitex/decorators/_converters.py +282 -0
  164. scitex/decorators/_deprecated.py +26 -0
  165. scitex/decorators/_not_implemented.py +30 -0
  166. scitex/decorators/_numpy_fn.py +86 -0
  167. scitex/decorators/_pandas_fn.py +121 -0
  168. scitex/decorators/_preserve_doc.py +19 -0
  169. scitex/decorators/_signal_fn.py +95 -0
  170. scitex/decorators/_timeout.py +55 -0
  171. scitex/decorators/_torch_fn.py +136 -0
  172. scitex/decorators/_wrap.py +39 -0
  173. scitex/decorators/_xarray_fn.py +88 -0
  174. scitex/dev/__init__.py +15 -0
  175. scitex/dev/_analyze_code_flow.py +284 -0
  176. scitex/dev/_reload.py +59 -0
  177. scitex/dict/_DotDict.py +442 -0
  178. scitex/dict/__init__.py +18 -0
  179. scitex/dict/_listed_dict.py +42 -0
  180. scitex/dict/_pop_keys.py +36 -0
  181. scitex/dict/_replace.py +13 -0
  182. scitex/dict/_safe_merge.py +62 -0
  183. scitex/dict/_to_str.py +32 -0
  184. scitex/dsp/__init__.py +72 -0
  185. scitex/dsp/_crop.py +122 -0
  186. scitex/dsp/_demo_sig.py +331 -0
  187. scitex/dsp/_detect_ripples.py +212 -0
  188. scitex/dsp/_ensure_3d.py +18 -0
  189. scitex/dsp/_hilbert.py +78 -0
  190. scitex/dsp/_listen.py +702 -0
  191. scitex/dsp/_misc.py +30 -0
  192. scitex/dsp/_mne.py +32 -0
  193. scitex/dsp/_modulation_index.py +79 -0
  194. scitex/dsp/_pac.py +319 -0
  195. scitex/dsp/_psd.py +102 -0
  196. scitex/dsp/_resample.py +65 -0
  197. scitex/dsp/_time.py +36 -0
  198. scitex/dsp/_transform.py +68 -0
  199. scitex/dsp/_wavelet.py +212 -0
  200. scitex/dsp/add_noise.py +111 -0
  201. scitex/dsp/example.py +253 -0
  202. scitex/dsp/filt.py +155 -0
  203. scitex/dsp/norm.py +18 -0
  204. scitex/dsp/params.py +51 -0
  205. scitex/dsp/reference.py +43 -0
  206. scitex/dsp/template.py +25 -0
  207. scitex/dsp/utils/__init__.py +15 -0
  208. scitex/dsp/utils/_differential_bandpass_filters.py +120 -0
  209. scitex/dsp/utils/_ensure_3d.py +18 -0
  210. scitex/dsp/utils/_ensure_even_len.py +10 -0
  211. scitex/dsp/utils/_zero_pad.py +48 -0
  212. scitex/dsp/utils/filter.py +408 -0
  213. scitex/dsp/utils/pac.py +177 -0
  214. scitex/dt/__init__.py +8 -0
  215. scitex/dt/_linspace.py +130 -0
  216. scitex/etc/__init__.py +15 -0
  217. scitex/etc/wait_key.py +34 -0
  218. scitex/gen/_DimHandler.py +196 -0
  219. scitex/gen/_TimeStamper.py +244 -0
  220. scitex/gen/__init__.py +95 -0
  221. scitex/gen/_alternate_kwarg.py +13 -0
  222. scitex/gen/_cache.py +11 -0
  223. scitex/gen/_check_host.py +34 -0
  224. scitex/gen/_ci.py +12 -0
  225. scitex/gen/_close.py +222 -0
  226. scitex/gen/_embed.py +78 -0
  227. scitex/gen/_inspect_module.py +257 -0
  228. scitex/gen/_is_ipython.py +12 -0
  229. scitex/gen/_less.py +48 -0
  230. scitex/gen/_list_packages.py +139 -0
  231. scitex/gen/_mat2py.py +88 -0
  232. scitex/gen/_norm.py +170 -0
  233. scitex/gen/_paste.py +18 -0
  234. scitex/gen/_print_config.py +84 -0
  235. scitex/gen/_shell.py +48 -0
  236. scitex/gen/_src.py +111 -0
  237. scitex/gen/_start.py +451 -0
  238. scitex/gen/_symlink.py +55 -0
  239. scitex/gen/_symlog.py +27 -0
  240. scitex/gen/_tee.py +238 -0
  241. scitex/gen/_title2path.py +60 -0
  242. scitex/gen/_title_case.py +88 -0
  243. scitex/gen/_to_even.py +84 -0
  244. scitex/gen/_to_odd.py +34 -0
  245. scitex/gen/_to_rank.py +39 -0
  246. scitex/gen/_transpose.py +37 -0
  247. scitex/gen/_type.py +78 -0
  248. scitex/gen/_var_info.py +73 -0
  249. scitex/gen/_wrap.py +17 -0
  250. scitex/gen/_xml2dict.py +76 -0
  251. scitex/gen/misc.py +730 -0
  252. scitex/gen/path.py +0 -0
  253. scitex/general/__init__.py +5 -0
  254. scitex/gists/_SigMacro_processFigure_S.py +128 -0
  255. scitex/gists/_SigMacro_toBlue.py +172 -0
  256. scitex/gists/__init__.py +12 -0
  257. scitex/io/_H5Explorer.py +292 -0
  258. scitex/io/__init__.py +82 -0
  259. scitex/io/_cache.py +101 -0
  260. scitex/io/_flush.py +24 -0
  261. scitex/io/_glob.py +103 -0
  262. scitex/io/_json2md.py +113 -0
  263. scitex/io/_load.py +168 -0
  264. scitex/io/_load_configs.py +146 -0
  265. scitex/io/_load_modules/__init__.py +38 -0
  266. scitex/io/_load_modules/_catboost.py +66 -0
  267. scitex/io/_load_modules/_con.py +20 -0
  268. scitex/io/_load_modules/_db.py +24 -0
  269. scitex/io/_load_modules/_docx.py +42 -0
  270. scitex/io/_load_modules/_eeg.py +110 -0
  271. scitex/io/_load_modules/_hdf5.py +196 -0
  272. scitex/io/_load_modules/_image.py +19 -0
  273. scitex/io/_load_modules/_joblib.py +19 -0
  274. scitex/io/_load_modules/_json.py +18 -0
  275. scitex/io/_load_modules/_markdown.py +103 -0
  276. scitex/io/_load_modules/_matlab.py +37 -0
  277. scitex/io/_load_modules/_numpy.py +39 -0
  278. scitex/io/_load_modules/_optuna.py +155 -0
  279. scitex/io/_load_modules/_pandas.py +69 -0
  280. scitex/io/_load_modules/_pdf.py +31 -0
  281. scitex/io/_load_modules/_pickle.py +24 -0
  282. scitex/io/_load_modules/_torch.py +16 -0
  283. scitex/io/_load_modules/_txt.py +126 -0
  284. scitex/io/_load_modules/_xml.py +49 -0
  285. scitex/io/_load_modules/_yaml.py +23 -0
  286. scitex/io/_mv_to_tmp.py +19 -0
  287. scitex/io/_path.py +286 -0
  288. scitex/io/_reload.py +78 -0
  289. scitex/io/_save.py +539 -0
  290. scitex/io/_save_modules/__init__.py +66 -0
  291. scitex/io/_save_modules/_catboost.py +22 -0
  292. scitex/io/_save_modules/_csv.py +89 -0
  293. scitex/io/_save_modules/_excel.py +49 -0
  294. scitex/io/_save_modules/_hdf5.py +249 -0
  295. scitex/io/_save_modules/_html.py +48 -0
  296. scitex/io/_save_modules/_image.py +140 -0
  297. scitex/io/_save_modules/_joblib.py +25 -0
  298. scitex/io/_save_modules/_json.py +25 -0
  299. scitex/io/_save_modules/_listed_dfs_as_csv.py +57 -0
  300. scitex/io/_save_modules/_listed_scalars_as_csv.py +42 -0
  301. scitex/io/_save_modules/_matlab.py +24 -0
  302. scitex/io/_save_modules/_mp4.py +29 -0
  303. scitex/io/_save_modules/_numpy.py +57 -0
  304. scitex/io/_save_modules/_optuna_study_as_csv_and_pngs.py +38 -0
  305. scitex/io/_save_modules/_pickle.py +45 -0
  306. scitex/io/_save_modules/_plotly.py +27 -0
  307. scitex/io/_save_modules/_text.py +23 -0
  308. scitex/io/_save_modules/_torch.py +26 -0
  309. scitex/io/_save_modules/_yaml.py +29 -0
  310. scitex/life/__init__.py +10 -0
  311. scitex/life/_monitor_rain.py +49 -0
  312. scitex/linalg/__init__.py +17 -0
  313. scitex/linalg/_distance.py +63 -0
  314. scitex/linalg/_geometric_median.py +64 -0
  315. scitex/linalg/_misc.py +73 -0
  316. scitex/nn/_AxiswiseDropout.py +27 -0
  317. scitex/nn/_BNet.py +126 -0
  318. scitex/nn/_BNet_Res.py +164 -0
  319. scitex/nn/_ChannelGainChanger.py +44 -0
  320. scitex/nn/_DropoutChannels.py +50 -0
  321. scitex/nn/_Filters.py +489 -0
  322. scitex/nn/_FreqGainChanger.py +110 -0
  323. scitex/nn/_GaussianFilter.py +48 -0
  324. scitex/nn/_Hilbert.py +111 -0
  325. scitex/nn/_MNet_1000.py +157 -0
  326. scitex/nn/_ModulationIndex.py +221 -0
  327. scitex/nn/_PAC.py +414 -0
  328. scitex/nn/_PSD.py +40 -0
  329. scitex/nn/_ResNet1D.py +120 -0
  330. scitex/nn/_SpatialAttention.py +25 -0
  331. scitex/nn/_Spectrogram.py +161 -0
  332. scitex/nn/_SwapChannels.py +50 -0
  333. scitex/nn/_TransposeLayer.py +19 -0
  334. scitex/nn/_Wavelet.py +183 -0
  335. scitex/nn/__init__.py +63 -0
  336. scitex/os/__init__.py +8 -0
  337. scitex/os/_mv.py +50 -0
  338. scitex/parallel/__init__.py +8 -0
  339. scitex/parallel/_run.py +151 -0
  340. scitex/path/__init__.py +33 -0
  341. scitex/path/_clean.py +52 -0
  342. scitex/path/_find.py +108 -0
  343. scitex/path/_get_module_path.py +51 -0
  344. scitex/path/_get_spath.py +35 -0
  345. scitex/path/_getsize.py +18 -0
  346. scitex/path/_increment_version.py +87 -0
  347. scitex/path/_mk_spath.py +51 -0
  348. scitex/path/_path.py +19 -0
  349. scitex/path/_split.py +23 -0
  350. scitex/path/_this_path.py +19 -0
  351. scitex/path/_version.py +101 -0
  352. scitex/pd/__init__.py +41 -0
  353. scitex/pd/_find_indi.py +126 -0
  354. scitex/pd/_find_pval.py +113 -0
  355. scitex/pd/_force_df.py +154 -0
  356. scitex/pd/_from_xyz.py +71 -0
  357. scitex/pd/_ignore_SettingWithCopyWarning.py +34 -0
  358. scitex/pd/_melt_cols.py +81 -0
  359. scitex/pd/_merge_columns.py +221 -0
  360. scitex/pd/_mv.py +63 -0
  361. scitex/pd/_replace.py +62 -0
  362. scitex/pd/_round.py +93 -0
  363. scitex/pd/_slice.py +63 -0
  364. scitex/pd/_sort.py +91 -0
  365. scitex/pd/_to_numeric.py +53 -0
  366. scitex/pd/_to_xy.py +59 -0
  367. scitex/pd/_to_xyz.py +110 -0
  368. scitex/plt/__init__.py +36 -0
  369. scitex/plt/_subplots/_AxesWrapper.py +182 -0
  370. scitex/plt/_subplots/_AxisWrapper.py +249 -0
  371. scitex/plt/_subplots/_AxisWrapperMixins/_AdjustmentMixin.py +414 -0
  372. scitex/plt/_subplots/_AxisWrapperMixins/_MatplotlibPlotMixin.py +896 -0
  373. scitex/plt/_subplots/_AxisWrapperMixins/_SeabornMixin.py +368 -0
  374. scitex/plt/_subplots/_AxisWrapperMixins/_TrackingMixin.py +185 -0
  375. scitex/plt/_subplots/_AxisWrapperMixins/__init__.py +16 -0
  376. scitex/plt/_subplots/_FigWrapper.py +226 -0
  377. scitex/plt/_subplots/_SubplotsWrapper.py +171 -0
  378. scitex/plt/_subplots/__init__.py +111 -0
  379. scitex/plt/_subplots/_export_as_csv.py +232 -0
  380. scitex/plt/_subplots/_export_as_csv_formatters/__init__.py +61 -0
  381. scitex/plt/_subplots/_export_as_csv_formatters/_format_bar.py +90 -0
  382. scitex/plt/_subplots/_export_as_csv_formatters/_format_barh.py +49 -0
  383. scitex/plt/_subplots/_export_as_csv_formatters/_format_boxplot.py +46 -0
  384. scitex/plt/_subplots/_export_as_csv_formatters/_format_contour.py +39 -0
  385. scitex/plt/_subplots/_export_as_csv_formatters/_format_errorbar.py +125 -0
  386. scitex/plt/_subplots/_export_as_csv_formatters/_format_eventplot.py +72 -0
  387. scitex/plt/_subplots/_export_as_csv_formatters/_format_fill.py +34 -0
  388. scitex/plt/_subplots/_export_as_csv_formatters/_format_fill_between.py +36 -0
  389. scitex/plt/_subplots/_export_as_csv_formatters/_format_hist.py +79 -0
  390. scitex/plt/_subplots/_export_as_csv_formatters/_format_imshow.py +59 -0
  391. scitex/plt/_subplots/_export_as_csv_formatters/_format_imshow2d.py +32 -0
  392. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot.py +79 -0
  393. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_box.py +75 -0
  394. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_conf_mat.py +64 -0
  395. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_ecdf.py +44 -0
  396. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_fillv.py +70 -0
  397. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_heatmap.py +66 -0
  398. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_image.py +95 -0
  399. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_joyplot.py +67 -0
  400. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_kde.py +52 -0
  401. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_line.py +46 -0
  402. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_mean_ci.py +46 -0
  403. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_mean_std.py +46 -0
  404. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_median_iqr.py +46 -0
  405. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_raster.py +44 -0
  406. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_rectangle.py +103 -0
  407. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_scatter_hist.py +82 -0
  408. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_shaded_line.py +58 -0
  409. scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_violin.py +117 -0
  410. scitex/plt/_subplots/_export_as_csv_formatters/_format_scatter.py +30 -0
  411. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_barplot.py +51 -0
  412. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_boxplot.py +93 -0
  413. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_heatmap.py +94 -0
  414. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_histplot.py +92 -0
  415. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_jointplot.py +65 -0
  416. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_kdeplot.py +59 -0
  417. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_lineplot.py +58 -0
  418. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_pairplot.py +45 -0
  419. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_scatterplot.py +70 -0
  420. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_stripplot.py +75 -0
  421. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_swarmplot.py +75 -0
  422. scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_violinplot.py +155 -0
  423. scitex/plt/_subplots/_export_as_csv_formatters/_format_violin.py +64 -0
  424. scitex/plt/_subplots/_export_as_csv_formatters/_format_violinplot.py +77 -0
  425. scitex/plt/_subplots/_export_as_csv_formatters/test_formatters.py +210 -0
  426. scitex/plt/_subplots/_export_as_csv_formatters/verify_formatters.py +342 -0
  427. scitex/plt/_subplots/_export_as_csv_formatters.py +115 -0
  428. scitex/plt/_tpl.py +28 -0
  429. scitex/plt/ax/__init__.py +114 -0
  430. scitex/plt/ax/_plot/__init__.py +53 -0
  431. scitex/plt/ax/_plot/_plot_circular_hist.py +124 -0
  432. scitex/plt/ax/_plot/_plot_conf_mat.py +136 -0
  433. scitex/plt/ax/_plot/_plot_cube.py +57 -0
  434. scitex/plt/ax/_plot/_plot_ecdf.py +84 -0
  435. scitex/plt/ax/_plot/_plot_fillv.py +55 -0
  436. scitex/plt/ax/_plot/_plot_heatmap.py +266 -0
  437. scitex/plt/ax/_plot/_plot_image.py +94 -0
  438. scitex/plt/ax/_plot/_plot_joyplot.py +76 -0
  439. scitex/plt/ax/_plot/_plot_raster.py +172 -0
  440. scitex/plt/ax/_plot/_plot_rectangle.py +69 -0
  441. scitex/plt/ax/_plot/_plot_scatter_hist.py +133 -0
  442. scitex/plt/ax/_plot/_plot_shaded_line.py +142 -0
  443. scitex/plt/ax/_plot/_plot_statistical_shaded_line.py +221 -0
  444. scitex/plt/ax/_plot/_plot_violin.py +343 -0
  445. scitex/plt/ax/_style/__init__.py +38 -0
  446. scitex/plt/ax/_style/_add_marginal_ax.py +44 -0
  447. scitex/plt/ax/_style/_add_panel.py +92 -0
  448. scitex/plt/ax/_style/_extend.py +64 -0
  449. scitex/plt/ax/_style/_force_aspect.py +37 -0
  450. scitex/plt/ax/_style/_format_label.py +23 -0
  451. scitex/plt/ax/_style/_hide_spines.py +84 -0
  452. scitex/plt/ax/_style/_map_ticks.py +182 -0
  453. scitex/plt/ax/_style/_rotate_labels.py +215 -0
  454. scitex/plt/ax/_style/_sci_note.py +279 -0
  455. scitex/plt/ax/_style/_set_log_scale.py +299 -0
  456. scitex/plt/ax/_style/_set_meta.py +261 -0
  457. scitex/plt/ax/_style/_set_n_ticks.py +37 -0
  458. scitex/plt/ax/_style/_set_size.py +16 -0
  459. scitex/plt/ax/_style/_set_supxyt.py +116 -0
  460. scitex/plt/ax/_style/_set_ticks.py +276 -0
  461. scitex/plt/ax/_style/_set_xyt.py +121 -0
  462. scitex/plt/ax/_style/_share_axes.py +264 -0
  463. scitex/plt/ax/_style/_shift.py +139 -0
  464. scitex/plt/ax/_style/_show_spines.py +333 -0
  465. scitex/plt/color/_PARAMS.py +70 -0
  466. scitex/plt/color/__init__.py +52 -0
  467. scitex/plt/color/_add_hue_col.py +41 -0
  468. scitex/plt/color/_colors.py +205 -0
  469. scitex/plt/color/_get_colors_from_cmap.py +134 -0
  470. scitex/plt/color/_interpolate.py +29 -0
  471. scitex/plt/color/_vizualize_colors.py +54 -0
  472. scitex/plt/utils/__init__.py +44 -0
  473. scitex/plt/utils/_calc_bacc_from_conf_mat.py +46 -0
  474. scitex/plt/utils/_calc_nice_ticks.py +101 -0
  475. scitex/plt/utils/_close.py +68 -0
  476. scitex/plt/utils/_colorbar.py +96 -0
  477. scitex/plt/utils/_configure_mpl.py +295 -0
  478. scitex/plt/utils/_histogram_utils.py +132 -0
  479. scitex/plt/utils/_im2grid.py +70 -0
  480. scitex/plt/utils/_is_valid_axis.py +78 -0
  481. scitex/plt/utils/_mk_colorbar.py +65 -0
  482. scitex/plt/utils/_mk_patches.py +26 -0
  483. scitex/plt/utils/_scientific_captions.py +638 -0
  484. scitex/plt/utils/_scitex_config.py +223 -0
  485. scitex/reproduce/__init__.py +14 -0
  486. scitex/reproduce/_fix_seeds.py +45 -0
  487. scitex/reproduce/_gen_ID.py +55 -0
  488. scitex/reproduce/_gen_timestamp.py +35 -0
  489. scitex/res/__init__.py +5 -0
  490. scitex/resource/__init__.py +13 -0
  491. scitex/resource/_get_processor_usages.py +281 -0
  492. scitex/resource/_get_specs.py +280 -0
  493. scitex/resource/_log_processor_usages.py +190 -0
  494. scitex/resource/_utils/__init__.py +31 -0
  495. scitex/resource/_utils/_get_env_info.py +481 -0
  496. scitex/resource/limit_ram.py +33 -0
  497. scitex/scholar/__init__.py +24 -0
  498. scitex/scholar/_local_search.py +454 -0
  499. scitex/scholar/_paper.py +244 -0
  500. scitex/scholar/_pdf_downloader.py +325 -0
  501. scitex/scholar/_search.py +393 -0
  502. scitex/scholar/_vector_search.py +370 -0
  503. scitex/scholar/_web_sources.py +457 -0
  504. scitex/stats/__init__.py +31 -0
  505. scitex/stats/_calc_partial_corr.py +17 -0
  506. scitex/stats/_corr_test_multi.py +94 -0
  507. scitex/stats/_corr_test_wrapper.py +115 -0
  508. scitex/stats/_describe_wrapper.py +90 -0
  509. scitex/stats/_multiple_corrections.py +63 -0
  510. scitex/stats/_nan_stats.py +93 -0
  511. scitex/stats/_p2stars.py +116 -0
  512. scitex/stats/_p2stars_wrapper.py +56 -0
  513. scitex/stats/_statistical_tests.py +73 -0
  514. scitex/stats/desc/__init__.py +40 -0
  515. scitex/stats/desc/_describe.py +189 -0
  516. scitex/stats/desc/_nan.py +289 -0
  517. scitex/stats/desc/_real.py +94 -0
  518. scitex/stats/multiple/__init__.py +14 -0
  519. scitex/stats/multiple/_bonferroni_correction.py +72 -0
  520. scitex/stats/multiple/_fdr_correction.py +400 -0
  521. scitex/stats/multiple/_multicompair.py +28 -0
  522. scitex/stats/tests/__corr_test.py +277 -0
  523. scitex/stats/tests/__corr_test_multi.py +343 -0
  524. scitex/stats/tests/__corr_test_single.py +277 -0
  525. scitex/stats/tests/__init__.py +22 -0
  526. scitex/stats/tests/_brunner_munzel_test.py +192 -0
  527. scitex/stats/tests/_nocorrelation_test.py +28 -0
  528. scitex/stats/tests/_smirnov_grubbs.py +98 -0
  529. scitex/str/__init__.py +113 -0
  530. scitex/str/_clean_path.py +75 -0
  531. scitex/str/_color_text.py +52 -0
  532. scitex/str/_decapitalize.py +58 -0
  533. scitex/str/_factor_out_digits.py +281 -0
  534. scitex/str/_format_plot_text.py +498 -0
  535. scitex/str/_grep.py +48 -0
  536. scitex/str/_latex.py +155 -0
  537. scitex/str/_latex_fallback.py +471 -0
  538. scitex/str/_mask_api.py +39 -0
  539. scitex/str/_mask_api_key.py +8 -0
  540. scitex/str/_parse.py +158 -0
  541. scitex/str/_print_block.py +47 -0
  542. scitex/str/_print_debug.py +68 -0
  543. scitex/str/_printc.py +62 -0
  544. scitex/str/_readable_bytes.py +38 -0
  545. scitex/str/_remove_ansi.py +23 -0
  546. scitex/str/_replace.py +134 -0
  547. scitex/str/_search.py +125 -0
  548. scitex/str/_squeeze_space.py +36 -0
  549. scitex/tex/__init__.py +10 -0
  550. scitex/tex/_preview.py +103 -0
  551. scitex/tex/_to_vec.py +116 -0
  552. scitex/torch/__init__.py +18 -0
  553. scitex/torch/_apply_to.py +34 -0
  554. scitex/torch/_nan_funcs.py +77 -0
  555. scitex/types/_ArrayLike.py +44 -0
  556. scitex/types/_ColorLike.py +21 -0
  557. scitex/types/__init__.py +14 -0
  558. scitex/types/_is_listed_X.py +70 -0
  559. scitex/utils/__init__.py +22 -0
  560. scitex/utils/_compress_hdf5.py +116 -0
  561. scitex/utils/_email.py +120 -0
  562. scitex/utils/_grid.py +148 -0
  563. scitex/utils/_notify.py +247 -0
  564. scitex/utils/_search.py +121 -0
  565. scitex/web/__init__.py +38 -0
  566. scitex/web/_search_pubmed.py +438 -0
  567. scitex/web/_summarize_url.py +158 -0
  568. scitex-2.0.0.dist-info/METADATA +307 -0
  569. scitex-2.0.0.dist-info/RECORD +572 -0
  570. scitex-2.0.0.dist-info/WHEEL +6 -0
  571. scitex-2.0.0.dist-info/licenses/LICENSE +7 -0
  572. scitex-2.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,207 @@
1
+ # Ranger deep learning optimizer - RAdam + Lookahead + Gradient Centralization, combined into one optimizer.
2
+
3
+ # https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer
4
+ # and/or
5
+ # https://github.com/lessw2020/Best-Deep-Learning-Optimizers
6
+
7
+ # Ranger has now been used to capture 12 records on the FastAI leaderboard.
8
+
9
+ # This version = 20.4.11
10
+
11
+ # Credits:
12
+ # Gradient Centralization --> https://arxiv.org/abs/2004.01461v2 (a new optimization technique for DNNs), github: https://github.com/Yonghongwei/Gradient-Centralization
13
+ # RAdam --> https://github.com/LiyuanLucasLiu/RAdam
14
+ # Lookahead --> rewritten by lessw2020, but big thanks to Github @LonePatient and @RWightman for ideas from their code.
15
+ # Lookahead paper --> MZhang,G Hinton https://arxiv.org/abs/1907.08610
16
+
17
+ # summary of changes:
18
+ # 4/11/20 - add gradient centralization option. Set new testing benchmark for accuracy with it, toggle with use_gc flag at init.
19
+ # full code integration with all updates at param level instead of group, moves slow weights into state dict (from generic weights),
20
+ # supports group learning rates (thanks @SHolderbach), fixes sporadic load from saved model issues.
21
+ # changes 8/31/19 - fix references to *self*.N_sma_threshold;
22
+ # changed eps to 1e-5 as better default than 1e-8.
23
+
24
+ import math
25
+ import torch
26
+ from torch.optim.optimizer import Optimizer, required
27
+
28
+
29
+ class Ranger(Optimizer):
30
+
31
+ def __init__(
32
+ self,
33
+ params,
34
+ lr=1e-3, # lr
35
+ alpha=0.5,
36
+ k=6,
37
+ N_sma_threshhold=5, # Ranger options
38
+ betas=(0.95, 0.999),
39
+ eps=1e-5,
40
+ weight_decay=0, # Adam options
41
+ # Gradient centralization on or off, applied to conv layers only or conv + fc layers
42
+ use_gc=True,
43
+ gc_conv_only=False,
44
+ ):
45
+
46
+ # parameter checks
47
+ if not 0.0 <= alpha <= 1.0:
48
+ raise ValueError(f"Invalid slow update rate: {alpha}")
49
+ if not 1 <= k:
50
+ raise ValueError(f"Invalid lookahead steps: {k}")
51
+ if not lr > 0:
52
+ raise ValueError(f"Invalid Learning Rate: {lr}")
53
+ if not eps > 0:
54
+ raise ValueError(f"Invalid eps: {eps}")
55
+
56
+ # parameter comments:
57
+ # beta1 (momentum) of .95 seems to work better than .90...
58
+ # N_sma_threshold of 5 seems better in testing than 4.
59
+ # In both cases, worth testing on your dataset (.90 vs .95, 4 vs 5) to make sure which works best for you.
60
+
61
+ # prep defaults and init torch.optim base
62
+ defaults = dict(
63
+ lr=lr,
64
+ alpha=alpha,
65
+ k=k,
66
+ step_counter=0,
67
+ betas=betas,
68
+ N_sma_threshhold=N_sma_threshhold,
69
+ eps=eps,
70
+ weight_decay=weight_decay,
71
+ )
72
+ super().__init__(params, defaults)
73
+
74
+ # adjustable threshold
75
+ self.N_sma_threshhold = N_sma_threshhold
76
+
77
+ # look ahead params
78
+
79
+ self.alpha = alpha
80
+ self.k = k
81
+
82
+ # radam buffer for state
83
+ self.radam_buffer = [[None, None, None] for ind in range(10)]
84
+
85
+ # gc on or off
86
+ self.use_gc = use_gc
87
+
88
+ # level of gradient centralization
89
+ self.gc_gradient_threshold = 3 if gc_conv_only else 1
90
+
91
+ print(
92
+ f"Ranger optimizer loaded. \nGradient Centralization usage = {self.use_gc}"
93
+ )
94
+ if self.use_gc and self.gc_gradient_threshold == 1:
95
+ print(f"GC applied to both conv and fc layers")
96
+ elif self.use_gc and self.gc_gradient_threshold == 3:
97
+ print(f"GC applied to conv layers only")
98
+
99
+ def __setstate__(self, state):
100
+ print("set state called")
101
+ super(Ranger, self).__setstate__(state)
102
+
103
+ def step(self, closure=None):
104
+ loss = None
105
+ # note - below is commented out b/c I have other work that passes back the loss as a float, and thus not a callable closure.
106
+ # Uncomment if you need to use the actual closure...
107
+
108
+ # if closure is not None:
109
+ # loss = closure()
110
+
111
+ # Evaluate averages and grad, update param tensors
112
+ for group in self.param_groups:
113
+
114
+ for p in group["params"]:
115
+ if p.grad is None:
116
+ continue
117
+ grad = p.grad.data.float()
118
+
119
+ if grad.is_sparse:
120
+ raise RuntimeError(
121
+ "Ranger optimizer does not support sparse gradients"
122
+ )
123
+
124
+ p_data_fp32 = p.data.float()
125
+
126
+ state = self.state[p] # get state dict for this param
127
+
128
+ if (
129
+ len(state) == 0
130
+ ): # if first time to run...init dictionary with our desired entries
131
+ # if self.first_run_check==0:
132
+ # self.first_run_check=1
133
+ # print("Initializing slow buffer...should not see this at load from saved model!")
134
+ state["step"] = 0
135
+ state["exp_avg"] = torch.zeros_like(p_data_fp32)
136
+ state["exp_avg_sq"] = torch.zeros_like(p_data_fp32)
137
+
138
+ # look ahead weight storage now in state dict
139
+ state["slow_buffer"] = torch.empty_like(p.data)
140
+ state["slow_buffer"].copy_(p.data)
141
+
142
+ else:
143
+ state["exp_avg"] = state["exp_avg"].type_as(p_data_fp32)
144
+ state["exp_avg_sq"] = state["exp_avg_sq"].type_as(p_data_fp32)
145
+
146
+ # begin computations
147
+ exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
148
+ beta1, beta2 = group["betas"]
149
+
150
+ # GC operation for Conv layers and FC layers
151
+ if grad.dim() > self.gc_gradient_threshold:
152
+ grad.add_(-grad.mean(dim=tuple(range(1, grad.dim())), keepdim=True))
153
+
154
+ state["step"] += 1
155
+
156
+ # compute variance mov avg
157
+ exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
158
+ # compute mean moving avg
159
+ exp_avg.mul_(beta1).add_(1 - beta1, grad)
160
+
161
+ buffered = self.radam_buffer[int(state["step"] % 10)]
162
+
163
+ if state["step"] == buffered[0]:
164
+ N_sma, step_size = buffered[1], buffered[2]
165
+ else:
166
+ buffered[0] = state["step"]
167
+ beta2_t = beta2 ** state["step"]
168
+ N_sma_max = 2 / (1 - beta2) - 1
169
+ N_sma = N_sma_max - 2 * state["step"] * beta2_t / (1 - beta2_t)
170
+ buffered[1] = N_sma
171
+ if N_sma > self.N_sma_threshhold:
172
+ step_size = math.sqrt(
173
+ (1 - beta2_t)
174
+ * (N_sma - 4)
175
+ / (N_sma_max - 4)
176
+ * (N_sma - 2)
177
+ / N_sma
178
+ * N_sma_max
179
+ / (N_sma_max - 2)
180
+ ) / (1 - beta1 ** state["step"])
181
+ else:
182
+ step_size = 1.0 / (1 - beta1 ** state["step"])
183
+ buffered[2] = step_size
184
+
185
+ if group["weight_decay"] != 0:
186
+ p_data_fp32.add_(-group["weight_decay"] * group["lr"], p_data_fp32)
187
+
188
+ # apply lr
189
+ if N_sma > self.N_sma_threshhold:
190
+ denom = exp_avg_sq.sqrt().add_(group["eps"])
191
+ p_data_fp32.addcdiv_(-step_size * group["lr"], exp_avg, denom)
192
+ else:
193
+ p_data_fp32.add_(-step_size * group["lr"], exp_avg)
194
+
195
+ p.data.copy_(p_data_fp32)
196
+
197
+ # integrated look ahead...
198
+ # we do it at the param level instead of group level
199
+ if state["step"] % group["k"] == 0:
200
+ # get access to slow param tensor
201
+ slow_p = state["slow_buffer"]
202
+ # (fast weights - slow weights) * alpha
203
+ slow_p.add_(self.alpha, p.data - slow_p)
204
+ # copy interpolated weights to RAdam param tensor
205
+ p.data.copy_(slow_p)
206
+
207
+ return loss
@@ -0,0 +1,238 @@
1
+ # Ranger deep learning optimizer - RAdam + Lookahead + Gradient Centralization, combined into one optimizer.
2
+
3
+ # https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer
4
+ # and/or
5
+ # https://github.com/lessw2020/Best-Deep-Learning-Optimizers
6
+
7
+ # Ranger has been used to capture 12 records on the FastAI leaderboard.
8
+
9
+ # This version = 2020.9.4
10
+
11
+
12
+ # Credits:
13
+ # Gradient Centralization --> https://arxiv.org/abs/2004.01461v2 (a new optimization technique for DNNs), github: https://github.com/Yonghongwei/Gradient-Centralization
14
+ # RAdam --> https://github.com/LiyuanLucasLiu/RAdam
15
+ # Lookahead --> rewritten by lessw2020, but big thanks to Github @LonePatient and @RWightman for ideas from their code.
16
+ # Lookahead paper --> MZhang,G Hinton https://arxiv.org/abs/1907.08610
17
+
18
+ # summary of changes:
19
+ # 9/4/20 - updated addcmul_ signature to avoid warning. Integrates latest changes from GC developer (he did the work for this), and verified on performance on private dataset.
20
+ # 4/11/20 - add gradient centralization option. Set new testing benchmark for accuracy with it, toggle with use_gc flag at init.
21
+ # full code integration with all updates at param level instead of group, moves slow weights into state dict (from generic weights),
22
+ # supports group learning rates (thanks @SHolderbach), fixes sporadic load from saved model issues.
23
+ # changes 8/31/19 - fix references to *self*.N_sma_threshold;
24
+ # changed eps to 1e-5 as better default than 1e-8.
25
+
26
+ import math
27
+ import torch
28
+ from torch.optim.optimizer import Optimizer, required
29
+
30
+
31
+ def centralized_gradient(x, use_gc=True, gc_conv_only=False):
32
+ """credit - https://github.com/Yonghongwei/Gradient-Centralization"""
33
+ if use_gc:
34
+ if gc_conv_only:
35
+ if len(list(x.size())) > 3:
36
+ x.add_(-x.mean(dim=tuple(range(1, len(list(x.size())))), keepdim=True))
37
+ else:
38
+ if len(list(x.size())) > 1:
39
+ x.add_(-x.mean(dim=tuple(range(1, len(list(x.size())))), keepdim=True))
40
+ return x
41
+
42
+
43
+ class Ranger(Optimizer):
44
+
45
+ def __init__(
46
+ self,
47
+ params,
48
+ lr=1e-3, # lr
49
+ alpha=0.5,
50
+ k=6,
51
+ N_sma_threshhold=5, # Ranger options
52
+ betas=(0.95, 0.999),
53
+ eps=1e-5,
54
+ weight_decay=0, # Adam options
55
+ # Gradient centralization on or off, applied to conv layers only or conv + fc layers
56
+ use_gc=True,
57
+ gc_conv_only=False,
58
+ gc_loc=True,
59
+ ):
60
+
61
+ # parameter checks
62
+ if not 0.0 <= alpha <= 1.0:
63
+ raise ValueError(f"Invalid slow update rate: {alpha}")
64
+ if not 1 <= k:
65
+ raise ValueError(f"Invalid lookahead steps: {k}")
66
+ if not lr > 0:
67
+ raise ValueError(f"Invalid Learning Rate: {lr}")
68
+ if not eps > 0:
69
+ raise ValueError(f"Invalid eps: {eps}")
70
+
71
+ # parameter comments:
72
+ # beta1 (momentum) of .95 seems to work better than .90...
73
+ # N_sma_threshold of 5 seems better in testing than 4.
74
+ # In both cases, worth testing on your dataset (.90 vs .95, 4 vs 5) to make sure which works best for you.
75
+
76
+ # prep defaults and init torch.optim base
77
+ defaults = dict(
78
+ lr=lr,
79
+ alpha=alpha,
80
+ k=k,
81
+ step_counter=0,
82
+ betas=betas,
83
+ N_sma_threshhold=N_sma_threshhold,
84
+ eps=eps,
85
+ weight_decay=weight_decay,
86
+ )
87
+ super().__init__(params, defaults)
88
+
89
+ # adjustable threshold
90
+ self.N_sma_threshhold = N_sma_threshhold
91
+
92
+ # look ahead params
93
+
94
+ self.alpha = alpha
95
+ self.k = k
96
+
97
+ # radam buffer for state
98
+ self.radam_buffer = [[None, None, None] for ind in range(10)]
99
+
100
+ # gc on or off
101
+ self.gc_loc = gc_loc
102
+ self.use_gc = use_gc
103
+ self.gc_conv_only = gc_conv_only
104
+ # level of gradient centralization
105
+ # self.gc_gradient_threshold = 3 if gc_conv_only else 1
106
+
107
+ print(
108
+ f"Ranger optimizer loaded. \nGradient Centralization usage = {self.use_gc}"
109
+ )
110
+ if self.use_gc and self.gc_conv_only == False:
111
+ print(f"GC applied to both conv and fc layers")
112
+ elif self.use_gc and self.gc_conv_only == True:
113
+ print(f"GC applied to conv layers only")
114
+
115
+ def __setstate__(self, state):
116
+ print("set state called")
117
+ super(Ranger, self).__setstate__(state)
118
+
119
+ def step(self, closure=None):
120
+ loss = None
121
+ # note - below is commented out b/c I have other work that passes back the loss as a float, and thus not a callable closure.
122
+ # Uncomment if you need to use the actual closure...
123
+
124
+ # if closure is not None:
125
+ # loss = closure()
126
+
127
+ # Evaluate averages and grad, update param tensors
128
+ for group in self.param_groups:
129
+
130
+ for p in group["params"]:
131
+ if p.grad is None:
132
+ continue
133
+ grad = p.grad.data.float()
134
+
135
+ if grad.is_sparse:
136
+ raise RuntimeError(
137
+ "Ranger optimizer does not support sparse gradients"
138
+ )
139
+
140
+ p_data_fp32 = p.data.float()
141
+
142
+ state = self.state[p] # get state dict for this param
143
+
144
+ if (
145
+ len(state) == 0
146
+ ): # if first time to run...init dictionary with our desired entries
147
+ # if self.first_run_check==0:
148
+ # self.first_run_check=1
149
+ # print("Initializing slow buffer...should not see this at load from saved model!")
150
+ state["step"] = 0
151
+ state["exp_avg"] = torch.zeros_like(p_data_fp32)
152
+ state["exp_avg_sq"] = torch.zeros_like(p_data_fp32)
153
+
154
+ # look ahead weight storage now in state dict
155
+ state["slow_buffer"] = torch.empty_like(p.data)
156
+ state["slow_buffer"].copy_(p.data)
157
+
158
+ else:
159
+ state["exp_avg"] = state["exp_avg"].type_as(p_data_fp32)
160
+ state["exp_avg_sq"] = state["exp_avg_sq"].type_as(p_data_fp32)
161
+
162
+ # begin computations
163
+ exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
164
+ beta1, beta2 = group["betas"]
165
+
166
+ # GC operation for Conv layers and FC layers
167
+ # if grad.dim() > self.gc_gradient_threshold:
168
+ # grad.add_(-grad.mean(dim=tuple(range(1, grad.dim())), keepdim=True))
169
+ if self.gc_loc:
170
+ grad = centralized_gradient(
171
+ grad, use_gc=self.use_gc, gc_conv_only=self.gc_conv_only
172
+ )
173
+
174
+ state["step"] += 1
175
+
176
+ # compute variance mov avg
177
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
178
+
179
+ # compute mean moving avg
180
+ exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
181
+
182
+ buffered = self.radam_buffer[int(state["step"] % 10)]
183
+
184
+ if state["step"] == buffered[0]:
185
+ N_sma, step_size = buffered[1], buffered[2]
186
+ else:
187
+ buffered[0] = state["step"]
188
+ beta2_t = beta2 ** state["step"]
189
+ N_sma_max = 2 / (1 - beta2) - 1
190
+ N_sma = N_sma_max - 2 * state["step"] * beta2_t / (1 - beta2_t)
191
+ buffered[1] = N_sma
192
+ if N_sma > self.N_sma_threshhold:
193
+ step_size = math.sqrt(
194
+ (1 - beta2_t)
195
+ * (N_sma - 4)
196
+ / (N_sma_max - 4)
197
+ * (N_sma - 2)
198
+ / N_sma
199
+ * N_sma_max
200
+ / (N_sma_max - 2)
201
+ ) / (1 - beta1 ** state["step"])
202
+ else:
203
+ step_size = 1.0 / (1 - beta1 ** state["step"])
204
+ buffered[2] = step_size
205
+
206
+ # if group['weight_decay'] != 0:
207
+ # p_data_fp32.add_(-group['weight_decay']
208
+ # * group['lr'], p_data_fp32)
209
+
210
+ # apply lr
211
+ if N_sma > self.N_sma_threshhold:
212
+ denom = exp_avg_sq.sqrt().add_(group["eps"])
213
+ G_grad = exp_avg / denom
214
+ else:
215
+ G_grad = exp_avg
216
+
217
+ if group["weight_decay"] != 0:
218
+ G_grad.add_(p_data_fp32, alpha=group["weight_decay"])
219
+ # GC operation
220
+ if self.gc_loc == False:
221
+ G_grad = centralized_gradient(
222
+ G_grad, use_gc=self.use_gc, gc_conv_only=self.gc_conv_only
223
+ )
224
+
225
+ p_data_fp32.add_(G_grad, alpha=-step_size * group["lr"])
226
+ p.data.copy_(p_data_fp32)
227
+
228
+ # integrated look ahead...
229
+ # we do it at the param level instead of group level
230
+ if state["step"] % group["k"] == 0:
231
+ # get access to slow param tensor
232
+ slow_p = state["slow_buffer"]
233
+ # (fast weights - slow weights) * alpha
234
+ slow_p.add_(p.data - slow_p, alpha=self.alpha)
235
+ # copy interpolated weights to RAdam param tensor
236
+ p.data.copy_(slow_p)
237
+
238
+ return loss
@@ -0,0 +1,215 @@
1
+ # Ranger deep learning optimizer - RAdam + Lookahead + calibrated adaptive LR combined.
2
+ # https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer
3
+
4
+ # Ranger has now been used to capture 12 records on the FastAI leaderboard.
5
+
6
+ # This version = 9.13.19A
7
+
8
+ # Credits:
9
+ # RAdam --> https://github.com/LiyuanLucasLiu/RAdam
10
+ # Lookahead --> rewritten by lessw2020, but big thanks to Github @LonePatient and @RWightman for ideas from their code.
11
+ # Lookahead paper --> MZhang,G Hinton https://arxiv.org/abs/1907.08610
12
+ # Calibrated anisotropic adaptive learning rates - https://arxiv.org/abs/1908.00700v2
13
+
14
+ # summary of changes:
15
+ # full code integration with all updates at param level instead of group, moves slow weights into state dict (from generic weights),
16
+ # supports group learning rates (thanks @SHolderbach), fixes sporadic load from saved model issues.
17
+ # changes 8/31/19 - fix references to *self*.N_sma_threshold;
18
+ # changed eps to 1e-5 as better default than 1e-8.
19
+
20
+ import math
21
+ import torch
22
+ from torch.optim.optimizer import Optimizer, required
23
+ import itertools as it
24
+
25
+
26
+ class RangerVA(Optimizer):
27
+
28
+ def __init__(
29
+ self,
30
+ params,
31
+ lr=1e-3,
32
+ alpha=0.5,
33
+ k=6,
34
+ n_sma_threshhold=5,
35
+ betas=(0.95, 0.999),
36
+ eps=1e-5,
37
+ weight_decay=0,
38
+ amsgrad=True,
39
+ transformer="softplus",
40
+ smooth=50,
41
+ grad_transformer="square",
42
+ ):
43
+ # parameter checks
44
+ if not 0.0 <= alpha <= 1.0:
45
+ raise ValueError(f"Invalid slow update rate: {alpha}")
46
+ if not 1 <= k:
47
+ raise ValueError(f"Invalid lookahead steps: {k}")
48
+ if not lr > 0:
49
+ raise ValueError(f"Invalid Learning Rate: {lr}")
50
+ if not eps > 0:
51
+ raise ValueError(f"Invalid eps: {eps}")
52
+
53
+ # parameter comments:
54
+ # beta1 (momentum) of .95 seems to work better than .90...
55
+ # N_sma_threshold of 5 seems better in testing than 4.
56
+ # In both cases, worth testing on your dataset (.90 vs .95, 4 vs 5) to make sure which works best for you.
57
+
58
+ # prep defaults and init torch.optim base
59
+ defaults = dict(
60
+ lr=lr,
61
+ alpha=alpha,
62
+ k=k,
63
+ step_counter=0,
64
+ betas=betas,
65
+ n_sma_threshhold=n_sma_threshhold,
66
+ eps=eps,
67
+ weight_decay=weight_decay,
68
+ smooth=smooth,
69
+ transformer=transformer,
70
+ grad_transformer=grad_transformer,
71
+ amsgrad=amsgrad,
72
+ )
73
+ super().__init__(params, defaults)
74
+
75
+ # adjustable threshold
76
+ self.n_sma_threshhold = n_sma_threshhold
77
+
78
+ # look ahead params
79
+ self.alpha = alpha
80
+ self.k = k
81
+
82
+ # radam buffer for state
83
+ self.radam_buffer = [[None, None, None] for ind in range(10)]
84
+
85
+ # self.first_run_check=0
86
+
87
+ # lookahead weights
88
+ # 9/2/19 - lookahead param tensors have been moved to state storage.
89
+ # This should resolve issues with load/save where weights were left in GPU memory from first load, slowing down future runs.
90
+
91
+ # self.slow_weights = [[p.clone().detach() for p in group['params']]
92
+ # for group in self.param_groups]
93
+
94
+ # don't use grad for lookahead weights
95
+ # for w in it.chain(*self.slow_weights):
96
+ # w.requires_grad = False
97
+
98
+ def __setstate__(self, state):
99
+ print("set state called")
100
+ super(RangerVA, self).__setstate__(state)
101
+
102
+ def step(self, closure=None):
103
+ loss = None
104
+ # note - below is commented out b/c I have other work that passes back the loss as a float, and thus not a callable closure.
105
+ # Uncomment if you need to use the actual closure...
106
+
107
+ # if closure is not None:
108
+ # loss = closure()
109
+
110
+ # Evaluate averages and grad, update param tensors
111
+ for group in self.param_groups:
112
+
113
+ for p in group["params"]:
114
+ if p.grad is None:
115
+ continue
116
+ grad = p.grad.data.float()
117
+ if grad.is_sparse:
118
+ raise RuntimeError(
119
+ "Ranger optimizer does not support sparse gradients"
120
+ )
121
+
122
+ amsgrad = group["amsgrad"]
123
+ smooth = group["smooth"]
124
+ grad_transformer = group["grad_transformer"]
125
+
126
+ p_data_fp32 = p.data.float()
127
+
128
+ state = self.state[p] # get state dict for this param
129
+
130
+ if (
131
+ len(state) == 0
132
+ ): # if first time to run...init dictionary with our desired entries
133
+ # if self.first_run_check==0:
134
+ # self.first_run_check=1
135
+ # print("Initializing slow buffer...should not see this at load from saved model!")
136
+ state["step"] = 0
137
+ state["exp_avg"] = torch.zeros_like(p_data_fp32)
138
+ state["exp_avg_sq"] = torch.zeros_like(p_data_fp32)
139
+ if amsgrad:
140
+ # Maintains max of all exp. moving avg. of sq. grad. values
141
+ state["max_exp_avg_sq"] = torch.zeros_like(p.data)
142
+
143
+ # look ahead weight storage now in state dict
144
+ state["slow_buffer"] = torch.empty_like(p.data)
145
+ state["slow_buffer"].copy_(p.data)
146
+
147
+ else:
148
+ state["exp_avg"] = state["exp_avg"].type_as(p_data_fp32)
149
+ state["exp_avg_sq"] = state["exp_avg_sq"].type_as(p_data_fp32)
150
+
151
+ # begin computations
152
+ exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
153
+ beta1, beta2 = group["betas"]
154
+ if amsgrad:
155
+ max_exp_avg_sq = state["max_exp_avg_sq"]
156
+
157
+ # compute variance mov avg
158
+ exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
159
+ # compute mean moving avg
160
+ exp_avg.mul_(beta1).add_(1 - beta1, grad)
161
+
162
+ ##transformer
163
+ if grad_transformer == "square":
164
+ grad_tmp = grad**2
165
+ elif grad_transformer == "abs":
166
+ grad_tmp = grad.abs()
167
+
168
+ exp_avg_sq.mul_(beta2).add_((1 - beta2) * grad_tmp)
169
+
170
+ if amsgrad:
171
+ # Maintains the maximum of all 2nd moment running avg. till now
172
+ torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
173
+ # Use the max. for normalizing running avg. of gradient
174
+ denomc = max_exp_avg_sq.clone()
175
+ else:
176
+ denomc = exp_avg_sq.clone()
177
+
178
+ if grad_transformer == "square":
179
+ # pdb.set_trace()
180
+ denomc.sqrt_()
181
+
182
+ state["step"] += 1
183
+
184
+ if group["weight_decay"] != 0:
185
+ p_data_fp32.add_(-group["weight_decay"] * group["lr"], p_data_fp32)
186
+
187
+ bias_correction1 = 1 - beta1 ** state["step"]
188
+ bias_correction2 = 1 - beta2 ** state["step"]
189
+ step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1
190
+
191
+ # ...let's use calibrated alr
192
+ if group["transformer"] == "softplus":
193
+ sp = torch.nn.Softplus(smooth)
194
+ denomf = sp(denomc)
195
+ p_data_fp32.addcdiv_(-step_size, exp_avg, denomf)
196
+
197
+ else:
198
+
199
+ denom = exp_avg_sq.sqrt().add_(group["eps"])
200
+ p_data_fp32.addcdiv_(-step_size * group["lr"], exp_avg, denom)
201
+
202
+ p.data.copy_(p_data_fp32)
203
+
204
+ # integrated look ahead...
205
+ # we do it at the param level instead of group level
206
+ if state["step"] % group["k"] == 0:
207
+ slow_p = state["slow_buffer"] # get access to slow param tensor
208
+ slow_p.add_(
209
+ self.alpha, p.data - slow_p
210
+ ) # (fast weights - slow weights) * alpha
211
+ p.data.copy_(
212
+ slow_p
213
+ ) # copy interpolated weights to RAdam param tensor
214
+
215
+ return loss