mpcontribs-client 5.7.0__py3-none-any.whl → 5.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -74,7 +74,7 @@ DEFAULT_HOST = "contribs-api.materialsproject.org"
74
74
  BULMA = "is-narrow is-fullwidth has-background-light"
75
75
  PROVIDERS = {"github", "google", "facebook", "microsoft", "amazon"}
76
76
  COMPONENTS = ["structures", "tables", "attachments"] # using list to maintain order
77
- SUBDOMAINS = ["contribs", "lightsources", "ml", "micro"]
77
+ SUBDOMAINS = ["contribs", "ml", "micro"]
78
78
  PORTS = [5000, 5002, 5003, 5005, 10000, 10002, 10003, 10005, 20000]
79
79
  HOSTS = ["localhost", "contribs-apis"]
80
80
  HOSTS += [f"192.168.0.{i}" for i in range(36, 47)] # PrivateSubnetOne
@@ -82,7 +82,8 @@ HOSTS += [f"192.168.0.{i}" for i in range(52, 63)] # PrivateSubnetTwo
82
82
  VALID_URLS = {f"http://{h}:{p}" for p in PORTS for h in HOSTS}
83
83
  VALID_URLS |= {
84
84
  f"https://{n}-api{m}.materialsproject.org"
85
- for n in SUBDOMAINS for m in ["", "-preview"]
85
+ for n in SUBDOMAINS
86
+ for m in ["", "-preview"]
86
87
  }
87
88
  VALID_URLS |= {f"http://localhost.{n}-api.materialsproject.org" for n in SUBDOMAINS}
88
89
  SUPPORTED_FILETYPES = (Gz, Jpeg, Png, Gif, Tiff)
@@ -113,7 +114,7 @@ ureg.define("electron_mass = 9.1093837015e-31 kg = mₑ = m_e")
113
114
  LOG_LEVEL = os.environ.get("MPCONTRIBS_CLIENT_LOG_LEVEL", "INFO")
114
115
  log_level = getattr(logging, LOG_LEVEL.upper())
115
116
  _session = requests.Session()
116
- _ipython = sys.modules['IPython'].get_ipython()
117
+ _ipython = sys.modules["IPython"].get_ipython()
117
118
 
118
119
 
119
120
  class LogFilter(logging.Filter):
@@ -127,14 +128,14 @@ class LogFilter(logging.Filter):
127
128
 
128
129
  class CustomLoggerAdapter(logging.LoggerAdapter):
129
130
  def process(self, msg, kwargs):
130
- prefix = self.extra.get('prefix')
131
+ prefix = self.extra.get("prefix")
131
132
  return f"[{prefix}] {msg}" if prefix else msg, kwargs
132
133
 
133
134
 
134
135
  class TqdmToLogger(io.StringIO):
135
136
  logger = None
136
137
  level = None
137
- buf = ''
138
+ buf = ""
138
139
 
139
140
  def __init__(self, logger, level=None):
140
141
  super(TqdmToLogger, self).__init__()
@@ -142,7 +143,7 @@ class TqdmToLogger(io.StringIO):
142
143
  self.level = level or logging.INFO
143
144
 
144
145
  def write(self, buf):
145
- self.buf = buf.strip('\r\n\t ')
146
+ self.buf = buf.strip("\r\n\t ")
146
147
 
147
148
  def flush(self):
148
149
  self.logger.log(self.level, self.buf)
@@ -177,7 +178,9 @@ def get_md5(d):
177
178
 
178
179
  def validate_email(email_string):
179
180
  if email_string.count(":") != 1:
180
- raise SwaggerValidationError(f"{email_string} not of format <provider>:<email>.")
181
+ raise SwaggerValidationError(
182
+ f"{email_string} not of format <provider>:<email>."
183
+ )
181
184
 
182
185
  provider, email = email_string.split(":", 1)
183
186
  if provider not in PROVIDERS:
@@ -204,7 +207,11 @@ def validate_url(url_string, qualifying=("scheme", "netloc")):
204
207
 
205
208
 
206
209
  url_format = SwaggerFormat(
207
- format="url", to_wire=str, to_python=str, validate=validate_url, description="URL",
210
+ format="url",
211
+ to_wire=str,
212
+ to_python=str,
213
+ validate=validate_url,
214
+ description="URL",
208
215
  )
