datajoint 0.14.1__py3-none-any.whl → 0.14.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datajoint might be problematic. Click here for more details.

datajoint/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
1
  """
2
- DataJoint for Python is a framework for building data piplines using MySQL databases
2
+ DataJoint for Python is a framework for building data pipelines using MySQL databases
3
3
  to represent pipeline structure and bulk storage systems for large objects.
4
4
  DataJoint is built on the foundation of the relational data model and prescribes a
5
5
  consistent method for organizing, populating, and querying data.
@@ -37,6 +37,7 @@ __all__ = [
37
37
  "Part",
38
38
  "Not",
39
39
  "AndList",
40
+ "Top",
40
41
  "U",
41
42
  "Diagram",
42
43
  "Di",
@@ -51,6 +52,7 @@ __all__ = [
51
52
  "key",
52
53
  "key_hash",
53
54
  "logger",
55
+ "cli",
54
56
  ]
55
57
 
56
58
  from .logging import logger
@@ -61,7 +63,7 @@ from .schemas import Schema
61
63
  from .schemas import VirtualModule, list_schemas
62
64
  from .table import Table, FreeTable
63
65
  from .user_tables import Manual, Lookup, Imported, Computed, Part
64
- from .expression import Not, AndList, U
66
+ from .expression import Not, AndList, U, Top
65
67
  from .diagram import Diagram
66
68
  from .admin import set_password, kill
67
69
  from .blob import MatCell, MatStruct
@@ -70,6 +72,7 @@ from .hash import key_hash
70
72
  from .attribute_adapter import AttributeAdapter
71
73
  from . import errors
72
74
  from .errors import DataJointError
75
+ from .cli import cli
73
76
 
74
77
  ERD = Di = Diagram # Aliases for Diagram
75
78
  schema = Schema # Aliases for Schema
datajoint/admin.py CHANGED
@@ -1,5 +1,6 @@
1
1
  import pymysql
2
2
  from getpass import getpass
3
+ from packaging import version
3
4
  from .connection import conn
4
5
  from .settings import config
5
6
  from .utils import user_choice
@@ -14,9 +15,16 @@ def set_password(new_password=None, connection=None, update_config=None):
14
15
  new_password = getpass("New password: ")
15
16
  confirm_password = getpass("Confirm password: ")
16
17
  if new_password != confirm_password:
17
- logger.warn("Failed to confirm the password! Aborting password change.")
18
+ logger.warning("Failed to confirm the password! Aborting password change.")
18
19
  return
19
- connection.query("SET PASSWORD = PASSWORD('%s')" % new_password)
20
+
21
+ if version.parse(
22
+ connection.query("select @@version;").fetchone()[0]
23
+ ) >= version.parse("5.7"):
24
+ # SET PASSWORD is deprecated as of MySQL 5.7 and removed in 8+
25
+ connection.query("ALTER USER user() IDENTIFIED BY '%s';" % new_password)
26
+ else:
27
+ connection.query("SET PASSWORD = PASSWORD('%s')" % new_password)
20
28
  logger.info("Password updated.")
21
29
 
22
30
  if update_config or (
datajoint/autopopulate.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """This module defines class dj.AutoPopulate"""
2
+
2
3
  import logging
3
4
  import datetime
4
5
  import traceback
@@ -22,7 +23,7 @@ logger = logging.getLogger(__name__.split(".")[0])
22
23
 
23
24
  def _initialize_populate(table, jobs, populate_kwargs):
24
25
  """
25
- Initialize the process for mulitprocessing.
26
+ Initialize the process for multiprocessing.
26
27
  Saves the unpickled copy of the table to the current process and reconnects.
27
28
  """
28
29
  process = mp.current_process()
@@ -118,7 +119,7 @@ class AutoPopulate:
118
119
 
119
120
  def _jobs_to_do(self, restrictions):
120
121
  """
