megadetector 10.0.1__py3-none-any.whl → 10.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of megadetector might be problematic. Click here for more details.

@@ -186,7 +186,8 @@ def merge_detections(source_files,target_file,output_file,options=None):
186
186
 
187
187
  image_filename = source_im['file']
188
188
 
189
- assert image_filename in fn_to_image, 'Image {} not in target image set'.format(image_filename)
189
+ assert image_filename in fn_to_image, \
190
+ 'Image {} not in target image set'.format(image_filename)
190
191
  target_im = fn_to_image[image_filename]
191
192
 
192
193
  if 'detections' not in source_im or source_im['detections'] is None:
@@ -294,10 +295,15 @@ def merge_detections(source_files,target_file,output_file,options=None):
294
295
 
295
296
  print('Saved merged results to {}'.format(output_file))
296
297
 
298
+ # ...def merge_detections(...)
299
+
297
300
 
298
301
  #%% Command-line driver
299
302
 
300
- def main(): # noqa
303
+ def main():
304
+ """
305
+ Command-line driver for merge_detections.py
306
+ """
301
307
 
302
308
  default_options = MergeDetectionsOptions()
303
309
 
@@ -305,7 +311,7 @@ def main(): # noqa
305
311
  description='Merge detections from one or more MegaDetector results files into an existing reuslts file')
306
312
  parser.add_argument(
307
313
  'source_files',
308
- nargs="+",
314
+ nargs='+',
309
315
  help='Path to source .json file(s) to merge from')
310
316
  parser.add_argument(
311
317
  'target_file',
@@ -325,11 +331,11 @@ def main(): # noqa
325
331
  default=default_options.min_detection_size,
326
332
  type=float,
327
333
  help='Ignore detections with an area smaller than this (as a fraction of ' + \
328
- 'image size) (default {})'.format(
329
- default_options.min_detection_size))
334
+ 'image size) (default {})'.format(
335
+ default_options.min_detection_size))
330
336
  parser.add_argument(
331
337
  '--source_confidence_thresholds',
332
- nargs="+",
338
+ nargs='+',
333
339
  type=float,
334
340
  default=default_options.source_confidence_thresholds,
335
341
  help='List of thresholds for each source file (default {}). '.format(
@@ -340,19 +346,18 @@ def main(): # noqa
340
346
  '--target_confidence_threshold',
341
347
  type=float,
342
348
  default=default_options.target_confidence_threshold,
343
- help='Don\'t merge if target file\'s detection confidence is already higher ' + \
344
- 'than this (default {}). '.format(
345
- default_options.target_confidence_threshold))
349
+ help="Do not merge if target file detection confidence is already higher " + \
350
+ "than this (default {})".format(default_options.target_confidence_threshold))
346
351
  parser.add_argument(
347
352
  '--categories_to_include',
348
353
  type=int,
349
- nargs="+",
354
+ nargs='+',
350
355
  default=None,
351
356
  help='List of numeric detection category IDs to include')
352
357
  parser.add_argument(
353
358
  '--categories_to_exclude',
354
359
  type=int,
355
- nargs="+",
360
+ nargs='+',
356
361
  default=None,
357
362
  help='List of numeric detection categories to include')