209
216
  bravado_config_dict = {
210
217
  "validate_responses": False,
@@ -235,23 +242,26 @@ def _compress(data):
235
242
 
236
243
 
237
244
  def get_session(session=None):
238
- adapter_kwargs = dict(max_retries=Retry(
239
- total=RETRIES,
240
- read=RETRIES,
241
- connect=RETRIES,
242
- respect_retry_after_header=True,
243
- status_forcelist=[429, 502], # rate limit
244
- allowed_methods={'DELETE', 'GET', 'PUT', 'POST'},
245
- backoff_factor=2
246
- ))
245
+ adapter_kwargs = dict(
246
+ max_retries=Retry(
247
+ total=RETRIES,
248
+ read=RETRIES,
249
+ connect=RETRIES,
250
+ respect_retry_after_header=True,
251
+ status_forcelist=[429, 502], # rate limit
252
+ allowed_methods={"DELETE", "GET", "PUT", "POST"},
253
+ backoff_factor=2,
254
+ )
255
+ )
247
256
  return FuturesSession(
248
257
  session=session if session else _session,
249
- max_workers=MAX_WORKERS, adapter_kwargs=adapter_kwargs
258
+ max_workers=MAX_WORKERS,
259
+ adapter_kwargs=adapter_kwargs,
250
260
  )
251
261
 
252
262
 
253
263
  def _response_hook(resp, *args, **kwargs):
254
- content_type = resp.headers['content-type']
264
+ content_type = resp.headers["content-type"]
255
265
  if content_type == "application/json":
256
266
  result = resp.json()
257
267
 
@@ -278,7 +288,7 @@ def _response_hook(resp, *args, **kwargs):
278
288
  resp.count = 0
279
289
 
280
290
 
281
- def _chunk_by_size(items, max_size=0.95*MAX_BYTES):
291
+ def _chunk_by_size(items, max_size=0.95 * MAX_BYTES):
282
292
  buffer, buffer_size = [], 0
283
293
 
284
294
  for idx, item in enumerate(items):
@@ -303,17 +313,23 @@ def visit(path, key, value):
303
313
 
304
314
 
305
315
  def _in_ipython():
306
- return _ipython is not None and 'IPKernelApp' in _ipython.config
316
+ return _ipython is not None and "IPKernelApp" in _ipython.config
307
317
 
308
318
 
309
319
  if _in_ipython():
320
+
310
321
  def _hide_traceback(
311
- exc_tuple=None, filename=None, tb_offset=None,
312
- exception_only=False, running_compiled_code=False
322
+ exc_tuple=None,
323
+ filename=None,
324
+ tb_offset=None,
325
+ exception_only=False,
326
+ running_compiled_code=False,
313
327
  ):
314
328
  etype, value, tb = sys.exc_info()
315
329
 
316
- if issubclass(etype, (MPContribsClientError, SwaggerValidationError, ValidationError)):
330
+ if issubclass(
331
+ etype, (MPContribsClientError, SwaggerValidationError, ValidationError)
332
+ ):
317
333
  return _ipython._showtraceback(
318
334
  etype, value, _ipython.InteractiveTB.get_exception_only(etype, value)
319
335
  )
@@ -327,6 +343,7 @@ if _in_ipython():
327
343
 
328
344
  class Dict(dict):
329
345
  """Custom dictionary to display itself as HTML table with Bulma CSS"""
346
+
330
347
  def display(self, attrs: str = f'class="table {BULMA}"'):
331
348
  """Nice table display of dictionary
332
349
 
@@ -342,6 +359,7 @@ class Dict(dict):
342
359
 
343
360
  class Table(pd.DataFrame):
344
361
  """Wrapper class around pandas.DataFrame to provide display() and info()"""
362
+
345
363
  def display(self):
346
364
  """Display a plotly graph for the table if in IPython/Jupyter"""
347
365
  if _in_ipython():
@@ -386,7 +404,7 @@ class Table(pd.DataFrame):
386
404
  def _clean(self):
387
405
  """clean the dataframe"""
388
406
  self.replace([np.inf, -np.inf], np.nan, inplace=True)
389
- self.fillna('', inplace=True)
407
+ self.fillna("", inplace=True)
390
408
  self.index = self.index.astype(str)
391
409
  for col in self.columns:
392
410
  self[col] = self[col].astype(str)
@@ -415,6 +433,7 @@ class Table(pd.DataFrame):
415
433
 
416
434
  class Structure(PmgStructure):
417
435
  """Wrapper class around pymatgen.Structure to provide display() and info()"""
436
+
418
437
  def display(self):
419
438
  return self # TODO use static image from crystal toolkit?
420
439
 
@@ -440,6 +459,7 @@ class Structure(PmgStructure):
440
459
 
441
460
  class Attachment(dict):
442
461
  """Wrapper class around dict to handle attachments"""
462
+
443
463
  def decode(self) -> str:
444
464
  """Decode base64-encoded content of attachment"""
445
465
  return b64decode(self["content"], validate=True)
@@ -509,7 +529,7 @@ class Attachment(dict):
509
529
  return cls(
510
530
  name=filename,
511
531
  mime="application/gzip",
512
- content=b64encode(content).decode("utf-8")
532
+ content=b64encode(content).decode("utf-8"),
513
533
  )
514
534
 
515
535
  @classmethod
@@ -545,7 +565,7 @@ class Attachment(dict):
545
565
  return cls(
546
566
  name=path.name,
547
567
  mime=kind.mime if supported else "application/gzip",
548
- content=b64encode(content).decode("utf-8")
568
+ content=b64encode(content).decode("utf-8"),
549
569
  )
550
570
 
551
571
  @classmethod
@@ -561,6 +581,7 @@ class Attachment(dict):
561
581
 
562
582
  class Attachments(list):
563
583
  """Wrapper class to handle attachments automatically"""
584
+
564
585
  # TODO implement "plural" versions for Attachment methods
565
586
 
566
587
  @classmethod
@@ -663,7 +684,9 @@ def _load(protocol, host, headers_json, project, version):
663
684
  origin_url = f"{url}/apispec.json"
664
685
  http_client = RequestsClient()
665
686
  http_client.session.headers.update(headers)
666
- swagger_spec = Spec.from_dict(spec_dict, origin_url, http_client, bravado_config_dict)
687
+ swagger_spec = Spec.from_dict(
688
+ spec_dict, origin_url, http_client, bravado_config_dict
689
+ )
667
690
  http_client.session.close()
668
691
  return swagger_spec
669
692
 
@@ -682,7 +705,9 @@ def _load(protocol, host, headers_json, project, version):
682
705
  projects = sorted(d["name"] for d in resp["data"])
683
706
  projects_json = ujson.dumps(projects)
684
707
  # expand regex-based query parameters for `data` columns
685
- spec = _expand_params(protocol, host, version, projects_json, apikey=headers.get("x-api-key"))
708
+ spec = _expand_params(
709
+ protocol, host, version, projects_json, apikey=headers.get("x-api-key")
710
+ )
686
711
  spec.http_client.session.headers.update(headers)
687
712
  return spec
688
713
 
@@ -692,8 +717,8 @@ def _raw_specs(protocol, host, version):
692
717
  http_client = RequestsClient()
693
718
  url = f"{protocol}://{host}"
694
719
  origin_url = f"{url}/apispec.json"
695
- url4fn = origin_url.replace("apispec", f"apispec-{version}").encode('utf-8')
696
- fn = urlsafe_b64encode(url4fn).decode('utf-8')
720
+ url4fn = origin_url.replace("apispec", f"apispec-{version}").encode("utf-8")
721
+ fn = urlsafe_b64encode(url4fn).decode("utf-8")
697
722
  apispec = Path(gettempdir()) / fn
698
723
  spec_dict = None
699
724
 
@@ -710,7 +735,9 @@ def _raw_specs(protocol, host, version):
710
735
  logger.debug(f"Specs for {origin_url} and {version} saved as {apispec}.")
711
736
 
712
737
  if not spec_dict:
713
- raise MPContribsClientError(f"Couldn't load specs from {url} for {version}!") # not cached
738
+ raise MPContribsClientError(
739
+ f"Couldn't load specs from {url} for {version}!"
740
+ ) # not cached
714
741
 
715
742
  spec_dict["host"] = host
716
743
  spec_dict["schemes"] = [protocol]
@@ -722,7 +749,7 @@ def _raw_specs(protocol, host, version):
722
749
  cache=LRUCache(maxsize=100),
723
750
  key=lambda protocol, host, version, projects_json, **kwargs: hashkey(
724
751
  protocol, host, version, projects_json
725
- )
752
+ ),
726
753
  )
727
754
  def _expand_params(protocol, host, version, projects_json, apikey=None):
728
755
  columns = {"string": [], "number": []}
@@ -753,7 +780,7 @@ def _expand_params(protocol, host, version, projects_json, apikey=None):
753
780
 
754
781
  for param in raw_params:
755
782
  if param["name"].startswith("^data__"):
756
- op = param["name"].rsplit('$__', 1)[-1]
783
+ op = param["name"].rsplit("$__", 1)[-1]
757
784
  typ = param["type"]
758
785
  key = "number" if typ == "number" else "string"
759
786
 
@@ -761,7 +788,8 @@ def _expand_params(protocol, host, version, projects_json, apikey=None):
761
788
  param_name = f"{column}__{op}"
762
789
  if param_name not in params:
763
790
  param_spec = {
764
- k: v for k, v in param.items()
791
+ k: v
792
+ for k, v in param.items()
765
793
  if k not in ["name", "description"]
766
794
  }
767
795
  param_spec["name"] = param_name
@@ -775,18 +803,18 @@ def _expand_params(protocol, host, version, projects_json, apikey=None):
775
803
  spec = Spec(spec_dict, origin_url, http_client, bravado_config_dict)
776
804
  model_discovery(spec)
777
805
 
778
- if spec.config['internally_dereference_refs']:
806
+ if spec.config["internally_dereference_refs"]:
779
807
  spec.deref = _identity
780
808
  spec._internal_spec_dict = spec.deref_flattened_spec
781
809
 
782
- for user_defined_format in spec.config['formats']:
810
+ for user_defined_format in spec.config["formats"]:
783
811
  spec.register_format(user_defined_format)
784
812
 
785
813
  spec.resources = build_resources(spec)
786
814
  spec.api_url = build_api_serving_url(
787
815
  spec_dict=spec.spec_dict,
788
816
  origin_url=spec.origin_url,
789
- use_spec_url_for_base_path=spec.config['use_spec_url_for_base_path'],
817
+ use_spec_url_for_base_path=spec.config["use_spec_url_for_base_path"],
790
818
  )
791
819
  http_client.session.close()
792
820
  return spec
@@ -796,13 +824,15 @@ def _expand_params(protocol, host, version, projects_json, apikey=None):
796
824
  def _version(url):
797
825
  retries, max_retries = 0, 3
798
826
  protocol = urlparse(url).scheme
799
- is_mock_test = 'pytest' in sys.modules and protocol == "http"
827
+ is_mock_test = "pytest" in sys.modules and protocol == "http"
800
828
 
801
829
  if is_mock_test:
802
830
  now = datetime.datetime.now()
803
831
  return Version(
804
- major=now.year, minor=now.month, patch=now.day,
805
- prerelease=(str(now.hour), str(now.minute))
832
+ major=now.year,
833
+ minor=now.month,
834
+ patch=now.day,
835
+ prerelease=(str(now.hour), str(now.minute)),
806
836
  )
807
837
  else:
808
838
  while retries < max_retries:
@@ -831,6 +861,7 @@ class Client(SwaggerClient):
831
861
  >>> from mpcontribs.client import Client
832
862
  >>> client = Client()
833
863
  """
864
+
834
865
  def __init__(
835
866
  self,
836
867
  apikey: str = None,
@@ -867,13 +898,17 @@ class Client(SwaggerClient):
867
898
  self.headers["Content-Type"] = "application/json"
868
899
  self.headers_json = ujson.dumps(self.headers, sort_keys=True)
869
900
  self.host = host
870
- ssl = host.endswith(".materialsproject.org") and not host.startswith("localhost.")
901
+ ssl = host.endswith(".materialsproject.org") and not host.startswith(
902
+ "localhost."
903
+ )
871
904
  self.protocol = "https" if ssl else "http"
872
905
  self.url = f"{self.protocol}://{self.host}"
873
906
  self.project = project
874
907
 
875
908
  if self.url not in VALID_URLS:
876
- raise MPContribsClientError(f"{self.url} not a valid URL (one of {VALID_URLS})")
909
+ raise MPContribsClientError(
910
+ f"{self.url} not a valid URL (one of {VALID_URLS})"
911
+ )
877
912
 
878
913
  self.version = _version(self.url) # includes healthcheck
879
914
  self.session = get_session(session=session)
@@ -887,7 +922,9 @@ class Client(SwaggerClient):
887
922
 
888
923
  @property
889
924
  def cached_swagger_spec(self):
890
- return _load(self.protocol, self.host, self.headers_json, self.project, self.version)
925
+ return _load(
926
+ self.protocol, self.host, self.headers_json, self.project, self.version
927
+ )
891
928
 
892
929
  def __dir__(self):
893
930
  members = set(self.swagger_spec.resources.keys())
@@ -902,7 +939,7 @@ class Client(SwaggerClient):
902
939
  def _is_valid_payload(self, model: str, data: dict):
903
940
  model_spec = deepcopy(self.get_model(f"{model}sSchema")._model_spec)
904
941
  model_spec.pop("required")
905
- model_spec['additionalProperties'] = False
942
+ model_spec["additionalProperties"] = False
906
943
 
907
944
  try:
908
945
  validate_object(self.swagger_spec, model_spec, data)
@@ -919,7 +956,9 @@ class Client(SwaggerClient):
919
956
 
920
957
  return True, None
921
958
 
922
- def _get_per_page_default_max(self, op: str = "query", resource: str = "contributions") -> int:
959
+ def _get_per_page_default_max(
960
+ self, op: str = "query", resource: str = "contributions"
961
+ ) -> int:
923
962
  attr = f"{op}{resource.capitalize()}"
924
963
  resource = self.swagger_spec.resources[resource]
925
964
  param_spec = getattr(resource, attr).params["per_page"].param_spec
@@ -928,7 +967,9 @@ class Client(SwaggerClient):
928
967
  def _get_per_page(
929
968
  self, per_page: int = -1, op: str = "query", resource: str = "contributions"
930
969
  ) -> int:
931
- per_page_default, per_page_max = self._get_per_page_default_max(op=op, resource=resource)
970
+ per_page_default, per_page_max = self._get_per_page_default_max(
971
+ op=op, resource=resource
972
+ )
932
973
  if per_page < 0:
933
974
  per_page = per_page_default
934
975
  return min(per_page_max, per_page)
@@ -942,7 +983,9 @@ class Client(SwaggerClient):
942
983
  ) -> List[dict]:
943
984
  """Avoid URI too long errors"""
944
985
  pp_default, pp_max = self._get_per_page_default_max(op=op, resource=resource)
945
- per_page = pp_default if any(k.endswith("__in") for k in query.keys()) else pp_max
986
+ per_page = (
987
+ pp_default if any(k.endswith("__in") for k in query.keys()) else pp_max
988
+ )
946
989
  nr_params_to_split = sum(
947
990
  len(v) > per_page for v in query.values() if isinstance(v, list)
948
991
  )
@@ -973,7 +1016,7 @@ class Client(SwaggerClient):
973
1016
 
974
1017
  if len(queries) == 1 and pages and pages > 0:
975
1018
  queries = []
976
- for page in range(1, pages+1):
1019
+ for page in range(1, pages + 1):
977
1020
  queries.append(deepcopy(query))
978
1021
  queries[-1]["page"] = page
979
1022
 
@@ -996,29 +1039,25 @@ class Client(SwaggerClient):
996
1039
  params: dict,
997
1040
  rel_url: str = "contributions",
998
1041
  op: str = "query",
999
- data: dict = None
1042
+ data: dict = None,
1000
1043
  ):
