mlrun 1.7.0rc14__py3-none-any.whl → 1.7.0rc22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mlrun might be problematic. Click here for more details.

Files changed (160) hide show
  1. mlrun/__init__.py +10 -1
  2. mlrun/__main__.py +23 -111
  3. mlrun/alerts/__init__.py +15 -0
  4. mlrun/alerts/alert.py +169 -0
  5. mlrun/api/schemas/__init__.py +4 -3
  6. mlrun/artifacts/__init__.py +8 -3
  7. mlrun/artifacts/base.py +36 -253
  8. mlrun/artifacts/dataset.py +9 -190
  9. mlrun/artifacts/manager.py +46 -42
  10. mlrun/artifacts/model.py +9 -141
  11. mlrun/artifacts/plots.py +14 -375
  12. mlrun/common/constants.py +65 -3
  13. mlrun/common/formatters/__init__.py +19 -0
  14. mlrun/{runtimes/mpijob/v1alpha1.py → common/formatters/artifact.py} +6 -14
  15. mlrun/common/formatters/base.py +113 -0
  16. mlrun/common/formatters/function.py +46 -0
  17. mlrun/common/formatters/pipeline.py +53 -0
  18. mlrun/common/formatters/project.py +51 -0
  19. mlrun/{runtimes → common/runtimes}/constants.py +32 -4
  20. mlrun/common/schemas/__init__.py +10 -5
  21. mlrun/common/schemas/alert.py +92 -11
  22. mlrun/common/schemas/api_gateway.py +56 -0
  23. mlrun/common/schemas/artifact.py +15 -5
  24. mlrun/common/schemas/auth.py +2 -0
  25. mlrun/common/schemas/client_spec.py +1 -0
  26. mlrun/common/schemas/frontend_spec.py +1 -0
  27. mlrun/common/schemas/function.py +4 -0
  28. mlrun/common/schemas/model_monitoring/__init__.py +15 -3
  29. mlrun/common/schemas/model_monitoring/constants.py +58 -7
  30. mlrun/common/schemas/model_monitoring/grafana.py +9 -5
  31. mlrun/common/schemas/model_monitoring/model_endpoints.py +86 -2
  32. mlrun/common/schemas/pipeline.py +0 -9
  33. mlrun/common/schemas/project.py +5 -11
  34. mlrun/common/types.py +1 -0
  35. mlrun/config.py +30 -9
  36. mlrun/data_types/to_pandas.py +9 -9
  37. mlrun/datastore/base.py +41 -9
  38. mlrun/datastore/datastore.py +6 -2
  39. mlrun/datastore/datastore_profile.py +56 -4
  40. mlrun/datastore/inmem.py +2 -2
  41. mlrun/datastore/redis.py +2 -2
  42. mlrun/datastore/s3.py +5 -0
  43. mlrun/datastore/sources.py +147 -7
  44. mlrun/datastore/store_resources.py +7 -7
  45. mlrun/datastore/targets.py +110 -42
  46. mlrun/datastore/utils.py +42 -0
  47. mlrun/db/base.py +54 -10
  48. mlrun/db/httpdb.py +282 -79
  49. mlrun/db/nopdb.py +52 -10
  50. mlrun/errors.py +11 -0
  51. mlrun/execution.py +26 -9
  52. mlrun/feature_store/__init__.py +0 -2
  53. mlrun/feature_store/api.py +12 -47
  54. mlrun/feature_store/feature_set.py +9 -0
  55. mlrun/feature_store/feature_vector.py +8 -0
  56. mlrun/feature_store/ingestion.py +7 -6
  57. mlrun/feature_store/retrieval/base.py +9 -4
  58. mlrun/feature_store/retrieval/conversion.py +9 -9
  59. mlrun/feature_store/retrieval/dask_merger.py +2 -0
  60. mlrun/feature_store/retrieval/job.py +9 -3
  61. mlrun/feature_store/retrieval/local_merger.py +2 -0
  62. mlrun/feature_store/retrieval/spark_merger.py +16 -0
  63. mlrun/frameworks/__init__.py +6 -0
  64. mlrun/frameworks/_dl_common/loggers/tensorboard_logger.py +7 -12
  65. mlrun/frameworks/parallel_coordinates.py +2 -1
  66. mlrun/frameworks/tf_keras/__init__.py +4 -1
  67. mlrun/k8s_utils.py +10 -11
  68. mlrun/launcher/base.py +4 -3
  69. mlrun/launcher/client.py +5 -3
  70. mlrun/launcher/local.py +12 -2
  71. mlrun/launcher/remote.py +9 -2
  72. mlrun/lists.py +6 -2
  73. mlrun/model.py +47 -21
  74. mlrun/model_monitoring/__init__.py +1 -1
  75. mlrun/model_monitoring/api.py +42 -18
  76. mlrun/model_monitoring/application.py +5 -305
  77. mlrun/model_monitoring/applications/__init__.py +11 -0
  78. mlrun/model_monitoring/applications/_application_steps.py +157 -0
  79. mlrun/model_monitoring/applications/base.py +280 -0
  80. mlrun/model_monitoring/applications/context.py +214 -0
  81. mlrun/model_monitoring/applications/evidently_base.py +211 -0
  82. mlrun/model_monitoring/applications/histogram_data_drift.py +132 -91
  83. mlrun/model_monitoring/applications/results.py +99 -0
  84. mlrun/model_monitoring/controller.py +3 -1
  85. mlrun/model_monitoring/db/__init__.py +2 -0
  86. mlrun/model_monitoring/db/stores/__init__.py +0 -2
  87. mlrun/model_monitoring/db/stores/base/store.py +22 -37
  88. mlrun/model_monitoring/db/stores/sqldb/models/__init__.py +43 -21
  89. mlrun/model_monitoring/db/stores/sqldb/models/base.py +39 -8
  90. mlrun/model_monitoring/db/stores/sqldb/models/mysql.py +27 -7
  91. mlrun/model_monitoring/db/stores/sqldb/models/sqlite.py +5 -0
  92. mlrun/model_monitoring/db/stores/sqldb/sql_store.py +246 -224
  93. mlrun/model_monitoring/db/stores/v3io_kv/kv_store.py +232 -216
  94. mlrun/model_monitoring/db/tsdb/__init__.py +100 -0
  95. mlrun/model_monitoring/db/tsdb/base.py +316 -0
  96. mlrun/model_monitoring/db/tsdb/helpers.py +30 -0
  97. mlrun/model_monitoring/db/tsdb/tdengine/__init__.py +15 -0
  98. mlrun/model_monitoring/db/tsdb/tdengine/schemas.py +240 -0
  99. mlrun/model_monitoring/db/tsdb/tdengine/stream_graph_steps.py +45 -0
  100. mlrun/model_monitoring/db/tsdb/tdengine/tdengine_connector.py +401 -0
  101. mlrun/model_monitoring/db/tsdb/v3io/__init__.py +15 -0
  102. mlrun/model_monitoring/db/tsdb/v3io/stream_graph_steps.py +117 -0
  103. mlrun/model_monitoring/db/tsdb/v3io/v3io_connector.py +658 -0
  104. mlrun/model_monitoring/evidently_application.py +6 -118
  105. mlrun/model_monitoring/helpers.py +63 -1
  106. mlrun/model_monitoring/model_endpoint.py +3 -2
  107. mlrun/model_monitoring/stream_processing.py +57 -216
  108. mlrun/model_monitoring/writer.py +134 -124
  109. mlrun/package/__init__.py +13 -1
  110. mlrun/package/packagers/__init__.py +6 -1
  111. mlrun/package/utils/_formatter.py +2 -2
  112. mlrun/platforms/__init__.py +10 -9
  113. mlrun/platforms/iguazio.py +21 -202
  114. mlrun/projects/operations.py +24 -12
  115. mlrun/projects/pipelines.py +79 -102
  116. mlrun/projects/project.py +271 -103
  117. mlrun/render.py +15 -14
  118. mlrun/run.py +16 -46
  119. mlrun/runtimes/__init__.py +6 -3
  120. mlrun/runtimes/base.py +14 -7
  121. mlrun/runtimes/daskjob.py +1 -0
  122. mlrun/runtimes/databricks_job/databricks_runtime.py +1 -0
  123. mlrun/runtimes/databricks_job/databricks_wrapper.py +1 -1
  124. mlrun/runtimes/funcdoc.py +0 -28
  125. mlrun/runtimes/kubejob.py +2 -1
  126. mlrun/runtimes/local.py +12 -3
  127. mlrun/runtimes/mpijob/__init__.py +0 -20
  128. mlrun/runtimes/mpijob/v1.py +1 -1
  129. mlrun/runtimes/nuclio/api_gateway.py +194 -84
  130. mlrun/runtimes/nuclio/application/application.py +170 -8
  131. mlrun/runtimes/nuclio/function.py +39 -49
  132. mlrun/runtimes/pod.py +16 -36
  133. mlrun/runtimes/remotesparkjob.py +9 -3
  134. mlrun/runtimes/sparkjob/spark3job.py +1 -1
  135. mlrun/runtimes/utils.py +6 -45
  136. mlrun/serving/__init__.py +8 -1
  137. mlrun/serving/server.py +2 -1
  138. mlrun/serving/states.py +51 -8
  139. mlrun/serving/utils.py +19 -11
  140. mlrun/serving/v2_serving.py +5 -1
  141. mlrun/track/tracker.py +2 -1
  142. mlrun/utils/async_http.py +25 -5
  143. mlrun/utils/helpers.py +157 -83
  144. mlrun/utils/logger.py +39 -7
  145. mlrun/utils/notifications/notification/__init__.py +14 -9
  146. mlrun/utils/notifications/notification/base.py +1 -1
  147. mlrun/utils/notifications/notification/slack.py +34 -7
  148. mlrun/utils/notifications/notification/webhook.py +1 -1
  149. mlrun/utils/notifications/notification_pusher.py +147 -16
  150. mlrun/utils/regex.py +9 -0
  151. mlrun/utils/v3io_clients.py +0 -1
  152. mlrun/utils/version/version.json +2 -2
  153. {mlrun-1.7.0rc14.dist-info → mlrun-1.7.0rc22.dist-info}/METADATA +14 -6
  154. {mlrun-1.7.0rc14.dist-info → mlrun-1.7.0rc22.dist-info}/RECORD +158 -138
  155. mlrun/kfpops.py +0 -865
  156. mlrun/platforms/other.py +0 -305
  157. {mlrun-1.7.0rc14.dist-info → mlrun-1.7.0rc22.dist-info}/LICENSE +0 -0
  158. {mlrun-1.7.0rc14.dist-info → mlrun-1.7.0rc22.dist-info}/WHEEL +0 -0
  159. {mlrun-1.7.0rc14.dist-info → mlrun-1.7.0rc22.dist-info}/entry_points.txt +0 -0
  160. {mlrun-1.7.0rc14.dist-info → mlrun-1.7.0rc22.dist-info}/top_level.txt +0 -0