358
363
  parser.add_argument(
@@ -387,4 +392,3 @@ def main(): # noqa
387
392
 
388
393
  if __name__ == '__main__':
389
394
  main()
390
-
@@ -172,7 +172,9 @@ def validate_batch_results(json_filename,options=None):
172
172
  file = im['file']
173
173
 
174
174
  if 'detections' in im and im['detections'] is not None:
175
+
175
176
  for det in im['detections']:
177
+
176
178
  assert 'category' in det, 'Image {} has a detection with no category'.format(file)
177
179
  assert 'conf' in det, 'Image {} has a detection with no confidence'.format(file)
178
180
  assert isinstance(det['conf'],float), \
@@ -182,6 +184,21 @@ def validate_batch_results(json_filename,options=None):
182
184
  'Image {} has a detection with an unmapped category {}'.format(
183
185
  file,det['category'])
184
186
 
187
+ if 'classifications' in det and det['classifications'] is not None:
188
+ for c in det['classifications']:
189
+ assert isinstance(c[0],str), \
190
+ 'Image {} has an illegal classification category: {}'.format(file,c[0])
191
+ try:
192
+ _ = int(c[0])
193
+ except Exception:
194
+ raise ValueError('Image {} has an illegal classification category: {}'.format(
195
+ file,c[0]))
196
+ assert isinstance(c[1],float) or isinstance(c[1], int)
197
+
198
+ # ...for each detection
199
+
200
+ # ...if this image has a detections field
201
+
185
202
  if options.check_image_existence:
186
203
 
187
204
  if options.relative_path_base is None:
@@ -207,13 +224,19 @@ def validate_batch_results(json_filename,options=None):
207
224
  if not isinstance(im['detections'],list):
208
225
  raise ValueError('Invalid detections list for image {}'.format(im['file']))
209
226
 
227
+ if is_video_file(im['file']) and (format_version >= 1.5):
228
+
229
+ if 'frames_processed' not in im:
230
+ raise ValueError('Video without frames_processed field: {}'.format(im['file']))
231
+
210
232
  if is_video_file(im['file']) and (format_version >= 1.4):
211
233
 
212
234
  if 'frame_rate' not in im:
213
235
  raise ValueError('Video without frame rate: {}'.format(im['file']))
214
236
  if im['frame_rate'] < 0:
215
- raise ValueError('Video with illegal frame rate {}: {}'.format(
216
- str(im['frame_rate']),im['file']))
237
+ if 'failure' not in im:
238
+ raise ValueError('Video with illegal frame rate {}: {}'.format(
239
+ str(im['frame_rate']),im['file']))
217
240
  if 'detections' in im and im['detections'] is not None:
218
241
  for det in im['detections']:
219
242
  if 'frame_number' not in det:
File without changes
@@ -0,0 +1,335 @@
1
+ """
2
+
3
+ Test script for validating NMS functionality with synthetic data.
4
+
5
+ This script creates synthetic detection scenarios where we know exactly which
6
+ boxes should be suppressed by NMS, allowing us to verify the correctness of
7
+ the NMS implementation.
8
+
9
+ This is an AI-generated test module.
10
+
11
+ """
12
+
13
+
14
+ #%% Imports
15
+
16
+ import torch
17
+
18
+ from megadetector.detection.pytorch_detector import nms
19
+
20
+
21
+ #%% Support functions
22
+
23
+ def calculate_iou_boxes(box1, box2):
24
+ """
25
+ Calculate IoU between two boxes in [x1, y1, x2, y2] format.
26
+
27
+ Args:
28
+ box1: torch.Tensor or list of [x1, y1, x2, y2]
29
+ box2: torch.Tensor or list of [x1, y1, x2, y2]
30
+
31
+ Returns:
32
+ float: IoU value between 0 and 1
33
+ """
34
+
35
+ if isinstance(box1, list):
36
+ box1 = torch.tensor(box1, dtype=torch.float)
37
+ if isinstance(box2, list):
38
+ box2 = torch.tensor(box2, dtype=torch.float)
39
+
40
+ # Calculate intersection area
41
+ x1_inter = max(box1[0], box2[0])
42
+ y1_inter = max(box1[1], box2[1])
43
+ x2_inter = min(box1[2], box2[2])
44
+ y2_inter = min(box1[3], box2[3])
45
+
46
+ if x2_inter <= x1_inter or y2_inter <= y1_inter:
47
+ return 0.0
48
+
49
+ intersection = (x2_inter - x1_inter) * (y2_inter - y1_inter)
50
+
51
+ # Calculate union area
52
+ area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
53
+ area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
54
+ union = area1 + area2 - intersection
55
+
56
+ return float(intersection / union) if union > 0 else 0.0
57
+
58
+
59
+ def create_synthetic_predictions():
60
+ """
61
+ Create synthetic model predictions for testing NMS.
62
+
63
+ Returns:
64
+ torch.Tensor: Synthetic predictions in the format expected by the NMS function
65
+ Shape: [batch_size=1, num_anchors, num_classes + 5]
66
+
67
+ Test scenarios:
68
+ 1. Two highly overlapping boxes (IoU > 0.5) with different confidences - higher confidence should win
69
+ 2. Two boxes with low overlap (IoU < 0.5) - both should be kept
70
+ 3. Multiple boxes of different classes in same location - should be kept (class-independent NMS)
71
+ 4. Three overlapping boxes with cascading confidences - highest confidence should win
72
+ """
73
+
74
+ # We'll create predictions for a 640x640 image with 3 classes
75
+ # Format: [x_center, y_center, width, height, objectness, class0_conf, class1_conf, class2_conf]
76
+
77
+ synthetic_boxes = []
78
+
79
+ # Scenario 1: Two highly overlapping boxes (IoU > 0.8)
80
+ # Box A: center=(100, 100), size=80x80, high confidence for class 0
81
+ # Box B: center=(105, 105), size=80x80, low confidence for class 0 (smaller offset = higher IoU)
82
+ # Expected: Box A kept, Box B suppressed
83
+ synthetic_boxes.append([100, 100, 80, 80, 0.9, 0.8, 0.1, 0.1]) # Box A - should be kept
84
+ synthetic_boxes.append([105, 105, 80, 80, 0.9, 0.5, 0.1, 0.1]) # Box B - should be suppressed
85
+
86
+ # Scenario 1b: Two nearly identical boxes (IoU ≈ 0.95)
87
+ # Box A2: center=(200, 100), size=60x60, high confidence for class 0
88
+ # Box B2: center=(202, 102), size=60x60, lower confidence for class 0
89
+ # Expected: Box A2 kept, Box B2 suppressed
90
+ synthetic_boxes.append([200, 100, 60, 60, 0.9, 0.9, 0.05, 0.05]) # Box A2 - should be kept
91
+ synthetic_boxes.append([202, 102, 60, 60, 0.9, 0.7, 0.1, 0.1]) # Box B2 - should be suppressed
92
+
93
+ # Scenario 2: Two boxes with low overlap (IoU ≈ 0.1)
94
+ # Box C: center=(300, 100), size=60x60, class 0
95
+ # Box D: center=(380, 100), size=60x60, class 0
96
+ # Expected: Both kept
97
+ synthetic_boxes.append([300, 100, 60, 60, 0.9, 0.7, 0.1, 0.1]) # Box C - should be kept
98
+ synthetic_boxes.append([380, 100, 60, 60, 0.9, 0.6, 0.1, 0.1]) # Box D - should be kept
99
+
100
+ # Scenario 3: Same location, different classes
101
+ # Box E: center=(100, 300), size=70x70, class 0
102
+ # Box F: center=(100, 300), size=70x70, class 1
103
+ # Expected: Both kept (class-independent NMS)
104
+ synthetic_boxes.append([100, 300, 70, 70, 0.9, 0.7, 0.1, 0.1]) # Box E - class 0, should be kept
105
+ synthetic_boxes.append([100, 300, 70, 70, 0.9, 0.1, 0.7, 0.1]) # Box F - class 1, should be kept
106
+
107
+ # Scenario 4: Three cascading overlapping boxes
108
+ # Box G: center=(500, 300), size=80x80, highest confidence
109
+ # Box H: center=(510, 310), size=80x80, medium confidence
110
+ # Box I: center=(520, 320), size=80x80, lowest confidence
111
+ # Expected: Only Box G kept
112
+ synthetic_boxes.append([500, 300, 80, 80, 0.95, 0.9, 0.05, 0.05]) # Box G - highest conf, should be kept
113
+ synthetic_boxes.append([510, 310, 80, 80, 0.9, 0.7, 0.1, 0.1]) # Box H - should be suppressed
114
+ synthetic_boxes.append([520, 320, 80, 80, 0.85, 0.6, 0.15, 0.15]) # Box I - should be suppressed
115
+
116
+ # Add some low-confidence boxes that should be filtered out before NMS
117
+ synthetic_boxes.append([200, 500, 50, 50, 0.1, 0.05, 0.02, 0.03]) # Too low confidence
118
+
119
+ # Convert to tensor format expected by NMS function
120
+ # We need to pad to a reasonable number of anchors (let's use 20)
121
+ num_anchors = 20
122
+ num_classes = 3
123
+
124
+ predictions = torch.zeros(1, num_anchors, num_classes + 5) # batch_size=1
125
+
126
+ # Fill in our synthetic boxes
127
+ for i, box_data in enumerate(synthetic_boxes):
128
+ if i < num_anchors:
129
+ predictions[0, i, :] = torch.tensor(box_data)
130
+
131
+ return predictions
132
+
133
+
134
+ #%% Main test function
135
+
136
+ def test_nms_functionality():
137
+ """
138
+ Test the NMS function with synthetic data to verify correct behavior.
139
+ """
140
+
141
+ print("Testing NMS functionality with synthetic data...")
142
+
143
+ # Generate synthetic predictions
144
+ predictions = create_synthetic_predictions()
145
+ print(f"Created synthetic predictions with shape: {predictions.shape}")
146
+
147
+ # Run NMS with IoU threshold = 0.5 and confidence threshold = 0.3
148
+ results = nms(predictions, conf_thres=0.3, iou_thres=0.5, max_det=300)
149
+
150
+ print(f"NMS returned {len(results)} batch results")
151
+ detections = results[0] # Get results for first (and only) image
152
+ print(f"Number of detections after NMS: {detections.shape[0]}")
153
+
154
+ assert detections.shape[0] != 0
155
+
156
+ print("\nDetections after NMS:")
157
+ print("Format: [x1, y1, x2, y2, confidence, class_id]")
158
+ for i, det in enumerate(detections):
159
+ x1, y1, x2, y2, conf, cls = det
160
+ center_x = (x1 + x2) / 2
161
+ center_y = (y1 + y2) / 2
162
+ width = x2 - x1
163
+ height = y2 - y1
164
+ print(f"Detection {i}: center=({center_x:.1f}, {center_y:.1f}), "
165
+ f"size={width:.1f}x{height:.1f}, conf={conf:.3f}, class={int(cls)}")
166
+
167
+ # Verify expected results
168
+
169
+ # Verify that high-confidence boxes are kept over low-confidence overlapping ones
170
+ # Look for the scenario 1 boxes (around center 100,100 area)
171
+ scenario1_boxes = []
172
+ for i, det in enumerate(detections):
173
+ x1, y1, x2, y2, conf, cls = det
174
+ center_x = (x1 + x2) / 2
175
+ center_y = (y1 + y2) / 2
176
+ if 80 <= center_x <= 130 and 80 <= center_y <= 130 and int(cls) == 0:
177
+ scenario1_boxes.append((i, center_x, center_y, conf))
178
+
179
+ # Check scenario 1b (around center 200,100 area)
180
+ scenario1b_boxes = []
181
+ for i, det in enumerate(detections):
182
+ x1, y1, x2, y2, conf, cls = det
183
+ center_x = (x1 + x2) / 2
184
+ center_y = (y1 + y2) / 2
185
+ if 180 <= center_x <= 220 and 80 <= center_y <= 120 and int(cls) == 0:
186
+ scenario1b_boxes.append((i, center_x, center_y, conf))
187
+
188
+ # Both scenario 1 and 1b should have exactly 1 detection each
189
+ total_high_overlap_boxes = len(scenario1_boxes) + len(scenario1b_boxes)
190
+ if total_high_overlap_boxes != 2:
191
+ print("Error: expected 2 detections in high-overlap scenarios (1 each), got {}".format(
192
+ total_high_overlap_boxes
193
+ ))
194
+ print(f" Scenario 1: {len(scenario1_boxes)} boxes")
195
+ print(f" Scenario 1b: {len(scenario1b_boxes)} boxes")
196
+ raise AssertionError()
197
+ # Should be the high-confidence box (0.8 * 0.9 = 0.72)
198
+ elif len(scenario1_boxes) == 1 and scenario1_boxes[0][3] < 0.7:
199
+ print("Error: wrong box kept in scenario 1. Expected conf > 0.7, got {}".format(
200
+ scenario1_boxes[0][3]
201
+ ))
202
+ raise AssertionError()
203
+ # Should be the high-confidence box (0.9 * 0.9 = 0.81)
204
+ elif len(scenario1b_boxes) == 1 and scenario1b_boxes[0][3] < 0.8:
205
+ print("Error: wrong box kept in scenario 1b. Expected conf > 0.8, got {}".format(
206
+ scenario1b_boxes[0][3]
207
+ ))
208
+ raise AssertionError()
209
+ else:
210
+ print("Scenarios 1 & 1b passed: High-confidence boxes kept, low-confidence overlapping boxes suppressed")
211
+
212
+ # Verify IoU calculations and ensure suppression actually works
213
+ if len(scenario1_boxes) == 1 and len(scenario1b_boxes) == 1:
214
+ # Calculate what the IoU would have been between the boxes that were supposed to overlap
215
+ # Scenario 1: Box A (100,100,80x80) vs Box B (105,105,80x80)
216
+ box_a = [100-40, 100-40, 100+40, 100+40] # Convert center+size to corners
217
+ box_b = [105-40, 105-40, 105+40, 105+40]
218
+ iou_1 = calculate_iou_boxes(box_a, box_b)
219
+
220
+ # Scenario 1b: Box A2 (200,100,60x60) vs Box B2 (202,102,60x60)
221
+ box_a2 = [200-30, 100-30, 200+30, 100+30]
222
+ box_b2 = [202-30, 102-30, 202+30, 102+30]
223
+ iou_1b = calculate_iou_boxes(box_a2, box_b2)
224
+
225
+ print(f" Theoretical IoU for scenario 1 boxes: {iou_1:.3f}")
226
+ print(f" Theoretical IoU for scenario 1b boxes: {iou_1b:.3f}")
227
+
228
+ # If IoU > threshold, suppression should have occurred
229
+ if iou_1 <= 0.5:
230
+ print(f"Error: scenario 1 IoU {iou_1:.3f} is too low - test setup is invalid!")
231
+ raise AssertionError()
232
+ elif iou_1b <= 0.5:
233
+ print(f"Error: scenario 1b IoU {iou_1b:.3f} is too low - test setup is invalid!")
234
+ raise AssertionError()
235
+ else:
236
+ print(" High IoU confirmed - suppression was correct")
237
+
238
+ # Verify scenario 2 - both non-overlapping boxes should be kept
239
+ scenario2_boxes = []
240
+ for i, det in enumerate(detections):
241
+ x1, y1, x2, y2, conf, cls = det
242
+ center_x = (x1 + x2) / 2
243
+ center_y = (y1 + y2) / 2
244
+ if 270 <= center_x <= 410 and 70 <= center_y <= 130 and int(cls) == 0:
245
+ scenario2_boxes.append((i, center_x, center_y, conf))
246
+
247
+ if len(scenario2_boxes) != 2:
248
+ print(f"Error: expected 2 detections in scenario 2 area, got {len(scenario2_boxes)}")
249
+ raise AssertionError()
250
+ else:
251
+ print("Scenario 2 passed: Both non-overlapping boxes kept")
252
+
253
+ # Verify scenario 3 - different classes should both be kept
254
+ scenario3_boxes = []
255
+ for i, det in enumerate(detections):
256
+ x1, y1, x2, y2, conf, cls = det
257
+ center_x = (x1 + x2) / 2
258
+ center_y = (y1 + y2) / 2
259
+ if 65 <= center_x <= 135 and 265 <= center_y <= 335:
260
+ scenario3_boxes.append((i, center_x, center_y, conf, int(cls)))
261
+
262
+ classes_found = set(box[4] for box in scenario3_boxes)
263
+ if (len(scenario3_boxes) != 2) or (len(classes_found) != 2):
264
+ print("Error: expected 2 detections of different classes , got {} detections of classes {}".format(
265
+ len(scenario3_boxes),classes_found
266
+ ))
267
+ raise AssertionError()
268
+ else:
269
+ print("Scenario 3 passed: Both different-class boxes kept")
270
+
271
+ # Verify scenario 4 - cascading overlapping boxes (only highest confidence should remain)
272
+ scenario4_boxes = []
273
+ for i, det in enumerate(detections):
274
+ x1, y1, x2, y2, conf, cls = det
275
+ center_x = (x1 + x2) / 2
276
+ center_y = (y1 + y2) / 2
277
+ if 460 <= center_x <= 560 and 260 <= center_y <= 360 and int(cls) == 0:
278
+ scenario4_boxes.append((i, center_x, center_y, conf))
279
+
280
+ print(f"\nScenario 4 analysis: Found {len(scenario4_boxes)} boxes in cascading area:")
281
+ for i, (det_idx, cx, cy, conf) in enumerate(scenario4_boxes):
282
+ print(f" Box {i}: center=({cx:.1f}, {cy:.1f}), conf={conf:.3f}")
283
+
284
+ # Check IoU between remaining boxes to ensure proper suppression
285
+ if len(scenario4_boxes) >= 2:
286
+ max_iou = 0
287
+ for i in range(len(scenario4_boxes)):
288
+ for j in range(i+1, len(scenario4_boxes)):
289
+ det1 = detections[scenario4_boxes[i][0]]
290
+ det2 = detections[scenario4_boxes[j][0]]
291
+ iou = calculate_iou_boxes(det1[:4], det2[:4])
292
+ print(f" IoU between box {i} and box {j}: {iou:.3f}")
293
+ max_iou = max(max_iou, iou)
294
+
295
+ if len(scenario4_boxes) == 1:
296
+ print("Scenario 4 passed: Only highest confidence box kept")
297
+ else:
298
+ # This is only OK if IoU < threshold
299
+ if max_iou < 0.5: # Our IoU threshold
300
+ print("Scenario 4 passed: Multiple boxes kept due to low IoU (< 0.5)")
301
+ else:
302
+ print(f"ERROR: Scenario 4 failed - boxes with IoU {max_iou:.3f} > 0.5 were not suppressed!")
303
+ raise AssertionError()
304
+
305
+ # Create a scenario that requires IoU calculation
306
+ print("\n=== COMPREHENSIVE IoU VALIDATION TEST ===")
307
+
308
+ # Create two identical boxes that should definitely be suppressed
309
+ identical_box_a = [100, 100, 50, 50, 0.9, 0.9, 0.05, 0.05] # High confidence
310
+ identical_box_b = [100, 100, 50, 50, 0.9, 0.7, 0.1, 0.1] # Lower confidence
311
+
312
+ test_predictions = torch.zeros(1, 5, 8) # Small batch for focused test
313
+ test_predictions[0, 0, :] = torch.tensor(identical_box_a)
314
+ test_predictions[0, 1, :] = torch.tensor(identical_box_b)
315
+
316
+ # Run NMS on this simple case
317
+ test_results = nms(test_predictions, conf_thres=0.3, iou_thres=0.5, max_det=300)
318
+ test_detections = test_results[0]
319
+
320
+ print(f"Identical boxes test: Input 2 identical boxes, got {test_detections.shape[0]} detections")
321
+
322
+ if test_detections.shape[0] != 1:
323
+ print(f"Error Two identical boxes should result in 1 detection, got {test_detections.shape[0]}")
324
+ raise AssertionError()
325
+ else:
326
+ # Verify it kept the higher confidence box
327
+ kept_conf = test_detections[0, 4].item()
328
+ expected_conf = 0.9 * 0.9 # objectness * class_conf
329
+ if abs(kept_conf - expected_conf) > 0.01:
330
+ print(f"ERROR: Wrong box kept. Expected conf ≈ {expected_conf:.3f}, got {kept_conf:.3f}")
331
+ raise AssertionError()
332
+ else:
333
+ print("Identical boxes test passed: Higher confidence box kept")
334
+
335
+ print("\nNMS tests passed")
@@ -983,13 +983,15 @@ def dict_to_kvp_list(d,
983
983
  return s
984
984
 
985
985
 
986
- def parse_bool_string(s):
986
+ def parse_bool_string(s, strict=False):
987
987
  """
988
988
  Convert the strings "true" or "false" to boolean values. Case-insensitive, discards
989
989
  leading and trailing whitespace. If s is already a bool, returns s.
990
990
 
991
991
  Args:
992
992
  s (str or bool): the string to parse, or the bool to return
993
+ strict (bool, optional): only allow "true" or "false", otherwise
994
+ handles "1", "0", "yes", and "no".
993
995
 
994
996
  Returns:
995
997
  bool: the parsed value
@@ -997,10 +999,17 @@ def parse_bool_string(s):
997
999
 
998
1000
  if isinstance(s,bool):
999
1001
  return s
1000
- s = s.lower().strip()
1001
- if s == 'true':
1002
+ s = str(s).lower().strip()
1003
+
1004
+ if strict:
1005
+ false_strings = ('false')
1006
+ true_strings = ('true')
1007
+ else:
1008
+ false_strings = ('no', 'false', 'f', 'n', '0')
1009
+ true_strings = ('yes', 'true', 't', 'y', '1')
1010
+ if s in true_strings:
1002
1011
  return True
1003
- elif s == 'false':
1012
+ elif s in false_strings:
1004
1013
  return False
1005
1014
  else:
1006
1015
  raise ValueError('Cannot parse bool from string {}'.format(str(s)))
@@ -1649,6 +1658,8 @@ def test_string_parsing():
1649
1658
  assert not parse_bool_string("false")
1650
1659
  assert not parse_bool_string("False")
1651
1660
  assert not parse_bool_string(" FALSE ")
1661
+ assert parse_bool_string("1", strict=False)
1662
+ assert not parse_bool_string("0", strict=False)
1652
1663
  assert parse_bool_string(True) is True # Test with existing bool
1653
1664
  assert parse_bool_string(False) is False
1654
1665
  try:
@@ -1657,7 +1668,7 @@ def test_string_parsing():
1657
1668
  except ValueError:
1658
1669
  pass
1659
1670
  try:
1660
- parse_bool_string("1") # Should not parse to True
1671
+ parse_bool_string("1",strict=True)
1661
1672
  raise AssertionError("ValueError not raised for '1'")
1662
1673
  except ValueError:
1663
1674
  pass