sagemaker-core 1.0.62__py3-none-any.whl → 2.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (358) hide show
  1. sagemaker/__init__.py +2 -0
  2. sagemaker/core/__init__.py +16 -0
  3. sagemaker/core/_studio.py +116 -0
  4. sagemaker/core/_version.py +11 -0
  5. sagemaker/core/accept_types.py +131 -0
  6. sagemaker/core/analytics.py +744 -0
  7. sagemaker/core/apiutils/__init__.py +13 -0
  8. sagemaker/core/apiutils/_base_types.py +228 -0
  9. sagemaker/core/apiutils/_boto_functions.py +130 -0
  10. sagemaker/core/apiutils/_utils.py +34 -0
  11. sagemaker/core/base_deserializers.py +35 -0
  12. sagemaker/core/base_serializers.py +35 -0
  13. sagemaker/core/clarify/__init__.py +2898 -0
  14. sagemaker/core/collection.py +467 -0
  15. sagemaker/core/common_utils.py +2399 -0
  16. sagemaker/core/compute_resource_requirements/__init__.py +18 -0
  17. sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
  18. sagemaker/core/config/__init__.py +181 -0
  19. sagemaker/core/config/config.py +238 -0
  20. sagemaker/core/config/config_manager.py +595 -0
  21. sagemaker/core/config/config_schema.py +1220 -0
  22. sagemaker/core/config/config_utils.py +297 -0
  23. {sagemaker_core/main → sagemaker/core}/config_schema.py +408 -3
  24. sagemaker/core/constants.py +73 -0
  25. sagemaker/core/content_types.py +137 -0
  26. sagemaker/core/debugger/__init__.py +39 -0
  27. sagemaker/core/debugger/debugger.py +945 -0
  28. sagemaker/core/debugger/framework_profile.py +292 -0
  29. sagemaker/core/debugger/metrics_config.py +468 -0
  30. sagemaker/core/debugger/profiler.py +42 -0
  31. sagemaker/core/debugger/profiler_config.py +190 -0
  32. sagemaker/core/debugger/profiler_constants.py +40 -0
  33. sagemaker/core/debugger/utils.py +148 -0
  34. sagemaker/core/deprecations.py +254 -0
  35. sagemaker/core/deserializers/__init__.py +10 -0
  36. sagemaker/core/deserializers/base.py +424 -0
  37. sagemaker/core/deserializers/implementations.py +157 -0
  38. sagemaker/core/drift_check_baselines.py +106 -0
  39. sagemaker/core/enums.py +51 -0
  40. sagemaker/core/environment_variables.py +101 -0
  41. sagemaker/core/exceptions.py +108 -0
  42. sagemaker/core/experiments/__init__.py +53 -0
  43. sagemaker/core/experiments/_api_types.py +251 -0
  44. sagemaker/core/experiments/_environment.py +124 -0
  45. sagemaker/core/experiments/_helper.py +294 -0
  46. sagemaker/core/experiments/_metrics.py +333 -0
  47. sagemaker/core/experiments/_run_context.py +58 -0
  48. sagemaker/core/experiments/_utils.py +216 -0
  49. sagemaker/core/experiments/experiment.py +247 -0
  50. sagemaker/core/experiments/run.py +970 -0
  51. sagemaker/core/experiments/trial.py +296 -0
  52. sagemaker/core/experiments/trial_component.py +387 -0
  53. sagemaker/core/explainer/__init__.py +24 -0
  54. sagemaker/core/explainer/clarify_explainer_config.py +298 -0
  55. sagemaker/core/explainer/explainer_config.py +44 -0
  56. sagemaker/core/fw_utils.py +1220 -0
  57. sagemaker/core/git_utils.py +415 -0
  58. sagemaker/core/helper/pipeline_variable.py +82 -0
  59. sagemaker/core/helper/session_helper.py +2977 -0
  60. sagemaker/core/hyperparameters.py +172 -0
  61. sagemaker/core/image_retriever/__init__.py +3 -0
  62. sagemaker/core/image_retriever/image_retriever.py +640 -0
  63. sagemaker/core/image_retriever/image_retriever_utils.py +509 -0
  64. sagemaker/core/image_retriever/test.py +7 -0
  65. sagemaker/core/image_uri_config/autogluon.json +1335 -0
  66. sagemaker/core/image_uri_config/blazingtext.json +50 -0
  67. sagemaker/core/image_uri_config/chainer.json +104 -0
  68. sagemaker/core/image_uri_config/clarify.json +39 -0
  69. sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
  70. sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
  71. sagemaker/core/image_uri_config/data-wrangler.json +91 -0
  72. sagemaker/core/image_uri_config/debugger.json +34 -0
  73. sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
  74. sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
  75. sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
  76. sagemaker/core/image_uri_config/djl-lmi.json +136 -0
  77. sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
  78. sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
  79. sagemaker/core/image_uri_config/factorization-machines.json +50 -0
  80. sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
  81. sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +770 -0
  82. sagemaker/core/image_uri_config/huggingface-llm.json +1267 -0
  83. sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
  84. sagemaker/core/image_uri_config/huggingface-neuronx.json +686 -0
  85. sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
  86. sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
  87. sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
  88. sagemaker/core/image_uri_config/huggingface-vllm-neuronx.json +38 -0
  89. sagemaker/core/image_uri_config/huggingface.json +2287 -0
  90. sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
  91. sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
  92. sagemaker/core/image_uri_config/image-classification.json +50 -0
  93. sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
  94. sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
  95. sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
  96. sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
  97. sagemaker/core/image_uri_config/ipinsights.json +50 -0
  98. sagemaker/core/image_uri_config/kmeans.json +50 -0
  99. sagemaker/core/image_uri_config/knn.json +50 -0
  100. sagemaker/core/image_uri_config/lda.json +26 -0
  101. sagemaker/core/image_uri_config/linear-learner.json +50 -0
  102. sagemaker/core/image_uri_config/model-monitor.json +42 -0
  103. sagemaker/core/image_uri_config/mxnet.json +1154 -0
  104. sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
  105. sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
  106. sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
  107. sagemaker/core/image_uri_config/ntm.json +50 -0
  108. sagemaker/core/image_uri_config/object-detection.json +50 -0
  109. sagemaker/core/image_uri_config/object2vec.json +50 -0
  110. sagemaker/core/image_uri_config/pca.json +50 -0
  111. sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
  112. sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
  113. sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
  114. sagemaker/core/image_uri_config/pytorch.json +3101 -0
  115. sagemaker/core/image_uri_config/randomcutforest.json +50 -0
  116. sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
  117. sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
  118. sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
  119. sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
  120. sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
  121. sagemaker/core/image_uri_config/sagemaker-tritonserver.json +252 -0
  122. sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
  123. sagemaker/core/image_uri_config/seq2seq.json +50 -0
  124. sagemaker/core/image_uri_config/sklearn.json +494 -0
  125. sagemaker/core/image_uri_config/spark.json +280 -0
  126. sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
  127. sagemaker/core/image_uri_config/stabilityai.json +53 -0
  128. sagemaker/core/image_uri_config/tensorflow.json +5086 -0
  129. sagemaker/core/image_uri_config/vw.json +25 -0
  130. sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
  131. sagemaker/core/image_uri_config/xgboost.json +972 -0
  132. sagemaker/core/image_uris.py +816 -0
  133. sagemaker/core/inference_config.py +144 -0
  134. sagemaker/core/inference_recommender/__init__.py +18 -0
  135. sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
  136. sagemaker/core/inputs.py +366 -0
  137. sagemaker/core/instance_group.py +61 -0
  138. sagemaker/core/instance_types.py +164 -0
  139. sagemaker/core/instance_types_gpu_info.py +43 -0
  140. sagemaker/core/interactive_apps/__init__.py +41 -0
  141. sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
  142. sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
  143. sagemaker/core/interactive_apps/tensorboard.py +149 -0
  144. sagemaker/core/iterators.py +197 -0
  145. sagemaker/core/job.py +380 -0
  146. sagemaker/core/jumpstart/__init__.py +156 -0
  147. sagemaker/core/jumpstart/accessors.py +390 -0
  148. sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
  149. sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
  150. sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
  151. sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
  152. sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
  153. sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
  154. sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
  155. sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
  156. sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
  157. sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
  158. sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
  159. sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
  160. sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
  161. sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
  162. sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
  163. sagemaker/core/jumpstart/cache.py +663 -0
  164. sagemaker/core/jumpstart/configs.py +50 -0
  165. sagemaker/core/jumpstart/constants.py +198 -0
  166. sagemaker/core/jumpstart/deserializers.py +81 -0
  167. sagemaker/core/jumpstart/document.py +76 -0
  168. sagemaker/core/jumpstart/enums.py +168 -0
  169. sagemaker/core/jumpstart/exceptions.py +236 -0
  170. sagemaker/core/jumpstart/factory/utils.py +833 -0
  171. sagemaker/core/jumpstart/filters.py +597 -0
  172. sagemaker/core/jumpstart/hub/constants.py +16 -0
  173. sagemaker/core/jumpstart/hub/hub.py +291 -0
  174. sagemaker/core/jumpstart/hub/interfaces.py +936 -0
  175. sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
  176. sagemaker/core/jumpstart/hub/parsers.py +288 -0
  177. sagemaker/core/jumpstart/hub/types.py +35 -0
  178. sagemaker/core/jumpstart/hub/utils.py +260 -0
  179. sagemaker/core/jumpstart/models.py +501 -0
  180. sagemaker/core/jumpstart/notebook_utils.py +575 -0
  181. sagemaker/core/jumpstart/parameters.py +20 -0
  182. sagemaker/core/jumpstart/payload_utils.py +239 -0
  183. sagemaker/core/jumpstart/region_config.json +171 -0
  184. sagemaker/core/jumpstart/search.py +171 -0
  185. sagemaker/core/jumpstart/serializers.py +81 -0
  186. sagemaker/core/jumpstart/session_utils.py +234 -0
  187. sagemaker/core/jumpstart/types.py +3044 -0
  188. sagemaker/core/jumpstart/utils.py +1731 -0
  189. sagemaker/core/jumpstart/validators.py +257 -0
  190. sagemaker/core/lambda_helper.py +312 -0
  191. sagemaker/core/lineage/__init__.py +42 -0
  192. sagemaker/core/lineage/_api_types.py +239 -0
  193. sagemaker/core/lineage/_utils.py +49 -0
  194. sagemaker/core/lineage/action.py +345 -0
  195. sagemaker/core/lineage/artifact.py +646 -0
  196. sagemaker/core/lineage/association.py +190 -0
  197. sagemaker/core/lineage/context.py +505 -0
  198. sagemaker/core/lineage/lineage_trial_component.py +191 -0
  199. sagemaker/core/lineage/query.py +732 -0
  200. sagemaker/core/lineage/visualizer.py +346 -0
  201. sagemaker/core/local/__init__.py +18 -0
  202. sagemaker/core/local/data.py +423 -0
  203. sagemaker/core/local/entities.py +678 -0
  204. sagemaker/core/local/exceptions.py +17 -0
  205. sagemaker/core/local/image.py +1243 -0
  206. sagemaker/core/local/local_session.py +739 -0
  207. sagemaker/core/local/utils.py +246 -0
  208. sagemaker/core/logs.py +181 -0
  209. sagemaker/core/metadata_properties.py +56 -0
  210. sagemaker/core/metric_definitions.py +91 -0
  211. sagemaker/core/mlflow/__init__.py +38 -0
  212. sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
  213. sagemaker/core/model_card/__init__.py +26 -0
  214. sagemaker/core/model_life_cycle.py +51 -0
  215. sagemaker/core/model_metrics.py +160 -0
  216. sagemaker/core/model_monitor/__init__.py +66 -0
  217. sagemaker/core/model_monitor/clarify_model_monitoring.py +1497 -0
  218. sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
  219. sagemaker/core/model_monitor/data_capture_config.py +115 -0
  220. sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
  221. sagemaker/core/model_monitor/dataset_format.py +102 -0
  222. sagemaker/core/model_monitor/model_monitoring.py +4266 -0
  223. sagemaker/core/model_monitor/monitoring_alert.py +76 -0
  224. sagemaker/core/model_monitor/monitoring_files.py +506 -0
  225. sagemaker/core/model_monitor/utils.py +793 -0
  226. sagemaker/core/model_registry.py +480 -0
  227. sagemaker/core/model_uris.py +97 -0
  228. sagemaker/core/modules/__init__.py +19 -0
  229. sagemaker/core/modules/configs.py +239 -0
  230. sagemaker/core/modules/constants.py +37 -0
  231. sagemaker/core/modules/distributed.py +182 -0
  232. sagemaker/core/modules/local_core/local_container.py +605 -0
  233. sagemaker/core/modules/templates.py +83 -0
  234. sagemaker/core/modules/train/__init__.py +14 -0
  235. sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
  236. sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
  237. sagemaker/core/modules/train/container_drivers/common/utils.py +205 -0
  238. sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
  239. sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
  240. sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
  241. sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
  242. sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
  243. sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
  244. sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
  245. sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
  246. sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
  247. sagemaker/core/modules/types.py +19 -0
  248. sagemaker/core/modules/utils.py +194 -0
  249. sagemaker/core/network.py +185 -0
  250. sagemaker/core/parameter.py +173 -0
  251. sagemaker/core/payloads.py +185 -0
  252. sagemaker/core/processing.py +1599 -0
  253. sagemaker/core/remote_function/__init__.py +19 -0
  254. sagemaker/core/remote_function/checkpoint_location.py +47 -0
  255. sagemaker/core/remote_function/client.py +1310 -0
  256. sagemaker/core/remote_function/core/__init__.py +0 -0
  257. sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
  258. sagemaker/core/remote_function/core/pipeline_variables.py +347 -0
  259. sagemaker/core/remote_function/core/serialization.py +410 -0
  260. sagemaker/core/remote_function/core/stored_function.py +223 -0
  261. sagemaker/core/remote_function/custom_file_filter.py +128 -0
  262. sagemaker/core/remote_function/errors.py +102 -0
  263. sagemaker/core/remote_function/invoke_function.py +167 -0
  264. sagemaker/core/remote_function/job.py +2121 -0
  265. sagemaker/core/remote_function/logging_config.py +38 -0
  266. sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
  267. sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
  268. sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
  269. sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
  270. sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
  271. sagemaker/core/remote_function/spark_config.py +149 -0
  272. sagemaker/core/resource_requirements.py +168 -0
  273. {sagemaker_core/main → sagemaker/core}/resources.py +19098 -10895
  274. sagemaker/core/s3/__init__.py +41 -0
  275. sagemaker/core/s3/client.py +367 -0
  276. sagemaker/core/s3/utils.py +175 -0
  277. sagemaker/core/script_uris.py +93 -0
  278. sagemaker/core/serializers/__init__.py +11 -0
  279. sagemaker/core/serializers/base.py +510 -0
  280. sagemaker/core/serializers/implementations.py +159 -0
  281. sagemaker/core/serializers/utils.py +223 -0
  282. sagemaker/core/serverless_inference_config.py +63 -0
  283. sagemaker/core/session_settings.py +55 -0
  284. sagemaker/core/shapes/__init__.py +3 -0
  285. sagemaker/core/shapes/model_card_shapes.py +159 -0
  286. {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +5810 -1806
  287. sagemaker/core/spark/__init__.py +16 -0
  288. sagemaker/core/spark/defaults.py +16 -0
  289. sagemaker/core/spark/processing.py +1380 -0
  290. sagemaker/core/telemetry/__init__.py +23 -0
  291. sagemaker/core/telemetry/constants.py +82 -0
  292. sagemaker/core/telemetry/telemetry_logging.py +285 -0
  293. sagemaker/core/tools/__init__.py +1 -0
  294. {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
  295. {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
  296. {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
  297. {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
  298. sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
  299. {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
  300. {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
  301. {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
  302. {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
  303. {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
  304. sagemaker/core/training/__init__.py +14 -0
  305. sagemaker/core/training/configs.py +345 -0
  306. sagemaker/core/training/constants.py +37 -0
  307. sagemaker/core/training/utils.py +77 -0
  308. sagemaker/core/training_compiler/__init__.py +16 -0
  309. sagemaker/core/training_compiler/config.py +197 -0
  310. sagemaker/core/training_compiler_config.py +197 -0
  311. sagemaker/core/transformer.py +793 -0
  312. sagemaker/core/user_agent.py +76 -0
  313. sagemaker/core/utilities/__init__.py +24 -0
  314. sagemaker/core/utilities/cache.py +169 -0
  315. sagemaker/core/utilities/search_expression.py +133 -0
  316. sagemaker/core/utils/__init__.py +48 -0
  317. sagemaker/core/utils/code_injection/__init__.py +0 -0
  318. {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
  319. {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +5979 -176
  320. {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
  321. sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
  322. {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
  323. {sagemaker_core/main → sagemaker/core/utils}/utils.py +27 -22
  324. sagemaker/core/workflow/__init__.py +152 -0
  325. sagemaker/core/workflow/conditions.py +313 -0
  326. sagemaker/core/workflow/entities.py +58 -0
  327. sagemaker/core/workflow/execution_variables.py +89 -0
  328. sagemaker/core/workflow/functions.py +193 -0
  329. sagemaker/core/workflow/parameters.py +222 -0
  330. sagemaker/core/workflow/pipeline_context.py +394 -0
  331. sagemaker/core/workflow/pipeline_definition_config.py +31 -0
  332. sagemaker/core/workflow/properties.py +285 -0
  333. sagemaker/core/workflow/step_outputs.py +65 -0
  334. sagemaker/core/workflow/utilities.py +514 -0
  335. sagemaker/lineage/__init__.py +33 -0
  336. sagemaker/lineage/action.py +28 -0
  337. sagemaker/lineage/artifact.py +28 -0
  338. sagemaker/lineage/context.py +28 -0
  339. sagemaker/lineage/lineage_trial_component.py +28 -0
  340. {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.3.1.dist-info}/METADATA +28 -9
  341. sagemaker_core-2.3.1.dist-info/RECORD +351 -0
  342. sagemaker_core-2.3.1.dist-info/top_level.txt +1 -0
  343. sagemaker_core/_version.py +0 -3
  344. sagemaker_core/helper/session_helper.py +0 -769
  345. sagemaker_core/resources/__init__.py +0 -1
  346. sagemaker_core/shapes/__init__.py +0 -1
  347. sagemaker_core/tools/__init__.py +0 -1
  348. sagemaker_core-1.0.62.dist-info/RECORD +0 -35
  349. sagemaker_core-1.0.62.dist-info/top_level.txt +0 -1
  350. {sagemaker_core → sagemaker/core/helper}/__init__.py +0 -0
  351. {sagemaker_core/helper → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
  352. {sagemaker_core/main → sagemaker/core/jumpstart/hub}/__init__.py +0 -0
  353. {sagemaker_core/main/code_injection → sagemaker/core/modules/local_core}/__init__.py +0 -0
  354. {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
  355. {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
  356. {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
  357. {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.3.1.dist-info}/WHEEL +0 -0
  358. {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.3.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,424 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ """Implements base methods for deserializing data returned from an inference endpoint."""
14
+ from __future__ import absolute_import
15
+
16
+ import csv
17
+
18
+ import abc
19
+ import codecs
20
+ import io
21
+ import json
22
+
23
+ import numpy as np
24
+ from six import with_metaclass
25
+
26
+ # Lazy import to avoid circular dependency with amazon modules
27
+ # from sagemaker.core.serializers.utils import read_records
28
+ from sagemaker.core.common_utils import DeferredError
29
+
30
+ try:
31
+ import pandas
32
+ except ImportError as e:
33
+ pandas = DeferredError(e)
34
+
35
+
36
+ class BaseDeserializer(abc.ABC):
37
+ """Abstract base class for creation of new deserializers.
38
+
39
+ Provides a skeleton for customization requiring the overriding of the method
40
+ deserialize and the class attribute ACCEPT.
41
+ """
42
+
43
+ @abc.abstractmethod
44
+ def deserialize(self, stream, content_type):
45
+ """Deserialize data received from an inference endpoint.
46
+
47
+ Args:
48
+ stream (botocore.response.StreamingBody): Data to be deserialized.
49
+ content_type (str): The MIME type of the data.
50
+
51
+ Returns:
52
+ object: The data deserialized into an object.
53
+ """
54
+
55
+ @property
56
+ @abc.abstractmethod
57
+ def ACCEPT(self):
58
+ """The content types that are expected from the inference endpoint."""
59
+
60
+
61
+ class SimpleBaseDeserializer(with_metaclass(abc.ABCMeta, BaseDeserializer)):
62
+ """Abstract base class for creation of new deserializers.
63
+
64
+ This class extends the API of :class:~`sagemaker.deserializers.BaseDeserializer` with more
65
+ user-friendly options for setting the ACCEPT content type header, in situations where it can be
66
+ provided at init and freely updated.
67
+ """
68
+
69
+ def __init__(self, accept="*/*"):
70
+ """Initialize a ``SimpleBaseDeserializer`` instance.
71
+
72
+ Args:
73
+ accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that
74
+ is expected from the inference endpoint (default: "*/*").
75
+ """
76
+ super(SimpleBaseDeserializer, self).__init__()
77
+ self.accept = accept
78
+
79
+ @property
80
+ def ACCEPT(self):
81
+ """The tuple of possible content types that are expected from the inference endpoint."""
82
+ if isinstance(self.accept, str):
83
+ return (self.accept,)
84
+ return self.accept
85
+
86
+
87
+ class StringDeserializer(SimpleBaseDeserializer):
88
+ """Deserialize data from an inference endpoint into a decoded string."""
89
+
90
+ def __init__(self, encoding="UTF-8", accept="application/json"):
91
+ """Initialize a ``StringDeserializer`` instance.
92
+
93
+ Args:
94
+ encoding (str): The string encoding to use (default: UTF-8).
95
+ accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that
96
+ is expected from the inference endpoint (default: "application/json").
97
+ """
98
+ super(StringDeserializer, self).__init__(accept=accept)
99
+ self.encoding = encoding
100
+
101
+ def deserialize(self, stream, content_type):
102
+ """Deserialize data from an inference endpoint into a decoded string.
103
+
104
+ Args:
105
+ stream (botocore.response.StreamingBody): Data to be deserialized.
106
+ content_type (str): The MIME type of the data.
107
+
108
+ Returns:
109
+ str: The data deserialized into a decoded string.
110
+ """
111
+ try:
112
+ return stream.read().decode(self.encoding)
113
+ finally:
114
+ stream.close()
115
+
116
+
117
+ class BytesDeserializer(SimpleBaseDeserializer):
118
+ """Deserialize a stream of bytes into a bytes object."""
119
+
120
+ def deserialize(self, stream, content_type):
121
+ """Read a stream of bytes returned from an inference endpoint.
122
+
123
+ Args:
124
+ stream (botocore.response.StreamingBody): A stream of bytes.
125
+ content_type (str): The MIME type of the data.
126
+
127
+ Returns:
128
+ bytes: The bytes object read from the stream.
129
+ """
130
+ try:
131
+ return stream.read()
132
+ finally:
133
+ stream.close()
134
+
135
+
136
+ class CSVDeserializer(SimpleBaseDeserializer):
137
+ """Deserialize a stream of bytes into a list of lists.
138
+
139
+ Consider using :class:~`sagemaker.deserializers.NumpyDeserializer` or
140
+ :class:~`sagemaker.deserializers.PandasDeserializer` instead, if you'd like to convert text/csv
141
+ responses directly into other data types.
142
+ """
143
+
144
+ def __init__(self, encoding="utf-8", accept="text/csv"):
145
+ """Initialize a ``CSVDeserializer`` instance.
146
+
147
+ Args:
148
+ encoding (str): The string encoding to use (default: "utf-8").
149
+ accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that
150
+ is expected from the inference endpoint (default: "text/csv").
151
+ """
152
+ super(CSVDeserializer, self).__init__(accept=accept)
153
+ self.encoding = encoding
154
+
155
+ def deserialize(self, stream, content_type):
156
+ """Deserialize data from an inference endpoint into a list of lists.
157
+
158
+ Args:
159
+ stream (botocore.response.StreamingBody): Data to be deserialized.
160
+ content_type (str): The MIME type of the data.
161
+
162
+ Returns:
163
+ list: The data deserialized into a list of lists representing the
164
+ contents of a CSV file.
165
+ """
166
+ try:
167
+ decoded_string = stream.read().decode(self.encoding)
168
+ return list(csv.reader(decoded_string.splitlines()))
169
+ finally:
170
+ stream.close()
171
+
172
+
173
+ class StreamDeserializer(SimpleBaseDeserializer):
174
+ """Directly return the data and content-type received from an inference endpoint.
175
+
176
+ It is the user's responsibility to close the data stream once they're done
177
+ reading it.
178
+ """
179
+
180
+ def deserialize(self, stream, content_type):
181
+ """Returns a stream of the response body and the MIME type of the data.
182
+
183
+ Args:
184
+ stream (botocore.response.StreamingBody): A stream of bytes.
185
+ content_type (str): The MIME type of the data.
186
+
187
+ Returns:
188
+ tuple: A two-tuple containing the stream and content-type.
189
+ """
190
+ return stream, content_type
191
+
192
+
193
+ class NumpyDeserializer(SimpleBaseDeserializer):
194
+ """Deserialize a stream of data in .npy, .npz or UTF-8 CSV/JSON format to a numpy array.
195
+
196
+ Note that when using application/x-npz archive format, the result will usually be a
197
+ dictionary-like object containing multiple arrays (as per ``numpy.load()``) - instead of a
198
+ single array.
199
+ """
200
+
201
+ def __init__(self, dtype=None, accept="application/x-npy", allow_pickle=False):
202
+ """Initialize a ``NumpyDeserializer`` instance.
203
+
204
+ Args:
205
+ dtype (str): The dtype of the data (default: None).
206
+ accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that
207
+ is expected from the inference endpoint (default: "application/x-npy").
208
+ allow_pickle (bool): Allow loading pickled object arrays (default: False).
209
+ """
210
+ super(NumpyDeserializer, self).__init__(accept=accept)
211
+ self.dtype = dtype
212
+ self.allow_pickle = allow_pickle
213
+
214
+ def deserialize(self, stream, content_type):
215
+ """Deserialize data from an inference endpoint into a NumPy array.
216
+
217
+ Args:
218
+ stream (botocore.response.StreamingBody): Data to be deserialized.
219
+ content_type (str): The MIME type of the data.
220
+
221
+ Returns:
222
+ numpy.ndarray: The data deserialized into a NumPy array.
223
+ """
224
+ try:
225
+ if content_type == "text/csv":
226
+ return np.genfromtxt(
227
+ codecs.getreader("utf-8")(stream), delimiter=",", dtype=self.dtype
228
+ )
229
+ if content_type == "application/json":
230
+ return np.array(json.load(codecs.getreader("utf-8")(stream)), dtype=self.dtype)
231
+ if content_type == "application/x-npy":
232
+ try:
233
+ return np.load(io.BytesIO(stream.read()), allow_pickle=self.allow_pickle)
234
+ except ValueError as ve:
235
+ raise ValueError(
236
+ "Please set the param allow_pickle=True \
237
+ to deserialize pickle objects in NumpyDeserializer"
238
+ ).with_traceback(ve.__traceback__)
239
+ if content_type == "application/x-npz":
240
+ try:
241
+ return np.load(io.BytesIO(stream.read()), allow_pickle=self.allow_pickle)
242
+ except ValueError as ve:
243
+ raise ValueError(
244
+ "Please set the param allow_pickle=True \
245
+ to deserialize pickle objectsin NumpyDeserializer"
246
+ ).with_traceback(ve.__traceback__)
247
+ finally:
248
+ stream.close()
249
+ finally:
250
+ stream.close()
251
+
252
+ raise ValueError("%s cannot read content type %s." % (__class__.__name__, content_type))
253
+
254
+
255
+ class JSONDeserializer(SimpleBaseDeserializer):
256
+ """Deserialize JSON data from an inference endpoint into a Python object."""
257
+
258
+ def __init__(self, accept="application/json"):
259
+ """Initialize a ``JSONDeserializer`` instance.
260
+
261
+ Args:
262
+ accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that
263
+ is expected from the inference endpoint (default: "application/json").
264
+ """
265
+ super(JSONDeserializer, self).__init__(accept=accept)
266
+
267
+ def deserialize(self, stream, content_type):
268
+ """Deserialize JSON data from an inference endpoint into a Python object.
269
+
270
+ Args:
271
+ stream (botocore.response.StreamingBody): Data to be deserialized.
272
+ content_type (str): The MIME type of the data.
273
+
274
+ Returns:
275
+ object: The JSON-formatted data deserialized into a Python object.
276
+ """
277
+ try:
278
+ return json.load(codecs.getreader("utf-8")(stream))
279
+ finally:
280
+ stream.close()
281
+
282
+
283
+ class PandasDeserializer(SimpleBaseDeserializer):
284
+ """Deserialize CSV or JSON data from an inference endpoint into a pandas dataframe."""
285
+
286
+ def __init__(self, accept=("text/csv", "application/json")):
287
+ """Initialize a ``PandasDeserializer`` instance.
288
+
289
+ Args:
290
+ accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that
291
+ is expected from the inference endpoint (default: ("text/csv","application/json")).
292
+ """
293
+ super(PandasDeserializer, self).__init__(accept=accept)
294
+
295
+ def deserialize(self, stream, content_type):
296
+ """Deserialize CSV or JSON data from an inference endpoint into a pandas dataframe.
297
+
298
+ If the data is JSON, the data should be formatted in the 'columns' orient.
299
+ See https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_json.html
300
+
301
+ Args:
302
+ stream (botocore.response.StreamingBody): Data to be deserialized.
303
+ content_type (str): The MIME type of the data.
304
+
305
+ Returns:
306
+ pandas.DataFrame: The data deserialized into a pandas DataFrame.
307
+ """
308
+ if content_type == "text/csv":
309
+ return pandas.read_csv(stream)
310
+
311
+ if content_type == "application/json":
312
+ return pandas.read_json(stream)
313
+
314
+ raise ValueError("%s cannot read content type %s." % (__class__.__name__, content_type))
315
+
316
+
317
+ class JSONLinesDeserializer(SimpleBaseDeserializer):
318
+ """Deserialize JSON lines data from an inference endpoint."""
319
+
320
+ def __init__(self, accept="application/jsonlines"):
321
+ """Initialize a ``JSONLinesDeserializer`` instance.
322
+
323
+ Args:
324
+ accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that
325
+ is expected from the inference endpoint (default: ("text/csv","application/json")).
326
+ """
327
+ super(JSONLinesDeserializer, self).__init__(accept=accept)
328
+
329
+ def deserialize(self, stream, content_type):
330
+ """Deserialize JSON lines data from an inference endpoint.
331
+
332
+ See https://docs.python.org/3/library/json.html#py-to-json-table to
333
+ understand how JSON values are converted to Python objects.
334
+
335
+ Args:
336
+ stream (botocore.response.StreamingBody): Data to be deserialized.
337
+ content_type (str): The MIME type of the data.
338
+
339
+ Returns:
340
+ list: A list of JSON serializable objects.
341
+ """
342
+ try:
343
+ body = stream.read().decode("utf-8")
344
+ lines = body.rstrip().split("\n")
345
+ return [json.loads(line) for line in lines]
346
+ finally:
347
+ stream.close()
348
+
349
+
350
+ class TorchTensorDeserializer(SimpleBaseDeserializer):
351
+ """Deserialize stream to torch.Tensor.
352
+
353
+ Args:
354
+ stream (botocore.response.StreamingBody): Data to be deserialized.
355
+ content_type (str): The MIME type of the data.
356
+
357
+ Returns:
358
+ torch.Tensor: The data deserialized into a torch Tensor.
359
+ """
360
+
361
+ def __init__(self, accept="tensor/pt"):
362
+ super(TorchTensorDeserializer, self).__init__(accept=accept)
363
+ self.numpy_deserializer = NumpyDeserializer()
364
+ try:
365
+ from torch import from_numpy
366
+
367
+ self.convert_npy_to_tensor = from_numpy
368
+ except ImportError:
369
+ raise Exception("Unable to import pytorch.")
370
+
371
+ def deserialize(self, stream, content_type="tensor/pt"):
372
+ """Deserialize streamed data to TorchTensor
373
+
374
+ See https://pytorch.org/docs/stable/generated/torch.from_numpy.html
375
+
376
+ Args:
377
+ stream (botocore.response.StreamingBody): Data to be deserialized.
378
+ content_type (str): The MIME type of the data.
379
+
380
+ Returns:
381
+ list: A list of TorchTensor serializable objects.
382
+ """
383
+ try:
384
+ numpy_array = self.numpy_deserializer.deserialize(
385
+ stream=stream, content_type="application/x-npy"
386
+ )
387
+ return self.convert_npy_to_tensor(numpy_array)
388
+ except Exception:
389
+ raise ValueError(
390
+ "Unable to deserialize your data to torch.Tensor.\
391
+ Please provide custom deserializer in InferenceSpec."
392
+ )
393
+
394
+
395
+ # TODO fix the unit test for this deserializer
396
+ class RecordDeserializer(SimpleBaseDeserializer):
397
+ """Deserialize RecordIO Protobuf data from an inference endpoint."""
398
+
399
+ def __init__(self, accept="application/x-recordio-protobuf"):
400
+ """Initialize a ``RecordDeserializer`` instance.
401
+
402
+ Args:
403
+ accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that
404
+ is expected from the inference endpoint (default:
405
+ "application/x-recordio-protobuf").
406
+ """
407
+ super(RecordDeserializer, self).__init__(accept=accept)
408
+
409
+ def deserialize(self, data, content_type):
410
+ """Deserialize RecordIO Protobuf data from an inference endpoint.
411
+
412
+ Args:
413
+ data (object): The protobuf message to deserialize.
414
+ content_type (str): The MIME type of the data.
415
+ Returns:
416
+ list: A list of records.
417
+ """
418
+ try:
419
+ # Lazy import to avoid circular dependency
420
+ from sagemaker.core.serializers.utils import read_records
421
+
422
+ return read_records(data)
423
+ finally:
424
+ data.close()
@@ -0,0 +1,157 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ """Implements methods for deserializing data returned from an inference endpoint."""
14
+ from __future__ import absolute_import
15
+
16
+
17
+ from typing import List, Optional
18
+
19
+ # base_deserializers was refactored from deserializers.
20
+ # this import ensures backward compatibility.
21
+ from sagemaker.core.deserializers.base import ( # noqa: F401 # pylint: disable=W0611
22
+ BaseDeserializer,
23
+ BytesDeserializer,
24
+ CSVDeserializer,
25
+ DeferredError,
26
+ JSONDeserializer,
27
+ JSONLinesDeserializer,
28
+ NumpyDeserializer,
29
+ PandasDeserializer,
30
+ SimpleBaseDeserializer,
31
+ StreamDeserializer,
32
+ StringDeserializer,
33
+ TorchTensorDeserializer,
34
+ RecordDeserializer,
35
+ )
36
+
37
+ from sagemaker.core.deprecations import deprecated_class
38
+ from sagemaker.core.jumpstart import artifacts, utils as jumpstart_utils
39
+ from sagemaker.core.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
40
+ from sagemaker.core.jumpstart.enums import JumpStartModelType
41
+ from sagemaker.core.helper.session_helper import Session
42
+
43
+
44
+ def retrieve_options(
45
+ region: Optional[str] = None,
46
+ model_id: Optional[str] = None,
47
+ model_version: Optional[str] = None,
48
+ hub_arn: Optional[str] = None,
49
+ tolerate_vulnerable_model: bool = False,
50
+ tolerate_deprecated_model: bool = False,
51
+ sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
52
+ ) -> List[BaseDeserializer]:
53
+ """Retrieves the supported deserializers for the model matching the given arguments.
54
+
55
+ Args:
56
+ region (str): The AWS Region for which to retrieve the supported deserializers.
57
+ Defaults to ``None``.
58
+ model_id (str): The model ID of the model for which to
59
+ retrieve the supported deserializers. (Default: None).
60
+ model_version (str): The version of the model for which to retrieve the
61
+ supported deserializers. (Default: None).
62
+ hub_arn (str): The arn of the SageMaker Hub for which to retrieve
63
+ model details from. (Default: None).
64
+ tolerate_vulnerable_model (bool): True if vulnerable versions of model
65
+ specifications should be tolerated (exception not raised). If False, raises an
66
+ exception if the script used by this version of the model has dependencies with known
67
+ security vulnerabilities. (Default: False).
68
+ tolerate_deprecated_model (bool): True if deprecated models should be tolerated
69
+ (exception not raised). False if these models should raise an exception.
70
+ (Default: False).
71
+ sagemaker_session (sagemaker.session.Session): A SageMaker Session
72
+ object, used for SageMaker interactions. If not
73
+ specified, one is created using the default AWS configuration
74
+ chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
75
+ Returns:
76
+ List[BaseDeserializer]: The supported deserializers to use for the model.
77
+
78
+ Raises:
79
+ ValueError: If the combination of arguments specified is not supported.
80
+ """
81
+
82
+ if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version):
83
+ raise ValueError(
84
+ "Must specify JumpStart `model_id` and `model_version` when retrieving deserializers."
85
+ )
86
+
87
+ return artifacts._retrieve_deserializer_options(
88
+ model_id=model_id,
89
+ model_version=model_version,
90
+ hub_arn=hub_arn,
91
+ region=region,
92
+ tolerate_vulnerable_model=tolerate_vulnerable_model,
93
+ tolerate_deprecated_model=tolerate_deprecated_model,
94
+ sagemaker_session=sagemaker_session,
95
+ )
96
+
97
+
98
+ def retrieve_default(
99
+ region: Optional[str] = None,
100
+ model_id: Optional[str] = None,
101
+ model_version: Optional[str] = None,
102
+ hub_arn: Optional[str] = None,
103
+ tolerate_vulnerable_model: bool = False,
104
+ tolerate_deprecated_model: bool = False,
105
+ sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
106
+ model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
107
+ config_name: Optional[str] = None,
108
+ ) -> BaseDeserializer:
109
+ """Retrieves the default deserializer for the model matching the given arguments.
110
+
111
+ Args:
112
+ region (str): The AWS Region for which to retrieve the default deserializer.
113
+ Defaults to ``None``.
114
+ model_id (str): The model ID of the model for which to
115
+ retrieve the default deserializer. (Default: None).
116
+ model_version (str): The version of the model for which to retrieve the
117
+ default deserializer. (Default: None).
118
+ hub_arn (str): The arn of the SageMaker Hub for which to retrieve
119
+ model details from. (Default: None).
120
+ tolerate_vulnerable_model (bool): True if vulnerable versions of model
121
+ specifications should be tolerated (exception not raised). If False, raises an
122
+ exception if the script used by this version of the model has dependencies with known
123
+ security vulnerabilities. (Default: False).
124
+ tolerate_deprecated_model (bool): True if deprecated models should be tolerated
125
+ (exception not raised). False if these models should raise an exception.
126
+ (Default: False).
127
+ sagemaker_session (sagemaker.session.Session): A SageMaker Session
128
+ object, used for SageMaker interactions. If not
129
+ specified, one is created using the default AWS configuration
130
+ chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
131
+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
132
+ Returns:
133
+ BaseDeserializer: The default deserializer to use for the model.
134
+
135
+ Raises:
136
+ ValueError: If the combination of arguments specified is not supported.
137
+ """
138
+
139
+ if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version):
140
+ raise ValueError(
141
+ "Must specify JumpStart `model_id` and `model_version` when retrieving deserializers."
142
+ )
143
+
144
+ return artifacts._retrieve_default_deserializer(
145
+ model_id=model_id,
146
+ model_version=model_version,
147
+ hub_arn=hub_arn,
148
+ region=region,
149
+ tolerate_vulnerable_model=tolerate_vulnerable_model,
150
+ tolerate_deprecated_model=tolerate_deprecated_model,
151
+ sagemaker_session=sagemaker_session,
152
+ model_type=model_type,
153
+ config_name=config_name,
154
+ )
155
+
156
+
157
+ record_deserializer = deprecated_class(RecordDeserializer, "record_deserializer")
@@ -0,0 +1,106 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ """This file contains code related to drift check baselines"""
14
+ from __future__ import absolute_import
15
+
16
+ from typing import Optional
17
+
18
+ from sagemaker.core.model_metrics import MetricsSource, FileSource
19
+
20
+
21
+ class DriftCheckBaselines(object):
22
+ """Accepts drift check baselines parameters for conversion to request dict."""
23
+
24
+ def __init__(
25
+ self,
26
+ model_statistics: Optional[MetricsSource] = None,
27
+ model_constraints: Optional[MetricsSource] = None,
28
+ model_data_statistics: Optional[MetricsSource] = None,
29
+ model_data_constraints: Optional[MetricsSource] = None,
30
+ bias_config_file: Optional[FileSource] = None,
31
+ bias_pre_training_constraints: Optional[MetricsSource] = None,
32
+ bias_post_training_constraints: Optional[MetricsSource] = None,
33
+ explainability_constraints: Optional[MetricsSource] = None,
34
+ explainability_config_file: Optional[FileSource] = None,
35
+ ):
36
+ """Initialize a ``DriftCheckBaselines`` instance and turn parameters into dict.
37
+
38
+ Args:
39
+ model_statistics (MetricsSource): A metric source object that represents
40
+ model statistics (default: None).
41
+ model_constraints (MetricsSource): A metric source object that represents
42
+ model constraints (default: None).
43
+ model_data_statistics (MetricsSource): A metric source object that represents
44
+ model data statistics (default: None).
45
+ model_data_constraints (MetricsSource): A metric source object that represents
46
+ model data constraints (default: None).
47
+ bias_config_file (FileSource): A file source object that represents bias config
48
+ (default: None).
49
+ bias_pre_training_constraints (MetricsSource):
50
+ A metric source object that represents Pre-training constraints (default: None).
51
+ bias_post_training_constraints (MetricsSource):
52
+ A metric source object that represents Post-training constraits (default: None).
53
+ explainability_constraints (MetricsSource):
54
+ A metric source object that represents explainability constraints (default: None).
55
+ explainability_config_file (FileSource): A file source object that represents
56
+ explainability config (default: None).
57
+ """
58
+ self.model_statistics = model_statistics
59
+ self.model_constraints = model_constraints
60
+ self.model_data_statistics = model_data_statistics
61
+ self.model_data_constraints = model_data_constraints
62
+ self.bias_config_file = bias_config_file
63
+ self.bias_pre_training_constraints = bias_pre_training_constraints
64
+ self.bias_post_training_constraints = bias_post_training_constraints
65
+ self.explainability_constraints = explainability_constraints
66
+ self.explainability_config_file = explainability_config_file
67
+
68
+ def _to_request_dict(self):
69
+ """Generates a request dictionary using the parameters provided to the class."""
70
+ drift_check_baselines_request = {}
71
+
72
+ model_quality = {}
73
+ if self.model_statistics is not None:
74
+ model_quality["Statistics"] = self.model_statistics._to_request_dict()
75
+ if self.model_constraints is not None:
76
+ model_quality["Constraints"] = self.model_constraints._to_request_dict()
77
+ if model_quality:
78
+ drift_check_baselines_request["ModelQuality"] = model_quality
79
+
80
+ model_data_quality = {}
81
+ if self.model_data_statistics is not None:
82
+ model_data_quality["Statistics"] = self.model_data_statistics._to_request_dict()
83
+ if self.model_data_constraints is not None:
84
+ model_data_quality["Constraints"] = self.model_data_constraints._to_request_dict()
85
+ if model_data_quality:
86
+ drift_check_baselines_request["ModelDataQuality"] = model_data_quality
87
+
88
+ bias = {}
89
+ if self.bias_config_file is not None:
90
+ bias["ConfigFile"] = self.bias_config_file._to_request_dict()
91
+ if self.bias_pre_training_constraints is not None:
92
+ bias["PreTrainingConstraints"] = self.bias_pre_training_constraints._to_request_dict()
93
+ if self.bias_post_training_constraints is not None:
94
+ bias["PostTrainingConstraints"] = self.bias_post_training_constraints._to_request_dict()
95
+ if bias:
96
+ drift_check_baselines_request["Bias"] = bias
97
+
98
+ explainability = {}
99
+ if self.explainability_constraints is not None:
100
+ explainability["Constraints"] = self.explainability_constraints._to_request_dict()
101
+ if self.explainability_config_file is not None:
102
+ explainability["ConfigFile"] = self.explainability_config_file._to_request_dict()
103
+ if explainability:
104
+ drift_check_baselines_request["Explainability"] = explainability
105
+
106
+ return drift_check_baselines_request