121
- :return: the query yeilding the keys to be computed (derived from self.key_source)
122
+ :return: the query yielding the keys to be computed (derived from self.key_source)
122
123
  """
123
124
  if self.restriction:
124
125
  raise DataJointError(
@@ -152,6 +153,7 @@ class AutoPopulate:
152
153
  def populate(
153
154
  self,
154
155
  *restrictions,
156
+ keys=None,
155
157
  suppress_errors=False,
156
158
  return_exception_objects=False,
157
159
  reserve_jobs=False,
@@ -168,6 +170,8 @@ class AutoPopulate:
168
170
 
169
171
  :param restrictions: a list of restrictions each restrict
170
172
  (table.key_source - target.proj())
173
+ :param keys: The list of keys (dicts) to send to self.make().
174
+ If None (default), then use self.key_source to query they keys.
171
175
  :param suppress_errors: if True, do not terminate execution.
172
176
  :param return_exception_objects: return error objects instead of just error messages
173
177
  :param reserve_jobs: if True, reserve jobs to populate in asynchronous fashion
@@ -180,6 +184,9 @@ class AutoPopulate:
180
184
  to be passed down to each ``make()`` call. Computation arguments should be
181
185
  specified within the pipeline e.g. using a `dj.Lookup` table.
182
186
  :type make_kwargs: dict, optional
187
+ :return: a dict with two keys
188
+ "success_count": the count of successful ``make()`` calls in this ``populate()`` call
189
+ "error_list": the error list that is filled if `suppress_errors` is True
183
190
  """
184
191
  if self.connection.in_transaction:
185
192
  raise DataJointError("Populate cannot be called during a transaction.")
@@ -202,14 +209,17 @@ class AutoPopulate:
202
209
 
203
210
  old_handler = signal.signal(signal.SIGTERM, handler)
204
211
 
205
- keys = (self._jobs_to_do(restrictions) - self.target).fetch("KEY", limit=limit)
212
+ if keys is None:
213
+ keys = (self._jobs_to_do(restrictions) - self.target).fetch(
214
+ "KEY", limit=limit
215
+ )
206
216
 
207
- # exclude "error" or "ignore" jobs
217
+ # exclude "error", "ignore" or "reserved" jobs
208
218
  if reserve_jobs:
209
219
  exclude_key_hashes = (
210
220
  jobs
211
221
  & {"table_name": self.target.table_name}
212
- & 'status in ("error", "ignore")'
222
+ & 'status in ("error", "ignore", "reserved")'
213
223
  ).fetch("key_hash")
214
224
  keys = [key for key in keys if key_hash(key) not in exclude_key_hashes]
215
225
 
@@ -222,49 +232,62 @@ class AutoPopulate:
222
232
 
223
233
  keys = keys[:max_calls]
224
234
  nkeys = len(keys)
225
- if not nkeys:
226
- return
227
-
228
- processes = min(_ for _ in (processes, nkeys, mp.cpu_count()) if _)
229
235
 
230
236
  error_list = []
231
- populate_kwargs = dict(
232
- suppress_errors=suppress_errors,
233
- return_exception_objects=return_exception_objects,
234
- make_kwargs=make_kwargs,
235
- )
237
+ success_list = []
236
238
 
