megadetector 5.0.23__py3-none-any.whl → 5.0.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of megadetector might be problematic. Click here for more details.
- megadetector/api/synchronous/api_core/animal_detection_api/api_backend.py +2 -3
- megadetector/classification/merge_classification_detection_output.py +2 -2
- megadetector/data_management/coco_to_labelme.py +2 -1
- megadetector/data_management/databases/integrity_check_json_db.py +15 -14
- megadetector/data_management/databases/subset_json_db.py +49 -21
- megadetector/data_management/lila/add_locations_to_island_camera_traps.py +73 -69
- megadetector/data_management/lila/add_locations_to_nacti.py +114 -110
- megadetector/data_management/mewc_to_md.py +340 -0
- megadetector/data_management/speciesnet_to_md.py +41 -0
- megadetector/data_management/yolo_output_to_md_output.py +15 -8
- megadetector/detection/process_video.py +24 -7
- megadetector/detection/pytorch_detector.py +841 -160
- megadetector/detection/run_detector.py +341 -146
- megadetector/detection/run_detector_batch.py +307 -70
- megadetector/detection/run_inference_with_yolov5_val.py +61 -4
- megadetector/detection/tf_detector.py +6 -1
- megadetector/postprocessing/{combine_api_outputs.py → combine_batch_outputs.py} +10 -13
- megadetector/postprocessing/compare_batch_results.py +236 -7
- megadetector/postprocessing/create_crop_folder.py +358 -0
- megadetector/postprocessing/md_to_labelme.py +7 -7
- megadetector/postprocessing/md_to_wi.py +40 -0
- megadetector/postprocessing/merge_detections.py +1 -1
- megadetector/postprocessing/postprocess_batch_results.py +12 -5
- megadetector/postprocessing/separate_detections_into_folders.py +32 -4
- megadetector/postprocessing/validate_batch_results.py +9 -4
- megadetector/utils/ct_utils.py +236 -45
- megadetector/utils/directory_listing.py +3 -3
- megadetector/utils/gpu_test.py +125 -0
- megadetector/utils/md_tests.py +455 -116
- megadetector/utils/path_utils.py +43 -2
- megadetector/utils/wi_utils.py +2691 -0
- megadetector/visualization/visualization_utils.py +95 -18
- megadetector/visualization/visualize_db.py +25 -7
- megadetector/visualization/visualize_detector_output.py +60 -13
- {megadetector-5.0.23.dist-info → megadetector-5.0.25.dist-info}/METADATA +11 -23
- {megadetector-5.0.23.dist-info → megadetector-5.0.25.dist-info}/RECORD +39 -36
- {megadetector-5.0.23.dist-info → megadetector-5.0.25.dist-info}/WHEEL +1 -1
- megadetector/detection/detector_training/__init__.py +0 -0
- megadetector/detection/detector_training/model_main_tf2.py +0 -114
- megadetector/utils/torch_test.py +0 -32
- {megadetector-5.0.23.dist-info → megadetector-5.0.25.dist-info}/LICENSE +0 -0
- {megadetector-5.0.23.dist-info → megadetector-5.0.25.dist-info}/top_level.txt +0 -0
megadetector/utils/ct_utils.py
CHANGED
|
@@ -28,8 +28,8 @@ image_extensions = ['.jpg', '.jpeg', '.gif', '.png']
|
|
|
28
28
|
|
|
29
29
|
def truncate_float_array(xs, precision=3):
|
|
30
30
|
"""
|
|
31
|
-
|
|
32
|
-
|
|
31
|
+
Truncates the fractional portion of each floating-point value in the array [xs]
|
|
32
|
+
to a specific number of floating-point digits.
|
|
33
33
|
|
|
34
34
|
Args:
|
|
35
35
|
xs (list): list of floats to truncate
|
|
@@ -42,6 +42,37 @@ def truncate_float_array(xs, precision=3):
|
|
|
42
42
|
return [truncate_float(x, precision=precision) for x in xs]
|
|
43
43
|
|
|
44
44
|
|
|
45
|
+
def round_float_array(xs, precision=3):
|
|
46
|
+
"""
|
|
47
|
+
Truncates the fractional portion of each floating-point value in the array [xs]
|
|
48
|
+
to a specific number of floating-point digits.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
xs (list): list of floats to round
|
|
52
|
+
precision (int, optional): the number of significant digits to preserve, should be >= 1
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
list: list of rounded floats
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
return [round_float(x,precision) for x in xs]
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def round_float(x, precision=3):
|
|
62
|
+
"""
|
|
63
|
+
Convenience wrapper for the native Python round()
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
x (float): number to truncate
|
|
67
|
+
precision (int, optional): the number of significant digits to preserve, should be >= 1
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
float: rounded value
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
return round(x,precision)
|
|
74
|
+
|
|
75
|
+
|
|
45
76
|
def truncate_float(x, precision=3):
|
|
46
77
|
"""
|
|
47
78
|
Truncates the fractional portion of a floating-point value to a specific number of
|
|
@@ -63,26 +94,7 @@ def truncate_float(x, precision=3):
|
|
|
63
94
|
float: truncated version of [x]
|
|
64
95
|
"""
|
|
65
96
|
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
if np.isclose(x, 0):
|
|
69
|
-
|
|
70
|
-
return 0
|
|
71
|
-
|
|
72
|
-
elif (x > 1):
|
|
73
|
-
|
|
74
|
-
fractional_component = x - 1.0
|
|
75
|
-
return 1 + truncate_float(fractional_component)
|
|
76
|
-
|
|
77
|
-
else:
|
|
78
|
-
|
|
79
|
-
# Determine the factor, which shifts the decimal point of x
|
|
80
|
-
# just behind the last significant digit.
|
|
81
|
-
factor = math.pow(10, precision - 1 - math.floor(math.log10(abs(x))))
|
|
82
|
-
|
|
83
|
-
# Shift decimal point by multiplication with factor, flooring, and
|
|
84
|
-
# division by factor.
|
|
85
|
-
return math.floor(x * factor)/factor
|
|
97
|
+
return math.floor(x * (10 ** precision)) / (10 ** precision)
|
|
86
98
|
|
|
87
99
|
|
|
88
100
|
def args_to_object(args, obj):
|
|
@@ -187,7 +199,8 @@ def write_json(path, content, indent=1):
|
|
|
187
199
|
|
|
188
200
|
def convert_yolo_to_xywh(yolo_box):
|
|
189
201
|
"""
|
|
190
|
-
Converts a YOLO format bounding box
|
|
202
|
+
Converts a YOLO format bounding box [x_center, y_center, w, h] to
|
|
203
|
+
[x_min, y_min, width_of_box, height_of_box].
|
|
191
204
|
|
|
192
205
|
Args:
|
|
193
206
|
yolo_box (list): bounding box of format [x_center, y_center, width_of_box, height_of_box]
|
|
@@ -202,37 +215,21 @@ def convert_yolo_to_xywh(yolo_box):
|
|
|
202
215
|
return [x_min, y_min, width_of_box, height_of_box]
|
|
203
216
|
|
|
204
217
|
|
|
205
|
-
def
|
|
218
|
+
def convert_xywh_to_xyxy(api_box):
|
|
206
219
|
"""
|
|
207
|
-
Converts an xywh bounding box (the
|
|
208
|
-
|
|
220
|
+
Converts an xywh bounding box (the MD output format) to an xyxy bounding box (the format
|
|
221
|
+
produced by TF-based MD models).
|
|
209
222
|
|
|
210
223
|
Args:
|
|
211
|
-
api_box: bbox
|
|
224
|
+
api_box (list): bbox formatted as [x_min, y_min, width_of_box, height_of_box]
|
|
212
225
|
|
|
213
226
|
Returns:
|
|
214
|
-
list: bbox
|
|
227
|
+
list: bbox formatted as [x_min, y_min, x_max, y_max]
|
|
215
228
|
"""
|
|
216
|
-
|
|
229
|
+
|
|
217
230
|
x_min, y_min, width_of_box, height_of_box = api_box
|
|
218
231
|
x_max = x_min + width_of_box
|
|
219
232
|
y_max = y_min + height_of_box
|
|
220
|
-
return [y_min, x_min, y_max, x_max]
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
def convert_xywh_to_xyxy(api_bbox):
|
|
224
|
-
"""
|
|
225
|
-
Converts an xywh bounding box (the MD output format) to an xyxy bounding box.
|
|
226
|
-
|
|
227
|
-
Args:
|
|
228
|
-
api_bbox (list): bbox formatted as [x_min, y_min, width_of_box, height_of_box]
|
|
229
|
-
|
|
230
|
-
Returns:
|
|
231
|
-
list: bbox formatted as [x_min, y_min, x_max, y_max]
|
|
232
|
-
"""
|
|
233
|
-
|
|
234
|
-
x_min, y_min, width_of_box, height_of_box = api_bbox
|
|
235
|
-
x_max, y_max = x_min + width_of_box, y_min + height_of_box
|
|
236
233
|
return [x_min, y_min, x_max, y_max]
|
|
237
234
|
|
|
238
235
|
|
|
@@ -520,6 +517,52 @@ def invert_dictionary(d):
|
|
|
520
517
|
return {v: k for k, v in d.items()}
|
|
521
518
|
|
|
522
519
|
|
|
520
|
+
def round_floats_in_nested_dict(obj, decimal_places=5):
|
|
521
|
+
"""
|
|
522
|
+
Recursively rounds all floating point values in a nested structure to the
|
|
523
|
+
specified number of decimal places. Handles dictionaries, lists, tuples,
|
|
524
|
+
sets, and other iterables. Modifies mutable objects in place.
|
|
525
|
+
|
|
526
|
+
Args:
|
|
527
|
+
obj: The object to process (can be a dict, list, set, tuple, or primitive value)
|
|
528
|
+
decimal_places: Number of decimal places to round to (default: 5)
|
|
529
|
+
|
|
530
|
+
Returns:
|
|
531
|
+
The processed object (useful for recursive calls)
|
|
532
|
+
"""
|
|
533
|
+
if isinstance(obj, dict):
|
|
534
|
+
for key in obj:
|
|
535
|
+
obj[key] = round_floats_in_nested_dict(obj[key], decimal_places)
|
|
536
|
+
return obj
|
|
537
|
+
|
|
538
|
+
elif isinstance(obj, list):
|
|
539
|
+
for i in range(len(obj)):
|
|
540
|
+
obj[i] = round_floats_in_nested_dict(obj[i], decimal_places)
|
|
541
|
+
return obj
|
|
542
|
+
|
|
543
|
+
elif isinstance(obj, tuple):
|
|
544
|
+
# Tuples are immutable, so we create a new one
|
|
545
|
+
return tuple(round_floats_in_nested_dict(item, decimal_places) for item in obj)
|
|
546
|
+
|
|
547
|
+
elif isinstance(obj, set):
|
|
548
|
+
# Sets are mutable but we can't modify elements in-place
|
|
549
|
+
# Convert to list, process, and convert back to set
|
|
550
|
+
return set(round_floats_in_nested_dict(list(obj), decimal_places))
|
|
551
|
+
|
|
552
|
+
elif hasattr(obj, '__iter__') and not isinstance(obj, (str, bytes, bytearray)):
|
|
553
|
+
# Handle other iterable types - convert to list, process, and convert back
|
|
554
|
+
return type(obj)(round_floats_in_nested_dict(item, decimal_places) for item in obj)
|
|
555
|
+
|
|
556
|
+
elif isinstance(obj, float):
|
|
557
|
+
return round(obj, decimal_places)
|
|
558
|
+
|
|
559
|
+
else:
|
|
560
|
+
# For other types (int, str, bool, None, etc.), return as is
|
|
561
|
+
return obj
|
|
562
|
+
|
|
563
|
+
# ...def round_floats_in_nested_dict(...)
|
|
564
|
+
|
|
565
|
+
|
|
523
566
|
def image_file_to_camera_folder(image_fn):
|
|
524
567
|
r"""
|
|
525
568
|
Removes common overflow folders (e.g. RECNX101, RECNX102) from paths, i.e. turn:
|
|
@@ -706,7 +749,133 @@ def is_function_name(s,calling_namespace):
|
|
|
706
749
|
callable(locals().get(s)) or \
|
|
707
750
|
callable(calling_namespace.get(s)) or \
|
|
708
751
|
callable(getattr(builtins, s, None))
|
|
752
|
+
|
|
753
|
+
|
|
754
|
+
# From https://gist.github.com/fralau/061a4f6c13251367ef1d9a9a99fb3e8d
|
|
755
|
+
def parse_kvp(s,kv_separator='='):
|
|
756
|
+
"""
|
|
757
|
+
Parse a key/value pair, separated by [kv_separator]. Errors if s is not
|
|
758
|
+
a valid key/value pair string.
|
|
759
|
+
|
|
760
|
+
Args:
|
|
761
|
+
s (str): the string to parse
|
|
762
|
+
kv_separator (str, optional): the string separating keys from values.
|
|
763
|
+
|
|
764
|
+
Returns:
|
|
765
|
+
tuple: a 2-tuple formatted as (key,value)
|
|
766
|
+
"""
|
|
767
|
+
|
|
768
|
+
items = s.split(kv_separator)
|
|
769
|
+
assert len(items) > 1, 'Illegal key-value pair'
|
|
770
|
+
key = items[0].strip()
|
|
771
|
+
if len(items) > 1:
|
|
772
|
+
value = kv_separator.join(items[1:])
|
|
773
|
+
return (key, value)
|
|
774
|
+
|
|
775
|
+
|
|
776
|
+
def parse_kvp_list(items,kv_separator='=',d=None):
|
|
777
|
+
"""
|
|
778
|
+
Parse a list key-value pairs into a dictionary. If items is None or [],
|
|
779
|
+
returns {}.
|
|
780
|
+
|
|
781
|
+
Args:
|
|
782
|
+
items (list): the list of KVPs to parse
|
|
783
|
+
kv_separator (str, optional): the string separating keys from values.
|
|
784
|
+
d (dict, optional): the initial dictionary, defaults to {}
|
|
785
|
+
|
|
786
|
+
Returns:
|
|
787
|
+
dict: a dict mapping keys to values
|
|
788
|
+
"""
|
|
789
|
+
|
|
790
|
+
if d is None:
|
|
791
|
+
d = {}
|
|
792
|
+
|
|
793
|
+
if items is None or len(items) == 0:
|
|
794
|
+
return d
|
|
795
|
+
|
|
796
|
+
for item in items:
|
|
797
|
+
key, value = parse_kvp(item)
|
|
798
|
+
d[key] = value
|
|
799
|
+
|
|
800
|
+
return d
|
|
801
|
+
|
|
802
|
+
|
|
803
|
+
def dict_to_kvp_list(d,
|
|
804
|
+
item_separator=' ',
|
|
805
|
+
kv_separator='=',
|
|
806
|
+
non_string_value_handling='error'):
|
|
807
|
+
"""
|
|
808
|
+
Convert a string <--> string dict into a string containing list of list of
|
|
809
|
+
key-value pairs. I.e., converts {'a':'dog','b':'cat'} to 'a=dog b=cat'. If
|
|
810
|
+
d is None, returns None. If d is empty, returns ''.
|
|
811
|
+
|
|
812
|
+
Args:
|
|
813
|
+
d (dict): the dictionary to convert, must contain only strings
|
|
814
|
+
item_separator (str, optional): the delimiter between KV pairs
|
|
815
|
+
kv_separator (str, optional): the separator betweena a key and its value
|
|
816
|
+
non_string_value_handling (str, optional): what do do with non-string values,
|
|
817
|
+
can be "omit", "error", or "convert"
|
|
818
|
+
|
|
819
|
+
Returns:
|
|
820
|
+
str: the string representation of [d]
|
|
821
|
+
"""
|
|
822
|
+
|
|
823
|
+
if d is None:
|
|
824
|
+
return None
|
|
825
|
+
|
|
826
|
+
if len(d) == 0:
|
|
827
|
+
return ''
|
|
828
|
+
|
|
829
|
+
s = None
|
|
830
|
+
for k in d.keys():
|
|
831
|
+
assert isinstance(k,str), 'Input {} is not a str <--> str dict'.format(str(d))
|
|
832
|
+
v = d[k]
|
|
833
|
+
if not isinstance(v,str):
|
|
834
|
+
if non_string_value_handling == 'error':
|
|
835
|
+
raise ValueError('Input {} is not a str <--> str dict'.format(str(d)))
|
|
836
|
+
elif non_string_value_handling == 'omit':
|
|
837
|
+
continue
|
|
838
|
+
elif non_string_value_handling == 'convert':
|
|
839
|
+
v = str(v)
|
|
840
|
+
else:
|
|
841
|
+
raise ValueError('Unrecognized non_string_value_handling value: {}'.format(
|
|
842
|
+
non_string_value_handling))
|
|
843
|
+
if s is None:
|
|
844
|
+
s = ''
|
|
845
|
+
else:
|
|
846
|
+
s += item_separator
|
|
847
|
+
s += k + kv_separator + v
|
|
848
|
+
|
|
849
|
+
if s is None:
|
|
850
|
+
s = ''
|
|
851
|
+
|
|
852
|
+
return s
|
|
853
|
+
|
|
854
|
+
|
|
855
|
+
def parse_bool_string(s):
|
|
856
|
+
"""
|
|
857
|
+
Convert the strings "true" or "false" to boolean values. Case-insensitive, discards
|
|
858
|
+
leading and trailing whitespace. If s is already a bool, returns s.
|
|
859
|
+
|
|
860
|
+
Args:
|
|
861
|
+
s (str or bool): the string to parse, or the bool to return
|
|
709
862
|
|
|
863
|
+
Returns:
|
|
864
|
+
bool: the parsed value
|
|
865
|
+
"""
|
|
866
|
+
|
|
867
|
+
if isinstance(s,bool):
|
|
868
|
+
return s
|
|
869
|
+
s = s.lower().strip()
|
|
870
|
+
if s == 'true':
|
|
871
|
+
return True
|
|
872
|
+
elif s == 'false':
|
|
873
|
+
return False
|
|
874
|
+
else:
|
|
875
|
+
raise ValueError('Cannot parse bool from string {}'.format(str(s)))
|
|
876
|
+
|
|
877
|
+
|
|
878
|
+
#%% Test driver
|
|
710
879
|
|
|
711
880
|
def __module_test__():
|
|
712
881
|
"""
|
|
@@ -736,3 +905,25 @@ def __module_test__():
|
|
|
736
905
|
L = [{'a':5},{'a':0},{'a':10}]
|
|
737
906
|
k = 'a'
|
|
738
907
|
sort_list_of_dicts_by_key(L, k, reverse=True)
|
|
908
|
+
|
|
909
|
+
|
|
910
|
+
##%% Test float rounding
|
|
911
|
+
|
|
912
|
+
# Example with mixed collection types
|
|
913
|
+
data = {
|
|
914
|
+
"name": "Project X",
|
|
915
|
+
"values": [1.23456789, 2.3456789],
|
|
916
|
+
"tuple_values": (3.45678901, 4.56789012),
|
|
917
|
+
"set_values": {5.67890123, 6.78901234},
|
|
918
|
+
"metrics": {
|
|
919
|
+
"score": 98.7654321,
|
|
920
|
+
"components": [5.6789012, 6.7890123]
|
|
921
|
+
}
|
|
922
|
+
}
|
|
923
|
+
|
|
924
|
+
result = round_floats_in_nested_dict(data)
|
|
925
|
+
assert result['values'][0] == 1.23457
|
|
926
|
+
assert result['tuple_values'][0] == 3.45679
|
|
927
|
+
assert min(list(result['set_values'])) == 5.6789
|
|
928
|
+
|
|
929
|
+
|
|
@@ -17,9 +17,6 @@ import sys
|
|
|
17
17
|
import argparse
|
|
18
18
|
import re
|
|
19
19
|
|
|
20
|
-
import azure.common
|
|
21
|
-
from azure.storage.blob import BlobServiceClient, ContentSettings
|
|
22
|
-
|
|
23
20
|
from megadetector.utils.path_utils import is_image_file
|
|
24
21
|
|
|
25
22
|
|
|
@@ -139,6 +136,8 @@ def traverse_and_create_index(dir, sas_url=None, overwrite_files=False,
|
|
|
139
136
|
# If we want to set the content type in blob storage using a SAS URL
|
|
140
137
|
if sas_url:
|
|
141
138
|
|
|
139
|
+
from azure.storage.blob import BlobServiceClient, ContentSettings
|
|
140
|
+
|
|
142
141
|
# Example: sas_url = 'https://accname.blob.core.windows.net/bname/path/to/folder?st=...&se=...&sp=...&...'
|
|
143
142
|
if '?' in sas_url:
|
|
144
143
|
# 'https://accname.blob.core.windows.net/bname/path/to/folder' and 'st=...&se=...&sp=...&...'
|
|
@@ -196,6 +195,7 @@ def traverse_and_create_index(dir, sas_url=None, overwrite_files=False,
|
|
|
196
195
|
|
|
197
196
|
# Set content type in blob storage
|
|
198
197
|
if sas_url:
|
|
198
|
+
import azure.common
|
|
199
199
|
if container_folder:
|
|
200
200
|
output_blob_path = container_folder + '/' + output_file[len(dir) + 1:]
|
|
201
201
|
else:
|
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
"""
|
|
2
|
+
|
|
3
|
+
gpu_test.py
|
|
4
|
+
|
|
5
|
+
Simple script to verify CUDA availability, used to verify a CUDA environment
|
|
6
|
+
for TF or PyTorch
|
|
7
|
+
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
# Minimize TF printouts
|
|
11
|
+
import os
|
|
12
|
+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
|
|
13
|
+
|
|
14
|
+
try:
|
|
15
|
+
import logging
|
|
16
|
+
logging.getLogger('tensorflow').setLevel(logging.ERROR)
|
|
17
|
+
except Exception:
|
|
18
|
+
pass
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
#%% Torch/TF test functions
|
|
22
|
+
|
|
23
|
+
def torch_test():
|
|
24
|
+
"""
|
|
25
|
+
Print diagnostic information about Torch/CUDA status, including Torch/CUDA versions
|
|
26
|
+
and all available CUDA device names.
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
int: The number of CUDA devices reported by PyTorch.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
try:
|
|
33
|
+
import torch
|
|
34
|
+
except Exception as e: #noqa
|
|
35
|
+
print('PyTorch unavailable, not running PyTorch tests. PyTorch import error was:\n{}'.format(
|
|
36
|
+
str(e)))
|
|
37
|
+
return
|
|
38
|
+
|
|
39
|
+
print('Torch version: {}'.format(str(torch.__version__)))
|
|
40
|
+
print('CUDA available (according to PyTorch): {}'.format(torch.cuda.is_available()))
|
|
41
|
+
if torch.cuda.is_available():
|
|
42
|
+
print('CUDA version (according to PyTorch): {}'.format(torch.version.cuda))
|
|
43
|
+
print('CuDNN version (according to PyTorch): {}'.format(torch.backends.cudnn.version()))
|
|
44
|
+
|
|
45
|
+
device_ids = list(range(torch.cuda.device_count()))
|
|
46
|
+
|
|
47
|
+
if len(device_ids) > 0:
|
|
48
|
+
cuda_str = 'Found {} CUDA devices:'.format(len(device_ids))
|
|
49
|
+
print(cuda_str)
|
|
50
|
+
|
|
51
|
+
for device_id in device_ids:
|
|
52
|
+
device_name = 'unknown'
|
|
53
|
+
try:
|
|
54
|
+
device_name = torch.cuda.get_device_name(device=device_id)
|
|
55
|
+
except Exception as e: #noqa
|
|
56
|
+
pass
|
|
57
|
+
print('{}: {}'.format(device_id,device_name))
|
|
58
|
+
else:
|
|
59
|
+
print('No GPUs reported by PyTorch')
|
|
60
|
+
|
|
61
|
+
try:
|
|
62
|
+
if torch.backends.mps.is_built and torch.backends.mps.is_available():
|
|
63
|
+
print('PyTorch reports that Metal Performance Shaders are available')
|
|
64
|
+
except Exception:
|
|
65
|
+
pass
|
|
66
|
+
return len(device_ids)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def tf_test():
|
|
70
|
+
"""
|
|
71
|
+
Print diagnostic information about TF/CUDA status.
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
int: The number of CUDA devices reported by PyTorch.
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
try:
|
|
78
|
+
import tensorflow as tf
|
|
79
|
+
except Exception as e: #noqa
|
|
80
|
+
print('TensorFlow unavailable, not running TF tests. TF import error was:\n{}'.format(
|
|
81
|
+
str(e)))
|
|
82
|
+
return
|
|
83
|
+
|
|
84
|
+
from tensorflow.python.platform import build_info as build
|
|
85
|
+
print(f"TF version: {tf.__version__}")
|
|
86
|
+
|
|
87
|
+
if 'cuda_version' not in build.build_info:
|
|
88
|
+
print('TF does not appear to be built with CUDA')
|
|
89
|
+
else:
|
|
90
|
+
print(f"CUDA build version reported by TensorFlow: {build.build_info['cuda_version']}")
|
|
91
|
+
if 'cudnn_version' not in build.build_info:
|
|
92
|
+
print('TF does not appear to be built with CuDNN')
|
|
93
|
+
else:
|
|
94
|
+
print(f"CuDNN build version reported by TensorFlow: {build.build_info['cudnn_version']}")
|
|
95
|
+
|
|
96
|
+
try:
|
|
97
|
+
from tensorflow.python.compiler.tensorrt import trt_convert as trt
|
|
98
|
+
print("Linked TensorRT version: {}".format(trt.trt_utils._pywrap_py_utils.get_linked_tensorrt_version()))
|
|
99
|
+
except Exception:
|
|
100
|
+
print('Could not probe TensorRT version')
|
|
101
|
+
|
|
102
|
+
gpus = tf.config.list_physical_devices('GPU')
|
|
103
|
+
if gpus is None:
|
|
104
|
+
gpus = []
|
|
105
|
+
|
|
106
|
+
if len(gpus) > 0:
|
|
107
|
+
print('TensorFlow found the following GPUs:')
|
|
108
|
+
for gpu in gpus:
|
|
109
|
+
print(gpu.name)
|
|
110
|
+
|
|
111
|
+
else:
|
|
112
|
+
print('No GPUs reported by TensorFlow')
|
|
113
|
+
|
|
114
|
+
return len(gpus)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
#%% Command-line driver
|
|
118
|
+
|
|
119
|
+
if __name__ == '__main__':
|
|
120
|
+
|
|
121
|
+
print('*** Running Torch tests ***\n')
|
|
122
|
+
torch_test()
|
|
123
|
+
|
|
124
|
+
print('\n*** Running TF tests ***\n')
|
|
125
|
+
tf_test()
|