1001
1044
  rname = rel_url.split("/", 1)[0]
1002
1045
  resource = self.swagger_spec.resources[rname]
1003
1046
  attr = f"{op}{rname.capitalize()}"
1004
1047
  method = getattr(resource, attr).http_method
1005
1048
  kwargs = dict(
1006
- headers=self.headers, params=params, hooks={'response': _response_hook}
1049
+ headers=self.headers, params=params, hooks={"response": _response_hook}
1007
1050
  )
1008
1051
 
1009
1052
  if method == "put" and data:
1010
1053
  kwargs["data"] = ujson.dumps(data).encode("utf-8")
1011
1054
 
1012
- future = getattr(self.session, method)(
1013
- f"{self.url}/{rel_url}/", **kwargs
1014
- )
1055
+ future = getattr(self.session, method)(f"{self.url}/{rel_url}/", **kwargs)
1015
1056
  setattr(future, "track_id", track_id)
1016
1057
  return future
1017
1058
 
1018
1059
  def available_query_params(
1019
- self,
1020
- startswith: tuple = None,
1021
- resource: str = "contributions"
1060
+ self, startswith: tuple = None, resource: str = "contributions"
1022
1061
  ) -> list:
1023
1062
  resources = self.swagger_spec.resources
1024
1063
  resource_obj = resources.get(resource)
@@ -1032,10 +1071,7 @@ class Client(SwaggerClient):
1032
1071
  if not startswith:
1033
1072
  return params
1034
1073
 
1035
- return [
1036
- param for param in params
1037
- if param.startswith(startswith)
1038
- ]
1074
+ return [param for param in params if param.startswith(startswith)]
1039
1075
 
1040
1076
  def get_project(self, name: str = None, fields: list = None) -> Type[Dict]:
1041
1077
  """Retrieve a project entry
@@ -1046,7 +1082,9 @@ class Client(SwaggerClient):
1046
1082
  """
1047
1083
  name = self.project or name
1048
1084
  if not name:
1049
- raise MPContribsClientError("initialize client with project or set `name` argument!")
1085
+ raise MPContribsClientError(
1086
+ "initialize client with project or set `name` argument!"
1087
+ )
1050
1088
 
1051
1089
  fields = fields or ["_all"] # retrieve all fields by default
1052
1090
  return Dict(self.projects.getProjectByName(pk=name, _fields=fields).result())
@@ -1057,7 +1095,7 @@ class Client(SwaggerClient):
1057
1095
  term: str = None,
1058
1096
  fields: list = None,
1059
1097
  sort: str = None,
1060
- timeout: int = -1
1098
+ timeout: int = -1,
1061
1099
  ) -> List[dict]:
1062
1100
  """Query projects by query and/or term (Atlas Search)
1063
1101
 
@@ -1080,17 +1118,20 @@ class Client(SwaggerClient):
1080
1118
  return [self.get_project(name=query.get("name"), fields=fields)]
1081
1119
 
1082
1120
  if term:
1121
+
1083
1122
  def search_future(search_term):
1084
1123
  future = self.session.get(
1085
1124
  f"{self.url}/projects/search",
1086
1125
  headers=self.headers,
1087
- hooks={'response': _response_hook},
1126
+ hooks={"response": _response_hook},
1088
1127
  params={"term": search_term},
1089
1128
  )
1090
1129
  setattr(future, "track_id", "search")
1091
1130
  return future
1092
1131
 
1093
- responses = _run_futures([search_future(term)], timeout=timeout, disable=True)
1132
+ responses = _run_futures(
1133
+ [search_future(term)], timeout=timeout, disable=True
1134
+ )
1094
1135
  query["name__in"] = responses["search"].get("result", [])
1095
1136
 
1096
1137
  if fields:
@@ -1110,11 +1151,13 @@ class Client(SwaggerClient):
1110
1151
 
1111
1152
  queries = []
1112
1153
 
1113
- for page in range(2, total_pages+1):
1154
+ for page in range(2, total_pages + 1):
1114
1155
  queries.append(deepcopy(query))
1115
1156
  queries[-1]["page"] = page
1116
1157
 
1117
- futures = [self._get_future(i, q, rel_url="projects") for i, q in enumerate(queries)]
1158
+ futures = [
1159
+ self._get_future(i, q, rel_url="projects") for i, q in enumerate(queries)
1160
+ ]
1118
1161
  responses = _run_futures(futures, total=total_count, timeout=timeout)
1119
1162
 
1120
1163
  for resp in responses.values():
@@ -1122,7 +1165,9 @@ class Client(SwaggerClient):
1122
1165
 
1123
1166
  return ret["data"]
1124
1167
 
1125
- def create_project(self, name: str, title: str, authors: str, description: str, url: str):
1168
+ def create_project(
1169
+ self, name: str, title: str, authors: str, description: str, url: str
1170
+ ):
1126
1171
  """Create a project