237
- if processes == 1:
238
- for key in (
239
- tqdm(keys, desc=self.__class__.__name__) if display_progress else keys
240
- ):
241
- error = self._populate1(key, jobs, **populate_kwargs)
242
- if error is not None:
243
- error_list.append(error)
244
- else:
245
- # spawn multiple processes
246
- self.connection.close() # disconnect parent process from MySQL server
247
- del self.connection._conn.ctx # SSLContext is not pickleable
248
- with mp.Pool(
249
- processes, _initialize_populate, (self, jobs, populate_kwargs)
250
- ) as pool, (
251
- tqdm(desc="Processes: ", total=nkeys)
252
- if display_progress
253
- else contextlib.nullcontext()
254
- ) as progress_bar:
255
- for error in pool.imap(_call_populate1, keys, chunksize=1):
256
- if error is not None:
257
- error_list.append(error)
258
- if display_progress:
259
- progress_bar.update()
260
- self.connection.connect() # reconnect parent process to MySQL server
239
+ if nkeys:
240
+ processes = min(_ for _ in (processes, nkeys, mp.cpu_count()) if _)
241
+
242
+ populate_kwargs = dict(
243
+ suppress_errors=suppress_errors,
244
+ return_exception_objects=return_exception_objects,
245
+ make_kwargs=make_kwargs,
246
+ )
247
+
248
+ if processes == 1:
249
+ for key in (
250
+ tqdm(keys, desc=self.__class__.__name__)
251
+ if display_progress
252
+ else keys
253
+ ):
254
+ status = self._populate1(key, jobs, **populate_kwargs)
255
+ if status is True:
256
+ success_list.append(1)
257
+ elif isinstance(status, tuple):
258
+ error_list.append(status)
259
+ else:
260
+ assert status is False
261
+ else:
262
+ # spawn multiple processes
263
+ self.connection.close() # disconnect parent process from MySQL server
264
+ del self.connection._conn.ctx # SSLContext is not pickleable
265
+ with mp.Pool(
266
+ processes, _initialize_populate, (self, jobs, populate_kwargs)
267
+ ) as pool, (
268
+ tqdm(desc="Processes: ", total=nkeys)
269
+ if display_progress
270
+ else contextlib.nullcontext()
271
+ ) as progress_bar:
272
+ for status in pool.imap(_call_populate1, keys, chunksize=1):
273
+ if status is True:
274
+ success_list.append(1)
275
+ elif isinstance(status, tuple):
276
+ error_list.append(status)
277
+ else:
278
+ assert status is False
279
+ if display_progress:
280
+ progress_bar.update()
281
+ self.connection.connect() # reconnect parent process to MySQL server
261
282
 
262
283
  # restore original signal handler:
263
284
  if reserve_jobs:
264
285
  signal.signal(signal.SIGTERM, old_handler)
265
286
 
266
- if suppress_errors:
267
- return error_list
287
+ return {
288
+ "success_count": sum(success_list),
289
+ "error_list": error_list,
290
+ }
268
291
 