mlrun/runtimes/utils.py CHANGED
@@ -20,17 +20,17 @@ from io import StringIO
20
20
  from sys import stderr
21
21
 
22
22
  import pandas as pd
23
- from kubernetes import client
24
23
 
25
24
  import mlrun
26
25
  import mlrun.common.constants
26
+ import mlrun.common.constants as mlrun_constants
27
27
  import mlrun.common.schemas
28
28
  import mlrun.utils.regex
29
29
  from mlrun.artifacts import TableArtifact
30
+ from mlrun.common.runtimes.constants import RunLabels
30
31
  from mlrun.config import config
31
32
  from mlrun.errors import err_to_str
32
33
  from mlrun.frameworks.parallel_coordinates import gen_pcp_plot
33
- from mlrun.runtimes.constants import RunLabels
34
34
  from mlrun.runtimes.generators import selector
35
35
  from mlrun.utils import get_in, helpers, logger, verify_field_regex
36
36
 
@@ -39,9 +39,6 @@ class RunError(Exception):
39
39
  pass
40
40
 
41
41
 
42
- mlrun_key = "mlrun/"
43
-
44
-
45
42
  class _ContextStore:
46
43
  def __init__(self):
47
44
  self._context = None
@@ -280,43 +277,6 @@ def get_item_name(item, attr="name"):
280
277
  return getattr(item, attr, None)