1127
1172
 
1128
1173
  Args:
@@ -1138,8 +1183,11 @@ class Client(SwaggerClient):
1138
1183
  raise MPContribsClientError(f"Project with {query} already exists!")
1139
1184
 
1140
1185
  project = {
1141
- "name": name, "title": title, "authors": authors, "description": description,
1142
- "references": [{"label": "REF", "url": url}]
1186
+ "name": name,
1187
+ "title": title,
1188
+ "authors": authors,
1189
+ "description": description,
1190
+ "references": [{"label": "REF", "url": url}],
1143
1191
  }
1144
1192
  resp = self.projects.createProject(project=project).result()
1145
1193
  owner = resp.get("owner")
@@ -1163,7 +1211,9 @@ class Client(SwaggerClient):
1163
1211
 
1164
1212
  name = self.project or name
1165
1213
  if not name:
1166
- raise MPContribsClientError("initialize client with project or set `name` argument!")
1214
+ raise MPContribsClientError(
1215
+ "initialize client with project or set `name` argument!"
1216
+ )
1167
1217
 
1168
1218
  disallowed = ["is_approved", "stats", "columns", "is_public", "owner"]
1169
1219
  for k in list(update.keys()):
@@ -1171,9 +1221,13 @@ class Client(SwaggerClient):
1171
1221
  logger.warning(f"removing `{k}` from update - not allowed.")
1172
1222
  update.pop(k)
1173
1223
  if k == "columns":
1174
- logger.info("use `client.init_columns()` to update project columns.")
1224
+ logger.info(
1225
+ "use `client.init_columns()` to update project columns."
1226
+ )
1175
1227
  elif k == "is_public":
1176
- logger.info("use `client.make_public/private()` to set `is_public`.")
1228
+ logger.info(
1229
+ "use `client.make_public/private()` to set `is_public`."
1230
+ )
1177
1231
  elif not isinstance(update[k], bool) and not update[k]:
1178
1232
  logger.warning(f"removing `{k}` from update - no update requested.")
1179
1233
  update.pop(k)
@@ -1196,8 +1250,7 @@ class Client(SwaggerClient):
1196
1250
  logger.error("cannot change project name after contributions submitted.")
1197
1251
 
1198
1252
  payload = {
1199
- k: v for k, v in update.items()
1200
- if k in fields and project.get(k, None) != v
1253
+ k: v for k, v in update.items() if k in fields and project.get(k, None) != v
1201
1254
  }
1202
1255
  if not payload:
1203
1256
  logger.warning("nothing to update")
@@ -1219,7 +1272,9 @@ class Client(SwaggerClient):
1219
1272
  """
1220
1273
  name = self.project or name
1221
1274
  if not name:
1222
- raise MPContribsClientError("initialize client with project or set `name` argument!")
1275
+ raise MPContribsClientError(
1276
+ "initialize client with project or set `name` argument!"
1277
+ )
1223
1278
 
1224
1279
  if not self.get_totals(query={"name": name}, resource="projects")[0]:
1225
1280
  raise MPContribsClientError(f"Project `{name}` doesn't exist!")
@@ -1238,7 +1293,9 @@ class Client(SwaggerClient):
1238
1293
  if not fields:
1239
1294
  fields = list(self.get_model("ContributionsSchema")._properties.keys())
1240
1295
  fields.remove("needs_build") # internal field
1241
- return Dict(self.contributions.getContributionById(pk=cid, _fields=fields).result())
1296
+ return Dict(
1297
+ self.contributions.getContributionById(pk=cid, _fields=fields).result()
1298
+ )
1242
1299
 
1243
1300
  def get_table(self, tid_or_md5: str) -> Type[Table]:
1244
1301
  """Retrieve full Pandas DataFrame for a table
@@ -1248,7 +1305,9 @@ class Client(SwaggerClient):
1248
1305
  """
1249
1306
  str_len = len(tid_or_md5)
1250
1307
  if str_len not in {24, 32}:
1251
- raise MPContribsClientError(f"'{tid_or_md5}' is not a valid table id or md5 hash!")
1308
+ raise MPContribsClientError(
1309
+ f"'{tid_or_md5}' is not a valid table id or md5 hash!"
1310
+ )
1252
1311
 
1253
1312
  if str_len == 32:
1254
1313
  tables = self.tables.queryTables(md5=tid_or_md5, _fields=["id"]).result()
@@ -1288,12 +1347,18 @@ class Client(SwaggerClient):
1288
1347
  """
1289
1348
  str_len = len(sid_or_md5)
1290
1349
  if str_len not in {24, 32}:
1291
- raise MPContribsClientError(f"'{sid_or_md5}' is not a valid structure id or md5 hash!")
1350
+ raise MPContribsClientError(
1351
+ f"'{sid_or_md5}' is not a valid structure id or md5 hash!"
1352
+ )
1292
1353
 
1293
1354
  if str_len == 32:
1294
- structures = self.structures.queryStructures(md5=sid_or_md5, _fields=["id"]).result()
1355
+ structures = self.structures.queryStructures(
1356
+ md5=sid_or_md5, _fields=["id"]
1357
+ ).result()
1295
1358
  if not structures:
1296
- raise MPContribsClientError(f"structure for md5 '{sid_or_md5}' not found!")
1359
+ raise MPContribsClientError(
1360
+ f"structure for md5 '{sid_or_md5}' not found!"
1361
+ )
1297
1362
  sid = structures["data"][0]["id"]
1298
1363
  else:
1299
1364
  sid = sid_or_md5
@@ -1310,19 +1375,25 @@ class Client(SwaggerClient):
1310
1375
  """
1311
1376
  str_len = len(aid_or_md5)
1312
1377
  if str_len not in {24, 32}:
1313
- raise MPContribsClientError(f"'{aid_or_md5}' is not a valid attachment id or md5 hash!")
1378
+ raise MPContribsClientError(
1379
+ f"'{aid_or_md5}' is not a valid attachment id or md5 hash!"
1380
+ )
1314
1381
 
1315
1382
  if str_len == 32:
1316
1383
  attachments = self.attachments.queryAttachments(
1317
1384
  md5=aid_or_md5, _fields=["id"]
1318
1385
  ).result()
1319
1386
  if not attachments:
1320
- raise MPContribsClientError(f"attachment for md5 '{aid_or_md5}' not found!")
1387
+ raise MPContribsClientError(
1388
+ f"attachment for md5 '{aid_or_md5}' not found!"
1389
+ )
1321
1390
  aid = attachments["data"][0]["id"]
1322
1391
  else:
1323
1392
  aid = aid_or_md5
1324
1393
 
1325
- return Attachment(self.attachments.getAttachmentById(pk=aid, _fields=["_all"]).result())
1394
+ return Attachment(
1395
+ self.attachments.getAttachmentById(pk=aid, _fields=["_all"]).result()
1396
+ )
1326
1397
 
1327
1398
  def init_columns(self, columns: dict = None) -> dict:
1328
1399
  """initialize columns for a project to set their order and desired units
@@ -1366,7 +1437,9 @@ class Client(SwaggerClient):
1366
1437
  raise MPContribsClientError(f"Number of columns larger than {MAX_COLUMNS}!")
1367
1438
 
1368
1439
  if not all(isinstance(v, str) for v in columns.values() if v is not None):
1369
- raise MPContribsClientError("All values in `columns` need to be None or of type str!")
1440
+ raise MPContribsClientError(
1441
+ "All values in `columns` need to be None or of type str!"
1442
+ )
1370
1443
 