269
292
  def _populate1(
270
293
  self, key, jobs, suppress_errors, return_exception_objects, make_kwargs=None
@@ -275,55 +298,61 @@ class AutoPopulate:
275
298
  :param key: dict specifying job to populate
276
299
  :param suppress_errors: bool if errors should be suppressed and returned
277
300
  :param return_exception_objects: if True, errors must be returned as objects
278
- :return: (key, error) when suppress_errors=True, otherwise None
301
+ :return: (key, error) when suppress_errors=True,
302
+ True if successfully invoke one `make()` call, otherwise False
279
303
  """
304
+ # use the legacy `_make_tuples` callback.
280
305
  make = self._make_tuples if hasattr(self, "_make_tuples") else self.make
281
306
 
282
- if jobs is None or jobs.reserve(self.target.table_name, self._job_key(key)):
283
- self.connection.start_transaction()
284
- if key in self.target: # already populated
307
+ if jobs is not None and not jobs.reserve(
308
+ self.target.table_name, self._job_key(key)
309
+ ):
310
+ return False
311
+
312
+ self.connection.start_transaction()
313
+ if key in self.target: # already populated
314
+ self.connection.cancel_transaction()
315
+ if jobs is not None:
316
+ jobs.complete(self.target.table_name, self._job_key(key))
317
+ return False
318
+
319
+ logger.debug(f"Making {key} -> {self.target.full_table_name}")
320
+ self.__class__._allow_insert = True
321
+ try:
322
+ make(dict(key), **(make_kwargs or {}))
323
+ except (KeyboardInterrupt, SystemExit, Exception) as error:
324
+ try:
285
325
  self.connection.cancel_transaction()
286
- if jobs is not None:
287
- jobs.complete(self.target.table_name, self._job_key(key))
326
+ except LostConnectionError:
327
+ pass
328
+ error_message = "{exception}{msg}".format(
329
+ exception=error.__class__.__name__,
330
+ msg=": " + str(error) if str(error) else "",
331
+ )
332
+ logger.debug(
333
+ f"Error making {key} -> {self.target.full_table_name} - {error_message}"
334
+ )
335
+ if jobs is not None:
336
+ # show error name and error message (if any)
337
+ jobs.error(
338
+ self.target.table_name,
339
+ self._job_key(key),
340
+ error_message=error_message,
341
+ error_stack=traceback.format_exc(),
342
+ )
343
+ if not suppress_errors or isinstance(error, SystemExit):
344
+ raise
288
345
  else:
289
- logger.debug(f"Making {key} -> {self.target.full_table_name}")
290
- self.__class__._allow_insert = True
291
- try:
292
- make(dict(key), **(make_kwargs or {}))
293
- except (KeyboardInterrupt, SystemExit, Exception) as error:
294
- try:
295
- self.connection.cancel_transaction()
296
- except LostConnectionError:
297
- pass
298
- error_message = "{exception}{msg}".format(
299
- exception=error.__class__.__name__,
300
- msg=": " + str(error) if str(error) else "",
301
- )
302
- logger.debug(
303
- f"Error making {key} -> {self.target.full_table_name} - {error_message}"
304
- )
305
- if jobs is not None:
306
- # show error name and error message (if any)
307
- jobs.error(
308
- self.target.table_name,
309
- self._job_key(key),
310
- error_message=error_message,
311
- error_stack=traceback.format_exc(),
312
- )
313
- if not suppress_errors or isinstance(error, SystemExit):
314
- raise
315
- else:
316
- logger.error(error)
317
- return key, error if return_exception_objects else error_message
318
- else:
319
- self.connection.commit_transaction()
320
- logger.debug(
321
- f"Success making {key} -> {self.target.full_table_name}"
322
- )
323
- if jobs is not None:
324
- jobs.complete(self.target.table_name, self._job_key(key))
325
- finally:
326
- self.__class__._allow_insert = False
346
+ logger.error(error)
347
+ return key, error if return_exception_objects else error_message
348
+ else:
349
+ self.connection.commit_transaction()
350
+ logger.debug(f"Success making {key} -> {self.target.full_table_name}")
351
+ if jobs is not None:
352
+ jobs.complete(self.target.table_name, self._job_key(key))
353
+ return True
354
+ finally:
355
+ self.__class__._allow_insert = False
327
356
 
328
357
  def progress(self, *restrictions, display=False):
329
358
  """
datajoint/blob.py CHANGED
@@ -322,9 +322,11 @@ class Blob:
322
322
  + "\0".join(array.dtype.names).encode() # number of fields
323
323
  + b"\0"
324
324
  + b"".join( # field names
325
- self.pack_recarray(array[f])
326
- if array[f].dtype.fields
327
- else self.pack_array(array[f])
325
+ (
326
+ self.pack_recarray(array[f])
327
+ if array[f].dtype.fields
328
+ else self.pack_array(array[f])
329
+ )
328
330
  for f in array.dtype.names
329
331
  )
330
332
  )
@@ -449,7 +451,7 @@ class Blob:
449
451
  )
450
452
 
451
453
  def read_struct(self):
452
- """deserialize matlab stuct"""
454
+ """deserialize matlab struct"""
453
455
  n_dims = self.read_value()
454
456
  shape = self.read_value(count=n_dims)
455
457
  n_elem = np.prod(shape, dtype=int)
