bplusplus 0.1.1__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bplusplus might be problematic. Click here for more details.

bplusplus/__init__.py CHANGED
@@ -1,3 +1,7 @@
1
- from .build_model import build_model
2
- from .collect_images import Group, collect_images
3
- from .train_validate import train_validate
1
+ from .collect import Group, collect
2
+ from .train_validate import train, validate
3
+ from .prepare import prepare
4
+ from .resnet.train import train_resnet
5
+ from .resnet.test import test_resnet
6
+ from .hierarchical.train import train_multitask
7
+ from .hierarchical.test import test_multitask
@@ -1,22 +1,41 @@
1
1
  import os
2
2
  import random
3
+ import threading
3
4
  from enum import Enum
4
- from typing import Any, Optional
5
-
5
+ from typing import Any, Optional, List, Dict
6
+ from tqdm import tqdm
7
+ import random
6
8
  import pygbif
7
9
  import requests
8
10
  import validators
9
11
 
10
-
11
12
  #this lists currently supported groupings, more can be added with proper testing
12
13
  class Group(str, Enum):
13
14
  scientificName="scientificName"
14
15
 
15
16
  #TODO add back support for fetching from dataset (or csvs)
16
- def collect_images(group_by_key: Group, search_parameters: dict[str, Any], images_per_group: int, output_directory: str):
17
+ def collect(group_by_key: Group, search_parameters: dict[str, Any], images_per_group: int, output_directory: str, num_threads: int):
17
18
 
18
19
  groups: list[str] = search_parameters[group_by_key.value]
19
20
 
21
+ # check if user wants to parallelize the process
22
+ if num_threads > 1:
23
+ __threaded_collect(
24
+ images_per_group=images_per_group,
25
+ output_directory=output_directory,
26
+ num_threads=num_threads,
27
+ groups=groups)
28
+ else:
29
+ __single_collect(
30
+ search_parameters=search_parameters,
31
+ images_per_group=images_per_group,
32
+ output_directory=output_directory,
33
+ group_by_key=group_by_key,
34
+ groups=groups,
35
+ )
36
+
37
+ def __single_collect(group_by_key: Group, search_parameters: dict[str, Any], images_per_group: int, output_directory: str, groups: list[str]):
38
+
20
39
  #TODO throw error if groups is not a str list
21
40
 
22
41
  __create_folders(
@@ -26,18 +45,18 @@ def collect_images(group_by_key: Group, search_parameters: dict[str, Any], image
26
45
 
27
46
  print("Beginning to collect images from GBIF...")
28
47
  for group in groups:
29
- print(f"Collecting images for {group}...")
48
+ # print(f"Collecting images for {group}...")
30
49
  occurrences_json = _fetch_occurrences(group_key=group_by_key, group_value=group, parameters=search_parameters, totalLimit=10000)
31
50
  optional_occurrences = map(lambda x: __parse_occurrence(x), occurrences_json)
32
51
  occurrences = list(filter(None, optional_occurrences))
33
52
 
34
- print(f"{group} : {len(occurrences)} parseable occurrences fetched, will sample for {images_per_group}")
53
+ # print(f"{group} : {len(occurrences)} parseable occurrences fetched, will sample for {images_per_group}")
35
54
 
36
55
  random.seed(42) # for reproducibility
37
56
  sampled_occurrences = random.sample(occurrences, min(images_per_group, len(occurrences)))
38
57
 
39
58
  print(f"Downloading {len(sampled_occurrences)} images into the {group} folder...")
40
- for occurrence in sampled_occurrences:
59
+ for occurrence in tqdm(sampled_occurrences, desc=f"Downloading images for {group}", unit="image"):
41
60
  # image_url = occurrence.image_url.replace("original", "large") # hack to get max 1024px image
42
61
 
43
62
  __down_image(
@@ -49,6 +68,34 @@ def collect_images(group_by_key: Group, search_parameters: dict[str, Any], image
49
68
 
50
69
  print("Finished collecting images.")
51
70
 
71
+ # threaded_collect: paralellize the collection of images
72
+ def __threaded_collect(images_per_group: int, output_directory: str, num_threads: int, groups: list[str]):
73
+ # Divide the species list into num_threads parts
74
+ chunk_size = len(groups) // num_threads
75
+ species_chunks = [
76
+ groups[i:i + chunk_size] for i in range(0, len(groups), chunk_size)
77
+ ]
78
+
79
+ # Ensure we have exactly num_threads chunks (the last chunk might be larger if len(species_list) % num_threads != 0)
80
+ while len(species_chunks) < num_threads:
81
+ species_chunks.append([])
82
+
83
+ threads = []
84
+ for i, chunk in enumerate(species_chunks):
85
+ thread = threading.Thread(
86
+ target=__collect_subset,
87
+ args=(chunk, images_per_group, output_directory, i)
88
+ )
89
+ threads.append(thread)
90
+ thread.start()
91
+
92
+ # Wait for all threads to complete
93
+ for thread in threads:
94
+ thread.join()
95
+
96
+ print("All collection threads have finished.")
97
+
98
+
52
99
  def _fetch_occurrences(group_key: str, group_value: str, parameters: dict[str, Any], totalLimit: int) -> list[dict[str, Any]]:
53
100
  parameters[group_key] = group_value
54
101
  return __next_batch(
@@ -98,6 +145,23 @@ def __create_folders(names: list[str], directory: str):
98
145
  # Create a folder using the group name
99
146
  os.makedirs(folder_name, exist_ok=True)
100
147
 
148
+ def __collect_subset(species_subset: List[str], images_per_group: int, output_directory: str, thread_id: int):
149
+ search_subset: Dict[str, Any] = {
150
+ "scientificName": species_subset
151
+ }
152
+
153
+ print(f"Thread {thread_id} starting collection for {len(species_subset)} species.")
154
+
155
+ __single_collect(
156
+ search_parameters=search_subset,
157
+ images_per_group=images_per_group,
158
+ output_directory=output_directory,
159
+ group_by_key=Group.scientificName,
160
+ groups=species_subset
161
+ )
162
+
163
+ print(f"Thread {thread_id} finished collection.")
164
+
101
165
 
102
166
 
103
167