1371
1444
  new_columns = []
1372
1445
 
@@ -1381,13 +1454,17 @@ class Client(SwaggerClient):
1381
1454
 
1382
1455
  nesting = k.count(".")
1383
1456
  if nesting > MAX_NESTING:
1384
- raise MPContribsClientError(f"Nesting depth larger than {MAX_NESTING} for {k}!")
1457
+ raise MPContribsClientError(
1458
+ f"Nesting depth larger than {MAX_NESTING} for {k}!"
1459
+ )
1385
1460
 
1386
1461
  for col in scanned_columns:
1387
1462
  if nesting and col.startswith(k):
1388
- raise MPContribsClientError(f"Duplicate definition of {k} in {col}!")
1463
+ raise MPContribsClientError(
1464
+ f"Duplicate definition of {k} in {col}!"
1465
+ )
1389
1466
 
1390
- for n in range(1, nesting+1):
1467
+ for n in range(1, nesting + 1):
1391
1468
  if k.rsplit(".", n)[0] == col:
1392
1469
  raise MPContribsClientError(
1393
1470
  f"Ancestor of {k} already defined in {col}!"
@@ -1400,7 +1477,7 @@ class Client(SwaggerClient):
1400
1477
  )
1401
1478
 
1402
1479
  if v != "" and v is not None and v not in ureg:
1403
- raise MPContribsClientError(f"Unit '{v}' for {k} invalid!")
1480
+ raise MPContribsClientError(f"Unit '{v}' for {k} not supported!")
1404
1481
 
1405
1482
  scanned_columns.add(k)
1406
1483
 
@@ -1411,8 +1488,11 @@ class Client(SwaggerClient):
1411
1488
  sorted(sorted_columns.items(), key=lambda item: item[0].count("."))
1412
1489
  )
1413
1490
 
1491
+ # TODO catch unsupported column renaming or implement solution
1414
1492
  # reconcile with existing columns
1415
- resp = self.projects.getProjectByName(pk=self.project, _fields=["columns"]).result()
1493
+ resp = self.projects.getProjectByName(
1494
+ pk=self.project, _fields=["columns"]
1495
+ ).result()
1416
1496
  existing_columns = {}
1417
1497
 
1418
1498
  for col in resp["columns"]:
@@ -1437,15 +1517,25 @@ class Client(SwaggerClient):
1437
1517
  new_unit = new_column.get("unit", "NaN")
1438
1518
  existing_unit = existing_column.get("unit")
1439
1519
  if existing_unit != new_unit:
1520
+ conv_args = []
1521
+ for u in [existing_unit, new_unit]:
1522
+ try:
1523
+ conv_args.append(ureg.Unit(u))
1524
+ except ValueError:
1525
+ raise MPContribsClientError(
1526
+ f"Can't convert {existing_unit} to {new_unit} for {path}"
1527
+ )
1440
1528
  try:
1441
- factor = ureg.convert(1, ureg.Unit(existing_unit), ureg.Unit(new_unit))
1529
+ factor = ureg.convert(1, *conv_args)
1442
1530
  except DimensionalityError:
1443
1531
  raise MPContribsClientError(
1444
1532
  f"Can't convert {existing_unit} to {new_unit} for {path}"
1445
1533
  )
1446
1534
 
1447
1535
  if not isclose(factor, 1):
1448
- logger.info(f"Changing {existing_unit} to {new_unit} for {path} ...")
1536
+ logger.info(
1537
+ f"Changing {existing_unit} to {new_unit} for {path} ..."
1538
+ )
1449
1539
  # TODO scale contributions to new unit
1450
1540
  raise MPContribsClientError(
1451
1541
  "Changing units not supported yet. Please resubmit"
@@ -1459,7 +1549,9 @@ class Client(SwaggerClient):
1459
1549
  if not valid:
1460
1550
  raise MPContribsClientError(error)
1461
1551
 
1462
- return self.projects.updateProjectByName(pk=self.project, project=payload).result()
1552
+ return self.projects.updateProjectByName(
1553
+ pk=self.project, project=payload
1554
+ ).result()
1463
1555
 
1464
1556
  def delete_contributions(self, query: dict = None, timeout: int = -1):
1465
1557
  """Remove all contributions for a query
@@ -1482,7 +1574,9 @@ class Client(SwaggerClient):
1482
1574
  cids = list(self.get_all_ids(query).get(query["project"], {}).get("ids", set()))
1483
1575
 
1484
1576
  if not cids:
1485
- logger.info(f"There aren't any contributions to delete for {query['project']}")
1577
+ logger.info(
1578
+ f"There aren't any contributions to delete for {query['project']}"
1579
+ )
1486
1580
  return
1487
1581
 
1488
1582
  total = len(cids)
@@ -1509,7 +1603,7 @@ class Client(SwaggerClient):
1509
1603
  query: dict = None,
1510
1604
  timeout: int = -1,
1511
1605
  resource: str = "contributions",
1512
- op: str = "query"
1606
+ op: str = "query",
1513
1607
  ) -> tuple:
1514
1608
  """Retrieve total count and pages for resource entries matching query
1515
1609
 
@@ -1536,7 +1630,9 @@ class Client(SwaggerClient):
1536
1630
  query["_fields"] = [] # only need totals -> explicitly request no fields
1537
1631
  queries = self._split_query(query, resource=resource, op=op) # don't paginate
1538
1632
  result = {"total_count": 0, "total_pages": 0}
1539
- futures = [self._get_future(i, q, rel_url=resource) for i, q in enumerate(queries)]
1633
+ futures = [
1634
+ self._get_future(i, q, rel_url=resource) for i, q in enumerate(queries)
1635
+ ]
1540
1636
  responses = _run_futures(futures, timeout=timeout, desc="Totals")
1541
1637
 
1542
1638
  for resp in responses.values():
@@ -1562,7 +1658,9 @@ class Client(SwaggerClient):
1562
1658
  """
1563
1659
  return {
1564
1660
  p["name"]: p["unique_identifiers"]
1565
- for p in self.query_projects(query=query, fields=["name", "unique_identifiers"])
1661
+ for p in self.query_projects(
1662
+ query=query, fields=["name", "unique_identifiers"]
1663
+ )
1566
1664
  }
1567
1665
 
1568
1666
  def get_all_ids(
@@ -1634,10 +1732,10 @@ class Client(SwaggerClient):
1634
1732
  raise MPContribsClientError(f"`op` has to be one of {ops}")
1635
1733
 
1636
1734
  unique_identifiers = self.get_unique_identifiers_flags()
1637
- data_id_fields = {
1638
- k: v for k, v in data_id_fields.items()
1639
- if k in unique_identifiers and isinstance(v, str)
1640
- } if data_id_fields else {}
1735
+ data_id_fields = data_id_fields or {}
1736
+ for k, v in data_id_fields.items():
1737
+ if k in unique_identifiers and isinstance(v, str):
1738
+ data_id_fields[k] = v
1641
1739
 
1642
1740
  ret = {}
1643
1741
  query = query or {}
@@ -1649,8 +1747,7 @@ class Client(SwaggerClient):
1649
1747
 
1650
1748
  if data_id_fields:
1651
1749
  id_fields.update(
1652
- f"data.{data_id_field}"
1653
- for data_id_field in data_id_fields.values()
1750
+ f"data.{data_id_field}" for data_id_field in data_id_fields.values()
1654
1751
  )
1655
1752
 
1656
1753
  query["_fields"] = list(id_fields | components)
@@ -1715,7 +1812,9 @@ class Client(SwaggerClient):
1715
1812
 
1716
1813
  for component in components:
1717
1814
  if component in contrib:
1718
- ret[project][identifier][data_id_field_val][component] = {
1815
+ ret[project][identifier][data_id_field_val][
1816
+ component
1817
+ ] = {
1719
1818
  d["name"]: {"id": d["id"], "md5": d["md5"]}
1720
1819
  for d in contrib[component]
1721
1820
  }
@@ -1728,7 +1827,7 @@ class Client(SwaggerClient):
1728
1827
  fields: list = None,
1729
1828
  sort: str = None,
1730
1829
  paginate: bool = False,
1731
- timeout: int = -1
1830
+ timeout: int = -1,
1732
1831
  ) -> List[dict]:
1733
1832
  """Query contributions
1734
1833
 
@@ -1780,10 +1879,7 @@ class Client(SwaggerClient):
1780
1879
  return ret
1781
1880
 
1782
1881
  def update_contributions(
1783
- self,
1784
- data: dict,
1785
- query: dict = None,
1786
- timeout: int = -1
1882
+ self, data: dict, query: dict = None, timeout: int = -1
1787
1883
  ) -> dict:
1788
1884
  """Apply the same update to all contributions in a project (matching query)
1789
1885
 
@@ -1827,7 +1923,9 @@ class Client(SwaggerClient):
1827
1923
  return
1828
1924
 
1829
1925
  # get current list of data columns to decide if swagger reload is needed
1830
- resp = self.projects.getProjectByName(pk=self.project, _fields=["columns"]).result()
1926
+ resp = self.projects.getProjectByName(
1927
+ pk=self.project, _fields=["columns"]
1928
+ ).result()
1831
1929
  old_paths = set(c["path"] for c in resp["columns"])
1832
1930
 
1833
1931
  total = len(cids)
@@ -1842,7 +1940,9 @@ class Client(SwaggerClient):
1842
1940
  updated = sum(resp["count"] for _, resp in responses.items())
1843
1941
 
1844
1942
  if updated:
1845
- resp = self.projects.getProjectByName(pk=self.project, _fields=["columns"]).result()
1943
+ resp = self.projects.getProjectByName(
1944
+ pk=self.project, _fields=["columns"]
1945
+ ).result()
1846
1946
  new_paths = set(c["path"] for c in resp["columns"])
1847
1947
 
1848
1948
  if new_paths != old_paths:
@@ -1853,10 +1953,7 @@ class Client(SwaggerClient):
1853
1953
  return {"updated": updated, "total": total, "seconds_elapsed": toc - tic}
1854
1954
 
1855
1955
  def make_public(
1856
- self,
1857
- query: dict = None,
1858
- recursive: bool = False,
1859
- timeout: int = -1
1956
+ self, query: dict = None, recursive: bool = False, timeout: int = -1
1860
1957
  ) -> dict:
1861
1958
  """Publish a project and optionally its contributions
1862
1959
 
@@ -1869,10 +1966,7 @@ class Client(SwaggerClient):
1869
1966
  )