datajoint/cli.py ADDED
@@ -0,0 +1,77 @@
1
+ import argparse
2
+ from code import interact
3
+ from collections import ChainMap
4
+ import datajoint as dj
5
+
6
+
7
+ def cli(args: list = None):
8
+ """
9
+ Console interface for DataJoint Python
10
+
11
+ :param args: List of arguments to be passed in, defaults to reading stdin
12
+ :type args: list, optional
13
+ """
14
+ parser = argparse.ArgumentParser(
15
+ prog="datajoint",
16
+ description="DataJoint console interface.",
17
+ conflict_handler="resolve",
18
+ )
19
+ parser.add_argument(
20
+ "-V", "--version", action="version", version=f"{dj.__name__} {dj.__version__}"
21
+ )
22
+ parser.add_argument(
23
+ "-u",
24
+ "--user",
25
+ type=str,
26
+ default=dj.config["database.user"],
27
+ required=False,
28
+ help="Datajoint username",
29
+ )
30
+ parser.add_argument(
31
+ "-p",
32
+ "--password",
33
+ type=str,
34
+ default=dj.config["database.password"],
35
+ required=False,
36
+ help="Datajoint password",
37
+ )
38
+ parser.add_argument(
39
+ "-h",
40
+ "--host",
41
+ type=str,
42
+ default=dj.config["database.host"],
43
+ required=False,
44
+ help="Datajoint host",
45
+ )
46
+ parser.add_argument(
47
+ "-s",
48
+ "--schemas",
49
+ nargs="+",
50
+ type=str,
51
+ required=False,
52
+ help="A list of virtual module mappings in `db:schema ...` format",
53
+ )
54
+ kwargs = vars(parser.parse_args(args))
55
+ mods = {}
56
+ if kwargs["user"]:
57
+ dj.config["database.user"] = kwargs["user"]
58
+ if kwargs["password"]:
59
+ dj.config["database.password"] = kwargs["password"]
60
+ if kwargs["host"]:
61
+ dj.config["database.host"] = kwargs["host"]
62
+ if kwargs["schemas"]:
63
+ for vm in kwargs["schemas"]:
64
+ d, m = vm.split(":")
65
+ mods[m] = dj.create_virtual_module(m, d)
66
+
67
+ banner = "dj repl\n"
68
+ if mods:
69
+ modstr = "\n".join(" - {}".format(m) for m in mods)
70
+ banner += "\nschema modules:\n\n" + modstr + "\n"
71
+ interact(banner, local=dict(ChainMap(mods, locals(), globals())))
72
+
73
+ raise SystemExit
74
+
75
+
76
+ if __name__ == "__main__":
77
+ cli()
datajoint/condition.py CHANGED
@@ -10,6 +10,8 @@ import numpy
10
10
  import pandas
11
11
  import json
12
12
  from .errors import DataJointError
13
+ from typing import Union, List
14
+ from dataclasses import dataclass
13
15
 
