corva-worker-python 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. corva_worker_python-2.0.0.dist-info/METADATA +30 -0
  2. corva_worker_python-2.0.0.dist-info/RECORD +63 -0
  3. corva_worker_python-2.0.0.dist-info/WHEEL +5 -0
  4. corva_worker_python-2.0.0.dist-info/top_level.txt +1 -0
  5. worker/__init__.py +5 -0
  6. worker/app/__init__.py +291 -0
  7. worker/app/modules/__init__.py +265 -0
  8. worker/app/modules/activity_module.py +141 -0
  9. worker/app/modules/connection_module.py +21 -0
  10. worker/app/modules/depth_activity_module.py +21 -0
  11. worker/app/modules/scheduler.py +44 -0
  12. worker/app/modules/time_activity_module.py +21 -0
  13. worker/app/modules/trigger.py +43 -0
  14. worker/constants.py +51 -0
  15. worker/data/__init__.py +0 -0
  16. worker/data/activity/__init__.py +132 -0
  17. worker/data/activity/activity_grouping.py +242 -0
  18. worker/data/alert.py +89 -0
  19. worker/data/api.py +155 -0
  20. worker/data/enums.py +141 -0
  21. worker/data/json_encoder.py +18 -0
  22. worker/data/math.py +104 -0
  23. worker/data/operations.py +477 -0
  24. worker/data/serialization.py +110 -0
  25. worker/data/task_handler.py +82 -0
  26. worker/data/two_way_dict.py +17 -0
  27. worker/data/unit_conversions.py +5 -0
  28. worker/data/wits.py +323 -0
  29. worker/event/__init__.py +53 -0
  30. worker/event/event_handler.py +90 -0
  31. worker/event/scheduled.py +64 -0
  32. worker/event/stream.py +48 -0
  33. worker/exceptions.py +26 -0
  34. worker/mixins/__init__.py +0 -0
  35. worker/mixins/logging.py +119 -0
  36. worker/mixins/rollbar.py +87 -0
  37. worker/partial_rerun_merge/__init__.py +0 -0
  38. worker/partial_rerun_merge/merge.py +500 -0
  39. worker/partial_rerun_merge/models.py +91 -0
  40. worker/partial_rerun_merge/progress.py +241 -0
  41. worker/state/__init__.py +96 -0
  42. worker/state/mixins.py +111 -0
  43. worker/state/state.py +46 -0
  44. worker/test/__init__.py +3 -0
  45. worker/test/lambda_function_test_run.py +196 -0
  46. worker/test/local_testing/__init__.py +0 -0
  47. worker/test/local_testing/to_local_transfer.py +360 -0
  48. worker/test/utils.py +51 -0
  49. worker/wellbore/__init__.py +0 -0
  50. worker/wellbore/factory.py +496 -0
  51. worker/wellbore/measured_depth_finder.py +12 -0
  52. worker/wellbore/model/__init__.py +0 -0
  53. worker/wellbore/model/ann.py +103 -0
  54. worker/wellbore/model/annulus.py +113 -0
  55. worker/wellbore/model/drillstring.py +196 -0
  56. worker/wellbore/model/drillstring_components.py +439 -0
  57. worker/wellbore/model/element.py +102 -0
  58. worker/wellbore/model/enums.py +92 -0
  59. worker/wellbore/model/hole.py +297 -0
  60. worker/wellbore/model/hole_section.py +51 -0
  61. worker/wellbore/model/riser.py +22 -0
  62. worker/wellbore/sections_mixin.py +64 -0
  63. worker/wellbore/wellbore.py +289 -0