1870
1967
 
1871
1968
  def make_private(
1872
- self,
1873
- query: dict = None,
1874
- recursive: bool = False,
1875
- timeout: int = -1
1969
+ self, query: dict = None, recursive: bool = False, timeout: int = -1
1876
1970
  ) -> dict:
1877
1971
  """Make a project and optionally its contributions private
1878
1972
 
@@ -1889,7 +1983,7 @@ class Client(SwaggerClient):
1889
1983
  is_public: bool,
1890
1984
  query: dict = None,
1891
1985
  recursive: bool = False,
1892
- timeout: int = -1
1986
+ timeout: int = -1,
1893
1987
  ) -> dict:
1894
1988
  """Set the `is_public` flag for a project and optionally its contributions
1895
1989
 
@@ -1914,16 +2008,22 @@ class Client(SwaggerClient):
1914
2008
  pk=query["project"], _fields=["is_public", "is_approved"]
1915
2009
  ).result()
1916
2010
  except HTTPNotFound:
1917
- raise MPContribsClientError(f"project `{query['project']}` not found or access denied!")
2011
+ raise MPContribsClientError(
2012
+ f"project `{query['project']}` not found or access denied!"
2013
+ )
1918
2014
 
1919
2015
  if not recursive and resp["is_public"] == is_public:
1920
- return {"warning": f"`is_public` already set to {is_public} for `{query['project']}`."}
2016
+ return {
2017
+ "warning": f"`is_public` already set to {is_public} for `{query['project']}`."
2018
+ }
1921
2019
 
1922
2020
  ret = {}
1923
2021
 
1924
2022
  if resp["is_public"] != is_public:
1925
2023
  if is_public and not resp["is_approved"]:
1926
- raise MPContribsClientError(f"project `{query['project']}` is not approved yet!")
2024
+ raise MPContribsClientError(
2025
+ f"project `{query['project']}` is not approved yet!"
2026
+ )
1927
2027
 
1928
2028
  resp = self.projects.updateProjectByName(
1929
2029
  pk=query["project"], project={"is_public": is_public}
@@ -1944,7 +2044,7 @@ class Client(SwaggerClient):
1944
2044
  contributions: List[dict],
1945
2045
  ignore_dupes: bool = False,
1946
2046
  timeout: int = -1,
1947
- skip_dupe_check: bool = False
2047
+ skip_dupe_check: bool = False,
1948
2048
  ):
1949
2049
  """Submit a list of contributions
1950
2050
 
@@ -1974,7 +2074,9 @@ class Client(SwaggerClient):
1974
2074
  skip_dupe_check (bool): skip duplicate check for contribution identifiers
1975
2075
  """
1976
2076
  if not contributions or not isinstance(contributions, list):
1977
- raise MPContribsClientError("Please provide list of contributions to submit.")
2077
+ raise MPContribsClientError(
2078
+ "Please provide list of contributions to submit."
2079
+ )
1978
2080
 
1979
2081
  # get existing contributions
1980
2082
  tic = time.perf_counter()
@@ -1985,11 +2087,15 @@ class Client(SwaggerClient):
1985
2087
  for idx, c in enumerate(contributions):
1986
2088
  has_keys = require_one_of & c.keys()
1987
2089
  if not has_keys:
1988
- raise MPContribsClientError(f"Nothing to submit for contribution #{idx}!")
2090
+ raise MPContribsClientError(
2091
+ f"Nothing to submit for contribution #{idx}!"
2092
+ )
1989
2093
  elif not all(c[k] for k in has_keys):
1990
2094
  for k in has_keys:
1991
2095
  if not c[k]:
1992
- raise MPContribsClientError(f"Empty `{k}` for contribution #{idx}!")
2096
+ raise MPContribsClientError(
2097
+ f"Empty `{k}` for contribution #{idx}!"
2098
+ )
1993
2099
  elif "id" in c:
1994
2100
  collect_ids.append(c["id"])
1995
2101
  elif "project" in c and "identifier" in c:
@@ -2017,12 +2123,18 @@ class Client(SwaggerClient):
2017
2123
 
2018
2124
  if not skip_dupe_check and len(collect_ids) != len(contributions):
2019
2125
  nproj = len(project_names)
2020
- query = {"name__in": project_names} if nproj > 1 else {"name": project_names[0]}
2126
+ query = (
2127
+ {"name__in": project_names} if nproj > 1 else {"name": project_names[0]}
2128
+ )
2021
2129
  unique_identifiers = self.get_unique_identifiers_flags(query)
2022
- query = {"project__in": project_names} if nproj > 1 else {"project": project_names[0]}
2023
- existing = defaultdict(dict, self.get_all_ids(
2024
- query, include=COMPONENTS, timeout=timeout
2025
- ))
2130
+ query = (
2131
+ {"project__in": project_names}
2132
+ if nproj > 1
2133
+ else {"project": project_names[0]}
2134
+ )
2135
+ existing = defaultdict(
2136
+ dict, self.get_all_ids(query, include=COMPONENTS, timeout=timeout)
2137
+ )
2026
2138
 
2027
2139
  # prepare contributions
2028
2140
  contribs = defaultdict(list)
@@ -2044,22 +2156,35 @@ class Client(SwaggerClient):
2044
2156
  update = "id" in contrib
2045
2157
  project_name = id2project[contrib["id"]] if update else contrib["project"]
2046
2158
  if (
2047
- not update and unique_identifiers.get(project_name)
2048
- and contrib["identifier"] in existing.get(project_name, {}).get("identifiers", {})
2159
+ not update
2160
+ and unique_identifiers.get(project_name)
2161
+ and contrib["identifier"]
2162
+ in existing.get(project_name, {}).get("identifiers", {})
2049
2163
  ):
2050
2164
  continue
2051
2165
 