281
278
 
282
279
 
283
- def apply_kfp(modify, cop, runtime):
284
- modify(cop)
285
-
286
- # Have to do it here to avoid circular dependencies
287
- from .pod import AutoMountType
288
-
289
- if AutoMountType.is_auto_modifier(modify):
290
- runtime.spec.disable_auto_mount = True
291
-
292
- api = client.ApiClient()
293
- for k, v in cop.pod_labels.items():
294
- runtime.metadata.labels[k] = v
295
- for k, v in cop.pod_annotations.items():
296
- runtime.metadata.annotations[k] = v
297
- if cop.container.env:
298
- env_names = [
299
- e.name if hasattr(e, "name") else e["name"] for e in runtime.spec.env
300
- ]
301
- for e in api.sanitize_for_serialization(cop.container.env):
302
- name = e["name"]
303
- if name in env_names:
304
- runtime.spec.env[env_names.index(name)] = e
305
- else:
306
- runtime.spec.env.append(e)
307
- env_names.append(name)
308
- cop.container.env.clear()
309
-
310
- if cop.volumes and cop.container.volume_mounts:
311
- vols = api.sanitize_for_serialization(cop.volumes)
312
- mounts = api.sanitize_for_serialization(cop.container.volume_mounts)
313
- runtime.spec.update_vols_and_mounts(vols, mounts)
314
- cop.volumes.clear()
315
- cop.container.volume_mounts.clear()
316
-
317
- return runtime
318
-
319
-
320
280
  def verify_limits(
321
281
  resources_field_name,
322
282
  mem=None,
@@ -410,10 +370,10 @@ def generate_resources(mem=None, cpu=None, gpus=None, gpu_type="nvidia.com/gpu")
410
370
 
411
371
 
412
372
  def get_func_selector(project, name=None, tag=None):
413
- s = [f"{mlrun_key}project={project}"]
373
+ s = [f"{mlrun_constants.MLRunInternalLabels.project}={project}"]
414
374
  if name:
415
- s.append(f"{mlrun_key}function={name}")
416
- s.append(f"{mlrun_key}tag={tag or 'latest'}")
375
+ s.append(f"{mlrun_constants.MLRunInternalLabels.function}={name}")
376
+ s.append(f"{mlrun_constants.MLRunInternalLabels.tag}={tag or 'latest'}")
417
377
  return s
418
378
 
419
379
 
@@ -476,6 +436,7 @@ def enrich_run_labels(
476
436
  ):
477
437
  labels_enrichment = {
478
438
  RunLabels.owner: os.environ.get("V3IO_USERNAME") or getpass.getuser(),
439
+ # TODO: remove this in 1.9.0
479
440
  RunLabels.v3io_user: os.environ.get("V3IO_USERNAME"),
480
441
  }
481
442
  labels_to_enrich = labels_to_enrich or RunLabels.all()
mlrun/serving/__init__.py CHANGED
@@ -22,10 +22,17 @@ __all__ = [
22
22
  "RouterStep",
23
23
  "QueueStep",
24
24
  "ErrorStep",
25
+ "MonitoringApplicationStep",
25
26
  ]
26
27
 
27
28
  from .routers import ModelRouter, VotingEnsemble # noqa
28
29
  from .server import GraphContext, GraphServer, create_graph_server # noqa
29
- from .states import ErrorStep, QueueStep, RouterStep, TaskStep # noqa
30
+ from .states import (
31
+ ErrorStep,
32
+ QueueStep,
33
+ RouterStep,
34
+ TaskStep,
35
+ MonitoringApplicationStep,
36
+ ) # noqa
30
37
  from .v1_serving import MLModelServer, new_v1_model_server # noqa
31
38
  from .v2_serving import V2ModelServer # noqa
mlrun/serving/server.py CHANGED
@@ -387,7 +387,7 @@ def v2_serving_handler(context, event, get_body=False):
387
387
 
388
388
 
389
389
  def create_graph_server(
390
- parameters={},
390
+ parameters=None,
391
391
  load_mode=None,
392
392
  graph=None,
393
393
  verbose=False,
@@ -403,6 +403,7 @@ def create_graph_server(
403
403
  server.graph.add_route("my", class_name=MyModelClass, model_path="{path}", z=100)
404
404
  print(server.test("/v2/models/my/infer", testdata))
405
405
  """
406
+ parameters = parameters or {}
406
407
  server = GraphServer(graph, parameters, load_mode, verbose=verbose, **kwargs)
407
408
  server.set_current_function(
408
409
  current_function or os.environ.get("SERVING_CURRENT_FUNCTION", "")
mlrun/serving/states.py CHANGED
@@ -12,7 +12,13 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- __all__ = ["TaskStep", "RouterStep", "RootFlowStep", "ErrorStep"]
15
+ __all__ = [
16
+ "TaskStep",
17
+ "RouterStep",
18
+ "RootFlowStep",
19
+ "ErrorStep",
20
+ "MonitoringApplicationStep",
21
+ ]
16
22
 
17
23
  import os
18
24
  import pathlib
@@ -55,6 +61,7 @@ class StepKinds:
55
61
  choice = "choice"
56
62
  root = "root"
57
63
  error_step = "error_step"
64
+ monitoring_application = "monitoring_application"
58
65
 
59
66
 
60
67
  _task_step_fields = [
@@ -485,13 +492,15 @@ class TaskStep(BaseStep):
485
492
  class_args[key] = arg
486
493
  class_args.update(extra_kwargs)
487
494
 
488
- # add common args (name, context, ..) only if target class can accept them
489
- argspec = getfullargspec(class_object)
490
- for key in ["name", "context", "input_path", "result_path", "full_event"]:
491
- if argspec.varkw or key in argspec.args:
492
- class_args[key] = getattr(self, key)
493
- if argspec.varkw or "graph_step" in argspec.args:
494
- class_args["graph_step"] = self
495
+ if not isinstance(self, MonitoringApplicationStep):
496
+ # add common args (name, context, ..) only if target class can accept them
497
+ argspec = getfullargspec(class_object)
498
+
499
+ for key in ["name", "context", "input_path", "result_path", "full_event"]:
500
+ if argspec.varkw or key in argspec.args:
501
+ class_args[key] = getattr(self, key)
502
+ if argspec.varkw or "graph_step" in argspec.args:
503
+ class_args["graph_step"] = self
495
504
  return class_args
496
505
 
497
506
  def get_step_class_object(self, namespace):
@@ -582,6 +591,39 @@ class TaskStep(BaseStep):
582
591
  return event
583
592
 
584
593
 
594
+ class MonitoringApplicationStep(TaskStep):
595
+ """monitoring application execution step, runs users class code"""
596
+
597
+ kind = "monitoring_application"
598
+ _default_class = ""
599
+
600
+ def __init__(
601
+ self,
602
+ class_name: Union[str, type] = None,
603
+ class_args: dict = None,
604
+ handler: str = None,
605
+ name: str = None,
606
+ after: list = None,
607
+ full_event: bool = None,
608
+ function: str = None,
609
+ responder: bool = None,
610
+ input_path: str = None,
611
+ result_path: str = None,
612
+ ):
613
+ super().__init__(
614
+ class_name=class_name,
615
+ class_args=class_args,
616
+ handler=handler,
617
+ name=name,
618
+ after=after,
619
+ full_event=full_event,
620
+ function=function,
621
+ responder=responder,
622
+ input_path=input_path,
623
+ result_path=result_path,
624
+ )
625
+
626
+
585
627
  class ErrorStep(TaskStep):
586
628
  """error execution step, runs a class or handler"""
587
629
 
@@ -1323,6 +1365,7 @@ classes_map = {
1323
1365
  "flow": FlowStep,
1324
1366
  "queue": QueueStep,
1325
1367
  "error_step": ErrorStep,
1368
+ "monitoring_application": MonitoringApplicationStep,
1326
1369
  }
1327
1370
 
1328
1371
 
mlrun/serving/utils.py CHANGED
@@ -46,6 +46,15 @@ def _update_result_body(result_path, event_body, result):
46
46
  class StepToDict:
47
47
  """auto serialization of graph steps to a python dictionary"""
48
48
 
49
+ meta_keys = [
50
+ "context",
51
+ "name",
52
+ "input_path",
53
+ "result_path",
54
+ "full_event",
55
+ "kwargs",
56
+ ]
57
+
49
58
  def to_dict(self, fields: list = None, exclude: list = None, strip: bool = False):
50
59
  """convert the step object to a python dictionary"""
51
60
  fields = fields or getattr(self, "_dict_fields", None)
@@ -54,24 +63,16 @@ class StepToDict:
54
63
  if exclude:
55
64
  fields = [field for field in fields if field not in exclude]
56
65
 
57
- meta_keys = [
58
- "context",
59
- "name",
60
- "input_path",
61
- "result_path",
62
- "full_event",
63
- "kwargs",
64
- ]
65
66
  args = {
66
67
  key: getattr(self, key)
67
68
  for key in fields
68
- if getattr(self, key, None) is not None and key not in meta_keys
69
+ if getattr(self, key, None) is not None and key not in self.meta_keys
69
70
  }
70
71
  # add storey kwargs or extra kwargs
71
72
  if "kwargs" in fields and (hasattr(self, "kwargs") or hasattr(self, "_kwargs")):
72
73
  kwargs = getattr(self, "kwargs", {}) or getattr(self, "_kwargs", {})
73
74
  for key, value in kwargs.items():
74
- if key not in meta_keys:
75
+ if key not in self.meta_keys:
75
76
  args[key] = value
76
77
 
77
78
  mod_name = self.__class__.__module__
@@ -80,7 +81,9 @@ class StepToDict:
80
81
  class_path = f"{mod_name}.{class_path}"
81
82
  struct = {
82
83
  "class_name": class_path,
83
- "name": self.name or self.__class__.__name__,
84
+ "name": self.name
85
+ if hasattr(self, "name") and self.name
86
+ else self.__class__.__name__,
84
87
  "class_args": args,
85
88
  }
86
89
  if hasattr(self, "_STEP_KIND"):
@@ -94,6 +97,11 @@ class StepToDict:
94
97
  return struct
95
98
 
96
99
 
100
+ class MonitoringApplicationToDict(StepToDict):
101
+ _STEP_KIND = "monitoring_application"
102
+ meta_keys = []
103
+
104
+
97
105
  class RouterToDict(StepToDict):
98
106
  _STEP_KIND = "router"
99
107
 
@@ -528,7 +528,11 @@ def _init_endpoint_record(
528
528
  return None
529
529
 
530
530
  # Generating version model value based on the model name and model version
531
- if model.version:
531
+ if model.model_path and model.model_path.startswith("store://"):
532
+ # Enrich the model server with the model artifact metadata
533
+ model.get_model()
534
+ model.version = model.model_spec.tag
535
+ model.labels = model.model_spec.labels
532
536
  versioned_model_name = f"{model.name}:{model.version}"
533
537
  else:
534
538
  versioned_model_name = f"{model.name}:latest"
mlrun/track/tracker.py CHANGED
@@ -31,8 +31,9 @@ class Tracker(ABC):
31
31
  * Offline: Manually importing models and artifacts into an MLRun project using the `import_x` methods.
32
32
  """
33
33
 
34
+ @staticmethod
34
35
  @abstractmethod
35
- def is_enabled(self) -> bool:
36
+ def is_enabled() -> bool:
36
37
  """
37
38
  Checks if tracker is enabled.
38
39
 
mlrun/utils/async_http.py CHANGED
@@ -24,7 +24,7 @@ from aiohttp_retry import ExponentialRetry, RequestParams, RetryClient, RetryOpt
24
24
  from aiohttp_retry.client import _RequestContext
25
25
 
26
26
  from mlrun.config import config
27
- from mlrun.errors import err_to_str
27
+ from mlrun.errors import err_to_str, raise_for_status
28
28
 
29
29
  from .helpers import logger as mlrun_logger
30
30
 
@@ -46,12 +46,21 @@ class AsyncClientWithRetry(RetryClient):
46
46
  *args,
47
47
  **kwargs,
48
48
  ):
49
+ # do not retry on PUT / PATCH as they might have side effects (not truly idempotent)
50
+ blacklisted_methods = (
51
+ blacklisted_methods
52
+ if blacklisted_methods is not None
53
+ else [
54
+ "POST",
55
+ "PUT",
56
+ "PATCH",
57
+ ]
58
+ )
49
59
  super().__init__(
50
60
  *args,
51
61
  retry_options=ExponentialRetryOverride(
52
62
  retry_on_exception=retry_on_exception,
53
- # do not retry on PUT / PATCH as they might have side effects (not truly idempotent)
54
- blacklisted_methods=blacklisted_methods or ["POST", "PUT", "PATCH"],
63
+ blacklisted_methods=blacklisted_methods,
55
64
  attempts=max_retries,
56
65
  statuses=retry_on_status_codes,
57
66
  factor=retry_backoff_factor,
@@ -63,6 +72,12 @@ class AsyncClientWithRetry(RetryClient):
63
72
  **kwargs,
64
73
  )
65
74
 
75
+ def methods_blacklist_update_required(self, new_blacklist: str):
76
+ self._retry_options: ExponentialRetryOverride
77
+ return set(self._retry_options.blacklisted_methods).difference(
78
+ set(new_blacklist)
79
+ )
80
+
66
81
  def _make_requests(
67
82
  self,
68
83
  params_list: list[RequestParams],
@@ -173,7 +188,7 @@ class _CustomRequestContext(_RequestContext):
173
188
  last_attempt = current_attempt == self._retry_options.attempts
174
189
  if self._is_status_code_ok(response.status) or last_attempt:
175
190
  if self._raise_for_status:
176
- response.raise_for_status()
191
+ raise_for_status(response)
177
192
 
178
193
  self._response = response
179
194
  return response
@@ -275,6 +290,11 @@ class _CustomRequestContext(_RequestContext):
275
290
  if isinstance(exc.os_error, exc_type):
276
291
  return
277
292
  if exc.__cause__:
278
- return self.verify_exception_type(exc.__cause__)
293
+ # If the cause exception is retriable, return, otherwise, raise the original exception
294
+ try:
295
+ self.verify_exception_type(exc.__cause__)
296
+ except Exception:
297
+ raise exc
298
+ return
279
299
  else:
280
300
  raise exc