worker/data/math.py ADDED
@@ -0,0 +1,104 @@
1
+ import numpy as np
2
+
3
+ from worker.data import operations
4
+
5
+
6
+ def percentile(ls, percent):
7
+ """
8
+ Calculates percentile without considering np.nan and None values. Returns 0 is all values are np.nan or None
9
+ :param ls:
10
+ :param percent:
11
+ :return:
12
+ """
13
+ try:
14
+ p = np.nanpercentile(operations.none_to_nan(ls), percent)
15
+ if p >= 0 or p <= 0:
16
+ return p
17
+
18
+ return None
19
+ except TypeError:
20
+ return None
21
+
22
+
23
+ def mean_angles(ls):
24
+ """
25
+ to compute mean of list of angles
26
+ :param ls: a list of angles
27
+ :return: mean value of angles
28
+ """
29
+ x_mean = np.nanmean(np.cos(np.deg2rad(ls)))
30
+ y_mean = np.nanmean(np.sin(np.deg2rad(ls)))
31
+ mean_deg = np.rad2deg(np.arctan2(y_mean, x_mean))
32
+
33
+ return mean_deg % 360
34
+
35
+
36
+ def angle_difference(ang1, ang2):
37
+ """
38
+ Code from: https://rosettacode.org/wiki/Angle_difference_between_two_bearings#Python
39
+ :param ang1:
40
+ :param ang2:
41
+ :return:
42
+ """
43
+ if not operations.is_number(ang1) or not operations.is_number(ang2):
44
+ return None
45
+
46
+ r = (ang2 - ang1) % 360.0
47
+ if r >= 180.0:
48
+ r -= 360.0
49
+ return r
50
+
51
+
52
+ def abs_angle_difference(ang1, ang2):
53
+ """
54
+ The absolute difference between two angles.
55
+ :param ang1:
56
+ :param ang2:
57
+ :return:
58
+ """
59
+ diff = angle_difference(ang1, ang2)
60
+ if operations.is_number(diff):
61
+ return abs(diff)
62
+
63
+ return None
64
+
65
+
66
+ def split_zip_edges(arr, separation_length=1, min_segment_length=1):
67
+ """
68
+ In cases that you have elements and you want only values that are close to each other.
69
+ :param arr: array of non-continuous data
70
+ :param separation_length: separation length
71
+ :param min_segment_length: min length of each segment
72
+ :return: a list of tuples representing the start and stop of each segment
73
+ """
74
+ if isinstance(arr, list):
75
+ arr = np.array(arr)
76
+
77
+ m = np.concatenate(([True], arr[1:] > arr[:-1] + separation_length, [True]))
78
+ idx = np.flatnonzero(m)
79
+ ll = arr.tolist()
80
+ return [(ll[i], ll[j - 1]) for i, j in zip(idx[:-1], idx[1:]) if (ll[j - 1] + 1 - ll[i]) >= min_segment_length]
81
+
82
+
83
+ def start_stop(arr, trigger_val, min_len_thresh=1):
84
+ """
85
+ If you have an array representing values and you only want
86
+ the values which are equal to a specific trigger_value.
87
+ Another param is the minimum of each interval size.
88
+ :param arr: a continuous stream of data
89
+ :param trigger_val: desired value
90
+ :param min_len_thresh: the min distance between two separate segments
91
+ :return: a list of tuples representing the start and stop of each segment
92
+ """
93
+ # "Enclose" mask with sentient to catch shifts later on
94
+ mask = np.r_[False, np.equal(arr, trigger_val), False]
95
+
96
+ # Get the shifting indices
97
+ idx = np.flatnonzero(mask[1:] != mask[:-1])
98
+
99
+ # Get lengths
100
+ lens = idx[1::2] - idx[::2]
101
+
102
+ res = idx.reshape(-1, 2)[lens >= min_len_thresh] - [0, 1]
103
+
104
+ return [(i[0], i[-1]) for i in res]
@@ -0,0 +1,477 @@
1
+ import os
2
+ from typing import List, Literal, Tuple, Union
3
+
4
+ import numpy as np
5
+ import simplejson as json
6
+
7
+ from worker.data import math
8
+ from worker.data.api import API
9
+ from worker.data.enums import Environment, EventType
10
+ from worker.data.wits import WITS
11
+ from worker.exceptions import NotFound
12
+
13
+
14
+ def gather_data_for_period(
15
+ asset_id: int, start: int, end: int, limit: int = 1800, collection: str = "wits", fields: str = None
16
+ ) -> list:
17
+ """
18
+ Get the wits data from API for an asset over an interval
19
+ :param asset_id: asset id
20
+ :param start: start timestamp
21
+ :param end: end timestamp
22
+ :param limit: count of the data
23
+ :param collection: any collection
24
+ :param fields: fields to be filtered
25
+ :return: a list of wits data
26
+ """
27
+ if start >= end:
28
+ return []
29
+
30
+ query = "{timestamp#gte#%s}AND{timestamp#lte#%s}" % (start, end)
31
+ worker = API()
32
+
33
+ wits_dataset = worker.get(
34
+ path="/v1/data/corva",
35
+ collection=collection,
36
+ asset_id=asset_id,
37
+ sort="{timestamp: 1}",
38
+ limit=limit,
39
+ query=query,
40
+ fields=fields,
41
+ ).data
42
+
43
+ if not wits_dataset:
44
+ return []
45
+
46
+ return wits_dataset
47
+
48
+
49
+ def get_one_data_record(asset_id: int, timestamp_sort: int = -1, collection: str = "wits") -> dict:
50
+ """
51
+ Get the first or last wits record of a given asset
52
+ :param asset_id:
53
+ :param timestamp_sort:
54
+ :param collection:
55
+ :return:
56
+ """
57
+ api_worker = API()
58
+ data = api_worker.get(
59
+ path="/v1/data/corva/",
60
+ collection=collection,
61
+ asset_id=asset_id,
62
+ sort="{timestamp:%s}" % timestamp_sort,
63
+ limit=1,
64
+ ).data
65
+
66
+ if not data:
67
+ return {}
68
+
69
+ return data[0]
70
+
71
+
72
+ def delete_collection_data_of_asset_id(asset_id: int, collections: Union[str, list]):
73
+ """
74
+ Delete all the data of a collection for an asset id.
75
+ :param asset_id:
76
+ :param collections: a collection or a list of collections
77
+ :return:
78
+ """
79
+ worker = API()
80
+
81
+ if isinstance(collections, str):
82
+ collections = [collections]
83
+
84
+ for collection in collections:
85
+ path = "/v1/data/corva/%s" % collection
86
+ query = "{asset_id#eq#%s}" % asset_id
87
+
88
+ # Looping a maximum of 2000 times. Tries to delete a maximum of 7.2 million records.
89
+ for _ in range(2000):
90
+ # Upon delete API returns the following response {"deleted_count": 10000}
91
+ res = worker.delete(path=path, query=query, limit=3600, retry_count=1).data
92
+ records_deleted = res.get("deleted_count") or 0
93
+
94
+ # API currently deletes 10000 records at a time, if the deleted count is less than 3600, break loop
95
+ if records_deleted < 3600:
96
+ break
97
+
98
+
99
+ def point_main_envs(env: Literal["qa", "staging", "production", "local"]):
100
+ """
101
+ The purpose of this function is to point main environment variables to
102
+ the provided environment. This method updates the following environment
103
+ variables:
104
+ - API_ROOT_URL
105
+ - API_KEY
106
+ - CACHE_URL
107
+
108
+ :param env: the environment to point to
109
+ :return:
110
+ """
111
+ if not env:
112
+ return
113
+
114
+ # validating the environment
115
+ Environment(env)
116
+
117
+ api_url = os.getenv(f"API_ROOT_URL_{env}") or os.getenv("API_ROOT_URL")
118
+ api_key = os.getenv(f"API_KEY_{env}") or os.getenv("API_KEY")
119
+ cache_url = os.getenv(f"CACHE_URL_{env}") or os.getenv("CACHE_URL")
120
+
121
+ if not all([api_url, api_key, cache_url]):
122
+ raise ValueError("Missing environment variables!")
123
+
124
+ new_envs = {"API_ROOT_URL": api_url, "API_KEY": api_key, "CACHE_URL": cache_url}
125
+
126
+ os.environ.update(new_envs)
127
+
128
+
129
+ def setup_api_worker(env: Literal["qa", "staging", "production", "local"], app_name: str) -> API:
130
+ """
131
+ Set up the Corva API worker
132
+ :param env: environment ['local', 'qa', 'staging', 'production']
133
+ :param app_name:
134
+ :return: api worker
135
+ """
136
+ # validating the environment
137
+ Environment(env)
138
+
139
+ api_url = os.getenv(f"API_ROOT_URL_{env}") or os.getenv("API_ROOT_URL")
140
+ api_key = os.getenv(f"API_KEY_{env}") or os.getenv("API_KEY")
141
+
142
+ options = {"api_url": api_url, "api_key": api_key, "app_name": app_name}
143
+ api_worker = API(**options)
144
+
145
+ return api_worker
146
+
147
+
148
+ def setup_redis_worker(env: Literal["qa", "staging", "production", "local"]):
149
+ """
150
+ Set up the Redis worker
151
+ :param env: environment ['local', 'qa', 'staging', 'production']
152
+ :return: redis worker
153
+ """
154
+ # validating the environment
155
+ Environment(env)
156
+
157
+ cache_url = os.getenv(f"CACHE_URL_{env}") or os.getenv("CACHE_URL")
158
+
159
+ import redis
160
+
161
+ redis_worker = redis.Redis.from_url(cache_url, decode_responses=True)
162
+
163
+ return redis_worker
164
+
165
+
166
+ def get_config_by_id(string_id: str, collection: str) -> Union[dict, None]:
167
+ """
168
+ Get the drillstring or casingstring from API by providing mongodb _id
169
+ :param string_id: mongodb _id of the drillstring
170
+ :param collection:
171
+ :return:
172
+ """
173
+ string = None
174
+ try:
175
+ string = API().get_by_id(path="/v1/data/corva/", collection=collection, id=string_id).data
176
+ except NotFound:
177
+ pass
178
+
179
+ return string
180
+
181
+
182
+ def is_number(data):
183
+ """
184
+ Check and return True if data is a number, else return False
185
+ :param data: Input can be string, number or nan
186
+ :return: True or False
187
+ """
188
+ try:
189
+ data_cast = float(data)
190
+ if data_cast >= 0 or data_cast <= 0: # to make sure it is a valid number
191
+ return True
192
+
193
+ return False
194
+ except ValueError:
195
+ return False
196
+ except TypeError:
197
+ return False
198
+
199
+
200
+ def is_finite(data):
201
+ """
202
+ Check if the given data is a finite number
203
+ Note that the string representation of a number is not finite
204
+ :param data:
205
+ :return: True or False
206
+ """
207
+ try:
208
+ return is_number(data) and np.isfinite(data)
209
+ except (TypeError, ValueError):
210
+ return False
211
+
212
+
213
+ def is_int(s: str) -> bool:
214
+ """
215
+ To check if the given string is an integer or not
216
+ :param s:
217
+ :return:
218
+ """
219
+ try:
220
+ int(s)
221
+ return True
222
+ except ValueError:
223
+ return False
224
+
225
+
226
+ def is_null(value) -> bool:
227
+ """
228
+ Will return True if the value is null, None, any variant of -999.25
229
+
230
+ :param value:
231
+ :return:
232
+ """
233
+ value = to_number(value)
234
+ if value is None:
235
+ return True
236
+
237
+ if value >= 0:
238
+ return False
239
+
240
+ value = str(value)
241
+ for each in ["-", ".", "0"]:
242
+ value = value.replace(each, "")
243
+
244
+ if value == "99925":
245
+ return True
246
+
247
+ return False
248
+
249
+
250
+ def to_number(data):
251
+ """
252
+ Check and return if the data can be cast to a number, else return None
253
+ :param data: Input can be string, number or nan
254
+ :return: A numbers
255
+ """
256
+ if is_number(data):
257
+ return float(data)
258
+
259
+ return None
260
+
261
+
262
+ def none_to_nan(data):
263
+ """
264
+ If data is a list, return list with None replaced with nan.
265
+ If data is None, return nan
266
+ :param data:
267
+ :return:
268
+ """
269
+ if isinstance(data, list):
270
+ return [np.nan if e is None else e for e in data]
271
+
272
+ if data is None:
273
+ return np.nan
274
+
275
+ return data
276
+
277
+
278
+ def get_data_by_path(data: dict, path: str, func=lambda x: x, **kwargs):
279
+ """
280
+ To find the path to a key in a nested dictionary.
281
+ Note that none of the keys should end up in a list
282
+ :param data:
283
+ :param path: path to the final key; example of the paths are:
284
+ 'data.X.Y'
285
+ 'data.bit_depth'
286
+ :param func: the type of the data (int, str, float, ...)
287
+ :param kwargs: pass default value in case the path not found;
288
+ note that None is an acceptable default
289
+ :return:
290
+ """
291
+ has_default = "default" in kwargs
292
+ default = kwargs.pop("default", None)
293
+
294
+ if not path:
295
+ if has_default:
296
+ return default
297
+
298
+ raise KeyError("No key provided")
299
+
300
+ keys = path.split(".")
301
+
302
+ while keys:
303
+ current_key = keys.pop(0)
304
+
305
+ if current_key not in data:
306
+ if has_default:
307
+ return default
308
+
309
+ raise KeyError(f"{current_key} not found in path")
310
+
311
+ data = data.get(current_key)
312
+
313
+ if data is None:
314
+ return None
315
+
316
+ return func(data)
317
+
318
+
319
+ def is_in_and_not_none(d: dict, key: str):
320
+ """
321
+ An structured way of getting data from a dict.
322
+ :param d: the dictionary
323
+ :param key:
324
+ :return: True or False
325
+ """
326
+ if key in d.keys() and d[key] is not None:
327
+ return True
328
+
329
+ return False
330
+
331
+
332
+ def nanround(value, decimal_places=2):
333
+ """
334
+ Similar to python built-in round but considering None values as well
335
+ :param value:
336
+ :param decimal_places:
337
+ :return:
338
+ """
339
+ if is_number(value):
340
+ return round(value, decimal_places)
341
+
342
+ return None
343
+
344
+
345
+ def merge_dicts(d1: dict, d2: dict) -> dict:
346
+ """
347
+ Merge two dictionaries
348
+ Note: the 2nd item (d2) has a higher priority to write items with similar keys
349
+ :param d1:
350
+ :param d2:
351
+ :return:
352
+ """
353
+ d = {**d1, **d2}
354
+ return d
355
+
356
+
357
+ def equal(obj1: object, obj2: object, params: List[str]) -> bool:
358
+ """
359
+ To check if two objects are equal by comparing the given parameters.
360
+ :param obj1:
361
+ :param obj2:
362
+ :param params:
363
+ :return:
364
+ """
365
+ if type(obj1) is not type(obj2):
366
+ return False
367
+
368
+ return all(getattr(obj1, param) == getattr(obj2, param) for param in params)
369
+
370
+
371
+ def get_cleaned_event_and_type(event) -> Tuple[Union[list, dict], EventType]:
372
+ """
373
+ validate and flatten the events and organize the data into a desired format
374
+
375
+ Task and generic events format is : dict => {}
376
+ Scheduler events format is: list of list of dict => [[{}]]
377
+ Kafka events format is: list of dict => [{}]
378
+ The above formats can be used to determine the format
379
+
380
+ :param event: a scheduler of kafka stream
381
+ :return: event and event_type
382
+ """
383
+
384
+ if not event:
385
+ raise ValueError("Empty events")
386
+
387
+ if isinstance(event, (str, bytes, bytearray)):
388
+ event = json.loads(event)
389
+
390
+ if isinstance(event, dict):
391
+ if event.get("event_type") == "partial-well-rerun-merge":
392
+ return event, EventType.PARTIAL_RERUN
393
+
394
+ if "task_id" in event:
395
+ return event, EventType.TASK
396
+
397
+ if "asset_id" in event:
398
+ return event, EventType.GENERIC
399
+
400
+ raise TypeError("Missing task_id or asset_id keys in event")
401
+
402
+ if not isinstance(event, list):
403
+ raise TypeError("Event is not a list or a dict")
404
+
405
+ first_event = event[0]
406
+ if isinstance(first_event, list):
407
+ if first_event[0] and "schedule_start" in first_event[0]:
408
+ return event, EventType.SCHEDULER
409
+
410
+ raise Exception("Missing scheduler_start key in scheduler event")
411
+
412
+ elif isinstance(first_event, dict):
413
+ # new kafka stream format: list of json objects, each with metadata and records
414
+ # event = [{"metadata": { ... }, "records": [ ... ]}, {"metadata": { ... }, "records": [ ... ]}]
415
+ return event, EventType.STREAM
416
+
417
+ raise TypeError("Event is not either a scheduler or kafka consumer")
418
+
419
+
420
+ def compute_time_step(records: Union[List[WITS], List[dict], List[float]], percent=50) -> Union[int, None]:
421
+ """
422
+ Compute the time step of the wits records
423
+ If the intention is to split the data based on their activities a higher percent
424
+ is recommended (such as 99%). For other cases 50% might work.
425
+
426
+ :param records: a list of wits records in WITS or dict format, or a list of timestamps
427
+ :param percent: this is an important parameter to get the correct time step
428
+ :return:
429
+ """
430
+ if len(records) <= 3:
431
+ return None
432
+
433
+ if not 0 <= percent <= 100:
434
+ raise ValueError(f"percent ({percent}) is out of [0, 100] range.")
435
+
436
+ if isinstance(records[0], WITS):
437
+ timestamps = [rec.timestamp for rec in records]
438
+ elif isinstance(records[0], dict):
439
+ timestamps = [rec.get("timestamp") for rec in records]
440
+ else:
441
+ timestamps = records
442
+
443
+ timestamps = [timestamp for timestamp in timestamps if is_finite(timestamp)]
444
+
445
+ # a list of time steps for all the timestamps
446
+ diffs = np.diff(timestamps)
447
+
448
+ time_step = math.percentile(diffs, percent)
449
+ if not time_step:
450
+ return None
451
+
452
+ return int(time_step)
453
+
454
+
455
+ def compare_float(value1, value2, tolerance):
456
+ """
457
+ Compare two values based on the given tolerance
458
+ :param value1: first number
459
+ :param value2: second number
460
+ :param tolerance: tolerance of the comparison
461
+ :return: +1 when value1 > value2 + tolerance, -1 when value2 > value1 + tolerance, 0 otherwise
462
+ """
463
+ if value1 > value2 + tolerance:
464
+ return 1
465
+
466
+ if value2 > value1 + tolerance:
467
+ return -1
468
+
469
+ return 0
470
+
471
+
472
+ def is_stream_app():
473
+ lambda_name = os.getenv("AWS_LAMBDA_FUNCTION_NAME")
474
+ if not lambda_name:
475
+ return False
476
+
477
+ return "task" not in lambda_name.lower()
@@ -0,0 +1,110 @@
1
+ # The modules of this files are used for (de)serialization of any object
2
+ #
3
+ # The only things you need to do is to add '@serialization' decorator to
4
+ # the class. If the class is an enum that's enough. For other class types,
5
+ # you need to specify the variables that need to be serialized as a dict:
6
+ # variable name as key -> variable type as value
7
+ # the name of it should be 'SERIALIZED_VARIABLES'
8
+ #
9
+ # At the end you can use:
10
+ # 'obj2json' to serialize, and
11
+ # 'json2obj' to deserialize
12
+ #
13
+
14
+
15
+ import json
16
+ from enum import Enum
17
+ from typing import Union
18
+
19
+ from worker.data.two_way_dict import TwoWayDict
20
+
21
+ # This registry object keeps a list of classes that subscribed to serialization.
22
+ registry = TwoWayDict()
23
+
24
+
25
+ # This method is used as a decorator of the class that wants to be serialized
26
+ def serialization(cls):
27
+ registry[cls] = f"_{cls.__name__.lower()}_"
28
+ return cls
29
+
30
+
31
+ class Encoder(json.JSONEncoder):
32
+ """
33
+ This class is used to encode: object --> json string
34
+ """
35
+
36
+ def default(self, obj):
37
+ cls = obj.__class__
38
+
39
+ if cls in registry:
40
+ cls_tag = registry[cls]
41
+ d = {cls_tag: True}
42
+
43
+ # enum types
44
+ if issubclass(cls, Enum):
45
+ d["name"] = obj.name
46
+ return d
47
+
48
+ # class instances
49
+ # 1. classes with serialize method
50
+ if hasattr(obj, "serialize"):
51
+ d.update(obj.serialize())
52
+ return d
53
+
54
+ # 2. classes without it
55
+ for var_name in cls.SERIALIZED_VARIABLES.keys():
56
+ value = getattr(obj, var_name)
57
+ d[var_name] = value
58
+ return d
59
+
60
+ return json.JSONEncoder.default(self, obj)
61
+
62
+
63
+ def obj2json(obj: object, output_format=str) -> Union[str, dict]:
64
+ """
65
+ Convert an object to JSON. The object class should have been decorated
66
+ with serialization function.
67
+ :param obj:
68
+ :param output_format: str, or dict
69
+ :return:
70
+ """
71
+ output = json.dumps(obj, cls=Encoder)
72
+ if output_format == str:
73
+ return output
74
+ if output_format == dict:
75
+ return json.loads(output)
76
+ raise TypeError("output_format should be str or dict.")
77
+
78
+
79
+ def _set_from_json(cls, data: str):
80
+ if isinstance(data, str):
81
+ data = json.loads(data)
82
+
83
+ # getitem of an enum is the best way of de-serializing it
84
+ if issubclass(cls, Enum):
85
+ return cls[data["name"]]
86
+
87
+ return cls(**data)
88
+
89
+
90
+ def _json_to_obj_hook(j_str):
91
+ cls_tags = [tag for tag in registry.values() if isinstance(tag, str)]
92
+ for tag in cls_tags:
93
+ if tag in j_str:
94
+ cls = registry.get(tag)
95
+ return _set_from_json(cls, j_str)
96
+
97
+ return j_str
98
+
99
+
100
+ def json2obj(json_str: Union[str, dict]):
101
+ if json_str is None:
102
+ return None
103
+
104
+ if not isinstance(json_str, (str, dict)):
105
+ raise TypeError(f"Wrong json format: {type(json_str)}. Acceptable formats are str and dict.")
106
+
107
+ if isinstance(json_str, dict):
108
+ json_str = json.dumps(json_str)
109
+
110
+ return json.loads(json_str, object_hook=_json_to_obj_hook)