14
16
  JSON_PATTERN = re.compile(
15
17
  r"^(?P<attr>\w+)(\.(?P<path>[\w.*\[\]]+))?(:(?P<type>[\w(,\s)]+))?$"
@@ -61,6 +63,35 @@ class AndList(list):
61
63
  super().append(restriction)
62
64
 
63
65
 
66
+ @dataclass
67
+ class Top:
68
+ """
69
+ A restriction to the top entities of a query.
70
+ In SQL, this corresponds to ORDER BY ... LIMIT ... OFFSET
71
+ """
72
+
73
+ limit: Union[int, None] = 1
74
+ order_by: Union[str, List[str]] = "KEY"
75
+ offset: int = 0
76
+
77
+ def __post_init__(self):
78
+ self.order_by = self.order_by or ["KEY"]
79
+ self.offset = self.offset or 0
80
+
81
+ if self.limit is not None and not isinstance(self.limit, int):
82
+ raise TypeError("Top limit must be an integer")
83
+ if not isinstance(self.order_by, (str, collections.abc.Sequence)) or not all(
84
+ isinstance(r, str) for r in self.order_by
85
+ ):
86
+ raise TypeError("Top order_by attributes must all be strings")
87
+ if not isinstance(self.offset, int):
88
+ raise TypeError("The offset argument must be an integer")
89
+ if self.offset and self.limit is None:
90
+ self.limit = 999999999999 # arbitrary large number to allow query
91
+ if isinstance(self.order_by, str):
92
+ self.order_by = [self.order_by]
93
+
94
+
64
95
  class Not:
65
96
  """invert restriction"""
66
97
 
datajoint/connection.py CHANGED
@@ -2,6 +2,7 @@
2
2
  This module contains the Connection class that manages the connection to the database, and
3
3
  the ``conn`` function that provides access to a persistent connection in datajoint.
4
4
  """
5
+
5
6
  import warnings
6
7
  from contextlib import contextmanager
7
8
  import pymysql as client
@@ -79,6 +80,8 @@ def translate_query_error(client_error, query):
79
80
  # Integrity errors
80
81
  if err == 1062:
81
82
  return errors.DuplicateError(*args)
83
+ if err == 1217: # MySQL 8 error code
84
+ return errors.IntegrityError(*args)
82
85
  if err == 1451:
83
86
  return errors.IntegrityError(*args)
84
87
  if err == 1452:
@@ -113,7 +116,7 @@ def conn(
113
116
  :param init_fun: initialization function
114
117
  :param reset: whether the connection should be reset or not
115
118
  :param use_tls: TLS encryption option. Valid options are: True (required), False
116
- (required no TLS), None (TLS prefered, default), dict (Manually specify values per
119
+ (required no TLS), None (TLS preferred, default), dict (Manually specify values per
117
120
  https://dev.mysql.com/doc/refman/5.7/en/connection-options.html#encrypted-connection-options).
118
121
  """
119
122
  if not hasattr(conn, "connection") or reset:
datajoint/declare.py CHANGED
@@ -2,12 +2,15 @@
2
2
  This module hosts functions to convert DataJoint table definitions into mysql table definitions, and to
3
3
  declare the corresponding mysql tables.
4
4
  """
5
+
5
6
  import re
6
7
  import pyparsing as pp
7
8
  import logging
9
+ from hashlib import sha1
8
10
  from .errors import DataJointError, _support_filepath_types, FILEPATH_FEATURE_SWITCH
9
11
  from .attribute_adapter import get_adapter
10
12
  from .condition import translate_attribute
13
+ from .settings import config
11
14
 
12
15
  UUID_DATA_TYPE = "binary(16)"
13
16
  MAX_TABLE_NAME_LENGTH = 64
@@ -309,6 +312,19 @@ def declare(full_table_name, definition, context):
309
312
  external_stores,
310
313
  ) = prepare_declare(definition, context)
311
314
 
315
+ if config.get("add_hidden_timestamp", False):
316
+ metadata_attr_sql = [
317
+ "`_{full_table_name}_timestamp` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP"
318
+ ]
319
+ attribute_sql.extend(
320
+ attr.format(
321
+ full_table_name=sha1(
322
+ full_table_name.replace("`", "").encode("utf-8")
323
+ ).hexdigest()
324
+ )
325
+ for attr in metadata_attr_sql
326
+ )
327
+
312
328
  if not primary_key:
313
329
  raise DataJointError("Table must have a primary key")
314
330
 
@@ -382,9 +398,7 @@ def _make_attribute_alter(new, old, primary_key):
382
398
  command=(
383
399
  "ADD"
384
400
  if (old_name or new_name) not in old_names
385
- else "MODIFY"
386
- if not old_name
387
- else "CHANGE `%s`" % old_name
401
+ else "MODIFY" if not old_name else "CHANGE `%s`" % old_name
388
402
  ),
389
403
  new_def=new_def,
390
404
  after="" if after is None else "AFTER `%s`" % after,
@@ -443,9 +457,11 @@ def compile_index(line, index_sql):
443
457
  return f"`{attr}`"
444
458
  return f"({attr})"
445
459
 
446
- match = re.match(
447
- r"(?P<unique>unique\s+)?index\s*\(\s*(?P<args>.*)\)", line, re.I
448
- ).groupdict()
460
+ match = re.match(r"(?P<unique>unique\s+)?index\s*\(\s*(?P<args>.*)\)", line, re.I)
461
+ if match is None:
462
+ raise DataJointError(f'Table definition syntax error in line "{line}"')
463
+ match = match.groupdict()
464
+
449
465
  attr_list = re.findall(r"(?:[^,(]|\([^)]*\))+", match["args"])
450
466
  index_sql.append(
451
467
  "{unique}index ({attrs})".format(