2052
- contribs[project_name].append({
2053
- k: deepcopy(contrib[k])
2054
- for k in fields if k in contrib
2055
- })
2166
+ contrib_copy = {}
2167
+ for k in fields:
2168
+ if k in contrib:
2169
+ flat = {}
2170
+ for kk, vv in flatten(contrib[k], reducer="dot").items():
2171
+ if isinstance(vv, bool):
2172
+ flat[kk] = "Yes" if vv else "No"
2173
+ elif isinstance(vv, str):
2174
+ flat[kk] = vv
2175
+
2176
+ contrib_copy[k] = deepcopy(unflatten(flat, splitter="dot"))
2177
+
2178
+ contribs[project_name].append(contrib_copy)
2056
2179
 
2057
2180
  for component in COMPONENTS:
2058
2181
  elements = contrib.get(component, [])
2059
2182
  nelems = len(elements)
2060
2183
 
2061
2184
  if nelems > MAX_ELEMS:
2062
- raise MPContribsClientError(f"Too many {component} ({nelems} > {MAX_ELEMS})!")
2185
+ raise MPContribsClientError(
2186
+ f"Too many {component} ({nelems} > {MAX_ELEMS})!"
2187
+ )
2063
2188
 
2064
2189
  if update and not nelems:
2065
2190
  continue # nothing to update for this component
@@ -2075,7 +2200,9 @@ class Client(SwaggerClient):
2075
2200
  is_table = isinstance(element, (pd.DataFrame, Table))
2076
2201
  is_attachment = isinstance(element, (str, Path, Attachment))
2077
2202
  if component == "structures" and not is_structure:
2078
- raise MPContribsClientError(f"Use pymatgen Structure for {component}!")
2203
+ raise MPContribsClientError(
2204
+ f"Use pymatgen Structure for {component}!"
2205
+ )
2079
2206
  elif component == "tables" and not is_table:
2080
2207
  raise MPContribsClientError(
2081
2208
  f"Use pandas DataFrame or mpontribs.client.Table for {component}!"
@@ -2095,7 +2222,9 @@ class Client(SwaggerClient):
2095
2222
 
2096
2223
  if "properties" in dct:
2097
2224
  if dct["properties"]:
2098
- logger.warning("storing structure properties not supported, yet!")
2225
+ logger.warning(
2226
+ "storing structure properties not supported, yet!"
2227
+ )
2099
2228
  del dct["properties"]
2100
2229
  elif is_table:
2101
2230
  table = element
@@ -2121,8 +2250,11 @@ class Client(SwaggerClient):
2121
2250
  dct["name"] = element.name
2122
2251
 
2123
2252
  dupe = bool(
2124
- digest in digests[project_name][component] or
2125
- digest in existing.get(project_name, {}).get(component, {}).get("md5s", [])
2253
+ digest in digests[project_name][component]
2254
+ or digest
2255
+ in existing.get(project_name, {})
2256
+ .get(component, {})
2257
+ .get("md5s", [])
2126
2258
  )
2127
2259
 
2128
2260
  if not ignore_dupes and dupe:
@@ -2133,9 +2265,13 @@ class Client(SwaggerClient):
2133
2265
  digests[project_name][component].add(digest)
2134
2266
  contribs[project_name][-1][component].append(dct)
2135
2267
 
2136
- valid, error = self._is_valid_payload("Contribution", contribs[project_name][-1])
2268
+ valid, error = self._is_valid_payload(
2269
+ "Contribution", contribs[project_name][-1]
2270
+ )
2137
2271
  if not valid:
2138
- raise MPContribsClientError(f"{contrib['identifier']} invalid: {error}!")
2272
+ raise MPContribsClientError(
2273
+ f"{contrib['identifier']} invalid: {error}!"
2274
+ )
2139
2275
 
2140
2276
  # submit contributions
2141
2277
  if contribs:
@@ -2146,7 +2282,7 @@ class Client(SwaggerClient):
2146
2282
  future = self.session.post(
2147
2283
  f"{self.url}/contributions/",
2148
2284
  headers=self.headers,
2149
- hooks={'response': _response_hook},
2285
+ hooks={"response": _response_hook},
2150
2286
  data=payload,
2151
2287
  )
2152
2288
  setattr(future, "track_id", track_id)
@@ -2156,7 +2292,7 @@ class Client(SwaggerClient):
2156
2292
  future = self.session.put(
2157
2293
  f"{self.url}/contributions/{pk}/",
2158
2294
  headers=self.headers,
2159
- hooks={'response': _response_hook},
2295
+ hooks={"response": _response_hook},
2160
2296
  data=payload,
2161
2297
  )
2162
2298
  setattr(future, "track_id", pk)
@@ -2174,24 +2310,33 @@ class Client(SwaggerClient):
2174
2310
  if "id" in c:
2175
2311
  pk = c.pop("id")
2176
2312
  if not c:
2177
- logger.error(f"SKIPPED: update of {project_name}/{pk} empty.")
2313
+ logger.error(
2314
+ f"SKIPPED: update of {project_name}/{pk} empty."
2315
+ )
2178
2316
 
2179
2317
  payload = ujson.dumps(c).encode("utf-8")
2180
2318
  if len(payload) < MAX_PAYLOAD:
2181
2319
  futures.append(put_future(pk, payload))
2182
2320
  else:
2183
- logger.error(f"SKIPPED: update of {project_name}/{pk} too large.")
2321
+ logger.error(
2322
+ f"SKIPPED: update of {project_name}/{pk} too large."
2323
+ )
2184
2324
  else:
2185
2325
  next_post_chunk = post_chunk + [c]
2186
2326
  next_payload = ujson.dumps(next_post_chunk).encode("utf-8")
2187
- if len(next_post_chunk) > nmax or len(next_payload) >= MAX_PAYLOAD:
2327
+ if (
2328
+ len(next_post_chunk) > nmax
2329
+ or len(next_payload) >= MAX_PAYLOAD
2330
+ ):
2188
2331
  if post_chunk:
2189
2332
  payload = ujson.dumps(post_chunk).encode("utf-8")
2190
2333
  futures.append(post_future(idx, payload))
2191
2334
  post_chunk = []
2192
2335
  idx += 1
2193
2336
  else:
2194
- logger.error(f"SKIPPED: contrib {project_name}/{n} too large.")
2337
+ logger.error(
2338
+ f"SKIPPED: contrib {project_name}/{n} too large."
2339
+ )
2195
2340
  continue
2196
2341
 
2197
2342
  post_chunk.append(c)
@@ -2204,23 +2349,38 @@ class Client(SwaggerClient):
2204
2349
  break # nothing to do
2205
2350
 
2206
2351
  responses = _run_futures(
2207
- futures, total=ncontribs-total_processed, timeout=timeout, desc="Submit"
2352
+ futures,
2353
+ total=ncontribs - total_processed,
2354
+ timeout=timeout,
2355
+ desc="Submit",
2208
2356
  )
2209
2357
  processed = sum(r.get("count", 0) for r in responses.values())
2210
2358
  total_processed += processed
2211
2359
 
2212
- if total_processed != ncontribs and retries < RETRIES and \
2213
- unique_identifiers.get(project_name):
2214
- logger.info(f"{total_processed}/{ncontribs} processed -> retrying ...")
2360
+ if (
2361
+ total_processed != ncontribs
2362
+ and retries < RETRIES
2363
+ and unique_identifiers.get(project_name)
2364
+ ):
2365
+ logger.info(
2366
+ f"{total_processed}/{ncontribs} processed -> retrying ..."
2367
+ )
2215
2368
  existing[project_name] = self.get_all_ids(
2216
- dict(project=project_name), include=COMPONENTS, timeout=timeout
2369
+ dict(project=project_name),
2370
+ include=COMPONENTS,
2371
+ timeout=timeout,
2217
2372
  ).get(project_name, {"identifiers": set()})
2218
- unique_identifiers[project_name] = self.projects.getProjectByName(
2219
- pk=project_name, _fields=["unique_identifiers"]
2220
- ).result()["unique_identifiers"]
2221
- existing_ids = existing.get(project_name, {}).get("identifiers", [])
2373
+ unique_identifiers[project_name] = (
2374
+ self.projects.getProjectByName(
2375
+ pk=project_name, _fields=["unique_identifiers"]
2376
+ ).result()["unique_identifiers"]
2377
+ )
2378
+ existing_ids = existing.get(project_name, {}).get(
2379
+ "identifiers", []
2380
+ )
2222
2381
  contribs[project_name] = [
2223
- c for c in contribs[project_name]
2382
+ c
2383
+ for c in contribs[project_name]
2224
2384
  if c["identifier"] not in existing_ids
2225
2385
  ]
2226
2386
  retries += 1
@@ -2228,7 +2388,9 @@ class Client(SwaggerClient):
2228
2388
  contribs[project_name] = [] # abort retrying
2229
2389
  if total_processed != ncontribs:
2230
2390
  if retries >= RETRIES:
2231
- logger.error(f"{project_name}: Tried {RETRIES} times - abort.")
2391
+ logger.error(
2392
+ f"{project_name}: Tried {RETRIES} times - abort."
2393
+ )
2232
2394
  elif not unique_identifiers.get(project_name):
2233
2395
  logger.info(
2234
2396
  f"{project_name}: resubmit failed contributions manually"
@@ -2238,7 +2400,9 @@ class Client(SwaggerClient):
2238
2400
  dt = (toc - tic) / 60
2239
2401
  self.init_columns()
2240
2402
  self._reinit()
2241
- logger.info(f"It took {dt:.1f}min to submit {total_processed}/{total} contributions.")
2403
+ logger.info(
2404
+ f"It took {dt:.1f}min to submit {total_processed}/{total} contributions."
2405
+ )
2242
2406
  else:
2243
2407
  logger.info("Nothing to submit.")
2244
2408
 
@@ -2248,7 +2412,7 @@ class Client(SwaggerClient):
2248
2412
  outdir: Union[str, Path] = DEFAULT_DOWNLOAD_DIR,
2249
2413
  overwrite: bool = False,
2250
2414
  include: List[str] = None,
2251
- timeout: int = -1
2415
+ timeout: int = -1,
2252
2416
  ) -> int:
2253
2417
  """Download a list of contributions as .json.gz file(s)
2254
2418
 
@@ -2296,8 +2460,12 @@ class Client(SwaggerClient):
2296
2460
  continue
2297
2461
 
2298
2462
  paths = self._download_resource(
2299
- resource=component, ids=ids, fmt=fmt,
2300
- outdir=outdir, overwrite=overwrite, timeout=timeout
2463
+ resource=component,
2464
+ ids=ids,
2465
+ fmt=fmt,
2466
+ outdir=outdir,
2467
+ overwrite=overwrite,
2468
+ timeout=timeout,
2301
2469
  )
2302
2470
  logger.debug(
2303
2471
  f"Downloaded {len(ids)} {component} for '{name}' in {len(paths)} file(s)."
@@ -2314,8 +2482,12 @@ class Client(SwaggerClient):
2314
2482
  continue
2315
2483
 
2316
2484
  paths = self._download_resource(
2317
- resource="contributions", ids=cids, fmt=fmt,
2318
- outdir=outdir, overwrite=overwrite, timeout=timeout
2485
+ resource="contributions",
2486
+ ids=cids,
2487
+ fmt=fmt,
2488
+ outdir=outdir,
2489
+ overwrite=overwrite,
2490
+ timeout=timeout,
2319
2491
  )
2320
2492
  logger.debug(
2321
2493
  f"Downloaded {len(cids)} contributions for '{name}' in {len(paths)} file(s)."
@@ -2341,7 +2513,7 @@ class Client(SwaggerClient):
2341
2513
  outdir: Union[str, Path] = DEFAULT_DOWNLOAD_DIR,
2342
2514
  overwrite: bool = False,
2343
2515
  timeout: int = -1,
2344
- fmt: str = "json"
2516
+ fmt: str = "json",
2345
2517
  ) -> Path:
2346
2518
  """Download a list of structures as a .json.gz file
2347
2519
 
@@ -2356,8 +2528,12 @@ class Client(SwaggerClient):
2356
2528
  paths of output files
2357
2529
  """
2358
2530
  return self._download_resource(
2359
- resource="structures", ids=ids, fmt=fmt,
2360
- outdir=outdir, overwrite=overwrite, timeout=timeout
2531
+ resource="structures",
2532
+ ids=ids,
2533
+ fmt=fmt,
2534
+ outdir=outdir,
2535
+ overwrite=overwrite,
2536
+ timeout=timeout,
2361
2537
  )
2362
2538
 
2363
2539
  def download_tables(
@@ -2366,7 +2542,7 @@ class Client(SwaggerClient):
2366
2542
  outdir: Union[str, Path] = DEFAULT_DOWNLOAD_DIR,
2367
2543
  overwrite: bool = False,
2368
2544
  timeout: int = -1,
2369
- fmt: str = "json"
2545
+ fmt: str = "json",
2370
2546
  ) -> Path:
2371
2547
  """Download a list of tables as a .json.gz file
2372
2548
 
@@ -2381,8 +2557,12 @@ class Client(SwaggerClient):
2381
2557
  paths of output files
2382
2558
  """
2383
2559
  return self._download_resource(
2384
- resource="tables", ids=ids, fmt=fmt,
2385
- outdir=outdir, overwrite=overwrite, timeout=timeout
2560
+ resource="tables",
2561
+ ids=ids,
2562
+ fmt=fmt,
2563
+ outdir=outdir,
2564
+ overwrite=overwrite,
2565
+ timeout=timeout,
2386
2566
  )
2387
2567
 
2388
2568
  def download_attachments(
@@ -2391,7 +2571,7 @@ class Client(SwaggerClient):
2391
2571
  outdir: Union[str, Path] = DEFAULT_DOWNLOAD_DIR,
2392
2572
  overwrite: bool = False,
2393
2573
  timeout: int = -1,
2394
- fmt: str = "json"
2574
+ fmt: str = "json",
2395
2575
  ) -> Path:
2396
2576
  """Download a list of attachments as a .json.gz file
2397
2577
 
@@ -2406,8 +2586,12 @@ class Client(SwaggerClient):
2406
2586
  paths of output files
2407
2587
  """
2408
2588
  return self._download_resource(
2409
- resource="attachments", ids=ids, fmt=fmt,
2410
- outdir=outdir, overwrite=overwrite, timeout=timeout
2589
+ resource="attachments",
2590
+ ids=ids,
2591
+ fmt=fmt,
2592
+ outdir=outdir,
2593
+ overwrite=overwrite,
2594
+ timeout=timeout,
2411
2595
  )
2412
2596
 
2413
2597
  def _download_resource(
@@ -2417,7 +2601,7 @@ class Client(SwaggerClient):
2417
2601
  outdir: Union[str, Path] = DEFAULT_DOWNLOAD_DIR,
2418
2602
  overwrite: bool = False,
2419
2603
  timeout: int = -1,
2420
- fmt: str = "json"
2604
+ fmt: str = "json",
2421
2605
  ) -> Path:
2422
2606
  """Helper to download a list of resources as .json.gz file
2423
2607
 
@@ -2450,7 +2634,9 @@ class Client(SwaggerClient):
2450
2634
  _, total_pages = self.get_totals(
2451
2635
  query=query, resource=resource, op="download", timeout=timeout
2452
2636
  )
2453
- queries = self._split_query(query, resource=resource, op="download", pages=total_pages)
2637
+ queries = self._split_query(
2638
+ query, resource=resource, op="download", pages=total_pages
2639
+ )
2454
2640
  paths, futures = [], []
2455
2641
 
2456
2642
  for query in queries:
@@ -2459,9 +2645,9 @@ class Client(SwaggerClient):
2459
2645
  paths.append(path)
2460
2646
 
2461
2647
  if not path.exists() or overwrite:
2462
- futures.append(self._get_future(
2463
- path, query, rel_url=f"{resource}/download/gz"
2464
- ))
2648
+ futures.append(
2649
+ self._get_future(path, query, rel_url=f"{resource}/download/gz")
2650
+ )
2465
2651
 
2466
2652
  if futures:
2467
2653
  responses = _run_futures(futures, timeout=timeout)