datajoint 0.14.2__py3-none-any.whl → 0.14.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datajoint might be problematic. Click here for more details.
- datajoint/__init__.py +16 -14
- datajoint/admin.py +4 -2
- datajoint/attribute_adapter.py +1 -0
- datajoint/autopopulate.py +62 -20
- datajoint/blob.py +6 -5
- datajoint/cli.py +78 -0
- datajoint/condition.py +38 -5
- datajoint/connection.py +17 -10
- datajoint/declare.py +25 -6
- datajoint/dependencies.py +67 -33
- datajoint/diagram.py +58 -48
- datajoint/expression.py +92 -42
- datajoint/external.py +17 -10
- datajoint/fetch.py +18 -42
- datajoint/hash.py +1 -1
- datajoint/heading.py +14 -11
- datajoint/jobs.py +4 -3
- datajoint/plugin.py +5 -3
- datajoint/s3.py +6 -4
- datajoint/schemas.py +18 -19
- datajoint/settings.py +25 -11
- datajoint/table.py +27 -22
- datajoint/user_tables.py +30 -2
- datajoint/utils.py +2 -1
- datajoint/version.py +4 -1
- datajoint-0.14.4.dist-info/METADATA +703 -0
- datajoint-0.14.4.dist-info/RECORD +34 -0
- {datajoint-0.14.2.dist-info → datajoint-0.14.4.dist-info}/WHEEL +1 -1
- datajoint-0.14.4.dist-info/entry_points.txt +3 -0
- datajoint-0.14.2.dist-info/METADATA +0 -26
- datajoint-0.14.2.dist-info/RECORD +0 -33
- datajoint-0.14.2.dist-info/datajoint.pub +0 -6
- {datajoint-0.14.2.dist-info → datajoint-0.14.4.dist-info/licenses}/LICENSE.txt +0 -0
- {datajoint-0.14.2.dist-info → datajoint-0.14.4.dist-info}/top_level.txt +0 -0
datajoint/__init__.py
CHANGED
|
@@ -37,6 +37,7 @@ __all__ = [
|
|
|
37
37
|
"Part",
|
|
38
38
|
"Not",
|
|
39
39
|
"AndList",
|
|
40
|
+
"Top",
|
|
40
41
|
"U",
|
|
41
42
|
"Diagram",
|
|
42
43
|
"Di",
|
|
@@ -51,25 +52,26 @@ __all__ = [
|
|
|
51
52
|
"key",
|
|
52
53
|
"key_hash",
|
|
53
54
|
"logger",
|
|
55
|
+
"cli",
|
|
54
56
|
]
|
|
55
57
|
|
|
56
|
-
from .
|
|
57
|
-
from .
|
|
58
|
-
from .
|
|
59
|
-
from .connection import conn, Connection
|
|
60
|
-
from .schemas import Schema
|
|
61
|
-
from .schemas import VirtualModule, list_schemas
|
|
62
|
-
from .table import Table, FreeTable
|
|
63
|
-
from .user_tables import Manual, Lookup, Imported, Computed, Part
|
|
64
|
-
from .expression import Not, AndList, U
|
|
65
|
-
from .diagram import Diagram
|
|
66
|
-
from .admin import set_password, kill
|
|
58
|
+
from . import errors
|
|
59
|
+
from .admin import kill, set_password
|
|
60
|
+
from .attribute_adapter import AttributeAdapter
|
|
67
61
|
from .blob import MatCell, MatStruct
|
|
62
|
+
from .cli import cli
|
|
63
|
+
from .connection import Connection, conn
|
|
64
|
+
from .diagram import Diagram
|
|
65
|
+
from .errors import DataJointError
|
|
66
|
+
from .expression import AndList, Not, Top, U
|
|
68
67
|
from .fetch import key
|
|
69
68
|
from .hash import key_hash
|
|
70
|
-
from .
|
|
71
|
-
from . import
|
|
72
|
-
from .
|
|
69
|
+
from .logging import logger
|
|
70
|
+
from .schemas import Schema, VirtualModule, list_schemas
|
|
71
|
+
from .settings import config
|
|
72
|
+
from .table import FreeTable, Table
|
|
73
|
+
from .user_tables import Computed, Imported, Lookup, Manual, Part
|
|
74
|
+
from .version import __version__
|
|
73
75
|
|
|
74
76
|
ERD = Di = Diagram # Aliases for Diagram
|
|
75
77
|
schema = Schema # Aliases for Schema
|
datajoint/admin.py
CHANGED
|
@@ -1,10 +1,12 @@
|
|
|
1
|
-
import
|
|
1
|
+
import logging
|
|
2
2
|
from getpass import getpass
|
|
3
|
+
|
|
4
|
+
import pymysql
|
|
3
5
|
from packaging import version
|
|
6
|
+
|
|
4
7
|
from .connection import conn
|
|
5
8
|
from .settings import config
|
|
6
9
|
from .utils import user_choice
|
|
7
|
-
import logging
|
|
8
10
|
|
|
9
11
|
logger = logging.getLogger(__name__.split(".")[0])
|
|
10
12
|
|
datajoint/attribute_adapter.py
CHANGED
datajoint/autopopulate.py
CHANGED
|
@@ -1,17 +1,20 @@
|
|
|
1
1
|
"""This module defines class dj.AutoPopulate"""
|
|
2
2
|
|
|
3
|
-
import
|
|
3
|
+
import contextlib
|
|
4
4
|
import datetime
|
|
5
|
-
import traceback
|
|
6
|
-
import random
|
|
7
5
|
import inspect
|
|
6
|
+
import logging
|
|
7
|
+
import multiprocessing as mp
|
|
8
|
+
import random
|
|
9
|
+
import signal
|
|
10
|
+
import traceback
|
|
11
|
+
|
|
12
|
+
import deepdiff
|
|
8
13
|
from tqdm import tqdm
|
|
9
|
-
|
|
10
|
-
from .expression import QueryExpression, AndList
|
|
14
|
+
|
|
11
15
|
from .errors import DataJointError, LostConnectionError
|
|
12
|
-
import
|
|
13
|
-
|
|
14
|
-
import contextlib
|
|
16
|
+
from .expression import AndList, QueryExpression
|
|
17
|
+
from .hash import key_hash
|
|
15
18
|
|
|
16
19
|
# noinspection PyExceptionInherit,PyCallingNonCallable
|
|
17
20
|
|
|
@@ -23,7 +26,7 @@ logger = logging.getLogger(__name__.split(".")[0])
|
|
|
23
26
|
|
|
24
27
|
def _initialize_populate(table, jobs, populate_kwargs):
|
|
25
28
|
"""
|
|
26
|
-
Initialize the process for
|
|
29
|
+
Initialize the process for multiprocessing.
|
|
27
30
|
Saves the unpickled copy of the table to the current process and reconnects.
|
|
28
31
|
"""
|
|
29
32
|
process = mp.current_process()
|
|
@@ -153,6 +156,7 @@ class AutoPopulate:
|
|
|
153
156
|
def populate(
|
|
154
157
|
self,
|
|
155
158
|
*restrictions,
|
|
159
|
+
keys=None,
|
|
156
160
|
suppress_errors=False,
|
|
157
161
|
return_exception_objects=False,
|
|
158
162
|
reserve_jobs=False,
|
|
@@ -169,6 +173,8 @@ class AutoPopulate:
|
|
|
169
173
|
|
|
170
174
|
:param restrictions: a list of restrictions each restrict
|
|
171
175
|
(table.key_source - target.proj())
|
|
176
|
+
:param keys: The list of keys (dicts) to send to self.make().
|
|
177
|
+
If None (default), then use self.key_source to query they keys.
|
|
172
178
|
:param suppress_errors: if True, do not terminate execution.
|
|
173
179
|
:param return_exception_objects: return error objects instead of just error messages
|
|
174
180
|
:param reserve_jobs: if True, reserve jobs to populate in asynchronous fashion
|
|
@@ -206,7 +212,10 @@ class AutoPopulate:
|
|
|
206
212
|
|
|
207
213
|
old_handler = signal.signal(signal.SIGTERM, handler)
|
|
208
214
|
|
|
209
|
-
keys
|
|
215
|
+
if keys is None:
|
|
216
|
+
keys = (self._jobs_to_do(restrictions) - self.target).fetch(
|
|
217
|
+
"KEY", limit=limit
|
|
218
|
+
)
|
|
210
219
|
|
|
211
220
|
# exclude "error", "ignore" or "reserved" jobs
|
|
212
221
|
if reserve_jobs:
|
|
@@ -256,13 +265,16 @@ class AutoPopulate:
|
|
|
256
265
|
# spawn multiple processes
|
|
257
266
|
self.connection.close() # disconnect parent process from MySQL server
|
|
258
267
|
del self.connection._conn.ctx # SSLContext is not pickleable
|
|
259
|
-
with
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
268
|
+
with (
|
|
269
|
+
mp.Pool(
|
|
270
|
+
processes, _initialize_populate, (self, jobs, populate_kwargs)
|
|
271
|
+
) as pool,
|
|
272
|
+
(
|
|
273
|
+
tqdm(desc="Processes: ", total=nkeys)
|
|
274
|
+
if display_progress
|
|
275
|
+
else contextlib.nullcontext()
|
|
276
|
+
) as progress_bar,
|
|
277
|
+
):
|
|
266
278
|
for status in pool.imap(_call_populate1, keys, chunksize=1):
|
|
267
279
|
if status is True:
|
|
268
280
|
success_list.append(1)
|
|
@@ -295,6 +307,7 @@ class AutoPopulate:
|
|
|
295
307
|
:return: (key, error) when suppress_errors=True,
|
|
296
308
|
True if successfully invoke one `make()` call, otherwise False
|
|
297
309
|
"""
|
|
310
|
+
# use the legacy `_make_tuples` callback.
|
|
298
311
|
make = self._make_tuples if hasattr(self, "_make_tuples") else self.make
|
|
299
312
|
|
|
300
313
|
if jobs is not None and not jobs.reserve(
|
|
@@ -302,17 +315,46 @@ class AutoPopulate:
|
|
|
302
315
|
):
|
|
303
316
|
return False
|
|
304
317
|
|
|
305
|
-
|
|
318
|
+
# if make is a generator, it transaction can be delayed until the final stage
|
|
319
|
+
is_generator = inspect.isgeneratorfunction(make)
|
|
320
|
+
if not is_generator:
|
|
321
|
+
self.connection.start_transaction()
|
|
322
|
+
|
|
306
323
|
if key in self.target: # already populated
|
|
307
|
-
|
|
324
|
+
if not is_generator:
|
|
325
|
+
self.connection.cancel_transaction()
|
|
308
326
|
if jobs is not None:
|
|
309
327
|
jobs.complete(self.target.table_name, self._job_key(key))
|
|
310
328
|
return False
|
|
311
329
|
|
|
312
330
|
logger.debug(f"Making {key} -> {self.target.full_table_name}")
|
|
313
331
|
self.__class__._allow_insert = True
|
|
332
|
+
|
|
314
333
|
try:
|
|
315
|
-
|
|
334
|
+
if not is_generator:
|
|
335
|
+
make(dict(key), **(make_kwargs or {}))
|
|
336
|
+
else:
|
|
337
|
+
# tripartite make - transaction is delayed until the final stage
|
|
338
|
+
gen = make(dict(key), **(make_kwargs or {}))
|
|
339
|
+
fetched_data = next(gen)
|
|
340
|
+
fetch_hash = deepdiff.DeepHash(
|
|
341
|
+
fetched_data, ignore_iterable_order=False
|
|
342
|
+
)[fetched_data]
|
|
343
|
+
computed_result = next(gen) # perform the computation
|
|
344
|
+
# fetch and insert inside a transaction
|
|
345
|
+
self.connection.start_transaction()
|
|
346
|
+
gen = make(dict(key), **(make_kwargs or {})) # restart make
|
|
347
|
+
fetched_data = next(gen)
|
|
348
|
+
if (
|
|
349
|
+
fetch_hash
|
|
350
|
+
!= deepdiff.DeepHash(fetched_data, ignore_iterable_order=False)[
|
|
351
|
+
fetched_data
|
|
352
|
+
]
|
|
353
|
+
): # rollback due to referential integrity fail
|
|
354
|
+
self.connection.cancel_transaction()
|
|
355
|
+
return False
|
|
356
|
+
gen.send(computed_result) # insert
|
|
357
|
+
|
|
316
358
|
except (KeyboardInterrupt, SystemExit, Exception) as error:
|
|
317
359
|
try:
|
|
318
360
|
self.connection.cancel_transaction()
|
datajoint/blob.py
CHANGED
|
@@ -3,17 +3,18 @@
|
|
|
3
3
|
compatibility with Matlab-based serialization implemented by mYm.
|
|
4
4
|
"""
|
|
5
5
|
|
|
6
|
-
import zlib
|
|
7
|
-
from itertools import repeat
|
|
8
6
|
import collections
|
|
9
|
-
from decimal import Decimal
|
|
10
7
|
import datetime
|
|
11
8
|
import uuid
|
|
9
|
+
import zlib
|
|
10
|
+
from decimal import Decimal
|
|
11
|
+
from itertools import repeat
|
|
12
|
+
|
|
12
13
|
import numpy as np
|
|
14
|
+
|
|
13
15
|
from .errors import DataJointError
|
|
14
16
|
from .settings import config
|
|
15
17
|
|
|
16
|
-
|
|
17
18
|
deserialize_lookup = {
|
|
18
19
|
0: {"dtype": None, "scalar_type": "UNKNOWN"},
|
|
19
20
|
1: {"dtype": None, "scalar_type": "CELL"},
|
|
@@ -204,7 +205,7 @@ class Blob:
|
|
|
204
205
|
return self.pack_dict(obj)
|
|
205
206
|
if isinstance(obj, str):
|
|
206
207
|
return self.pack_string(obj)
|
|
207
|
-
if isinstance(obj,
|
|
208
|
+
if isinstance(obj, (bytes, bytearray)):
|
|
208
209
|
return self.pack_bytes(obj)
|
|
209
210
|
if isinstance(obj, collections.abc.MutableSequence):
|
|
210
211
|
return self.pack_list(obj)
|
datajoint/cli.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
from code import interact
|
|
3
|
+
from collections import ChainMap
|
|
4
|
+
|
|
5
|
+
import datajoint as dj
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def cli(args: list = None):
|
|
9
|
+
"""
|
|
10
|
+
Console interface for DataJoint Python
|
|
11
|
+
|
|
12
|
+
:param args: List of arguments to be passed in, defaults to reading stdin
|
|
13
|
+
:type args: list, optional
|
|
14
|
+
"""
|
|
15
|
+
parser = argparse.ArgumentParser(
|
|
16
|
+
prog="datajoint",
|
|
17
|
+
description="DataJoint console interface.",
|
|
18
|
+
conflict_handler="resolve",
|
|
19
|
+
)
|
|
20
|
+
parser.add_argument(
|
|
21
|
+
"-V", "--version", action="version", version=f"{dj.__name__} {dj.__version__}"
|
|
22
|
+
)
|
|
23
|
+
parser.add_argument(
|
|
24
|
+
"-u",
|
|
25
|
+
"--user",
|
|
26
|
+
type=str,
|
|
27
|
+
default=dj.config["database.user"],
|
|
28
|
+
required=False,
|
|
29
|
+
help="Datajoint username",
|
|
30
|
+
)
|
|
31
|
+
parser.add_argument(
|
|
32
|
+
"-p",
|
|
33
|
+
"--password",
|
|
34
|
+
type=str,
|
|
35
|
+
default=dj.config["database.password"],
|
|
36
|
+
required=False,
|
|
37
|
+
help="Datajoint password",
|
|
38
|
+
)
|
|
39
|
+
parser.add_argument(
|
|
40
|
+
"-h",
|
|
41
|
+
"--host",
|
|
42
|
+
type=str,
|
|
43
|
+
default=dj.config["database.host"],
|
|
44
|
+
required=False,
|
|
45
|
+
help="Datajoint host",
|
|
46
|
+
)
|
|
47
|
+
parser.add_argument(
|
|
48
|
+
"-s",
|
|
49
|
+
"--schemas",
|
|
50
|
+
nargs="+",
|
|
51
|
+
type=str,
|
|
52
|
+
required=False,
|
|
53
|
+
help="A list of virtual module mappings in `db:schema ...` format",
|
|
54
|
+
)
|
|
55
|
+
kwargs = vars(parser.parse_args(args))
|
|
56
|
+
mods = {}
|
|
57
|
+
if kwargs["user"]:
|
|
58
|
+
dj.config["database.user"] = kwargs["user"]
|
|
59
|
+
if kwargs["password"]:
|
|
60
|
+
dj.config["database.password"] = kwargs["password"]
|
|
61
|
+
if kwargs["host"]:
|
|
62
|
+
dj.config["database.host"] = kwargs["host"]
|
|
63
|
+
if kwargs["schemas"]:
|
|
64
|
+
for vm in kwargs["schemas"]:
|
|
65
|
+
d, m = vm.split(":")
|
|
66
|
+
mods[m] = dj.create_virtual_module(m, d)
|
|
67
|
+
|
|
68
|
+
banner = "dj repl\n"
|
|
69
|
+
if mods:
|
|
70
|
+
modstr = "\n".join(" - {}".format(m) for m in mods)
|
|
71
|
+
banner += "\nschema modules:\n\n" + modstr + "\n"
|
|
72
|
+
interact(banner, local=dict(ChainMap(mods, locals(), globals())))
|
|
73
|
+
|
|
74
|
+
raise SystemExit
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
if __name__ == "__main__":
|
|
78
|
+
cli()
|
datajoint/condition.py
CHANGED
|
@@ -1,14 +1,18 @@
|
|
|
1
1
|
""" methods for generating SQL WHERE clauses from datajoint restriction conditions """
|
|
2
2
|
|
|
3
|
-
import inspect
|
|
4
3
|
import collections
|
|
5
|
-
import re
|
|
6
|
-
import uuid
|
|
7
4
|
import datetime
|
|
8
5
|
import decimal
|
|
6
|
+
import inspect
|
|
7
|
+
import json
|
|
8
|
+
import re
|
|
9
|
+
import uuid
|
|
10
|
+
from dataclasses import dataclass
|
|
11
|
+
from typing import List, Union
|
|
12
|
+
|
|
9
13
|
import numpy
|
|
10
14
|
import pandas
|
|
11
|
-
|
|
15
|
+
|
|
12
16
|
from .errors import DataJointError
|
|
13
17
|
|
|
14
18
|
JSON_PATTERN = re.compile(
|
|
@@ -61,6 +65,35 @@ class AndList(list):
|
|
|
61
65
|
super().append(restriction)
|
|
62
66
|
|
|
63
67
|
|
|
68
|
+
@dataclass
|
|
69
|
+
class Top:
|
|
70
|
+
"""
|
|
71
|
+
A restriction to the top entities of a query.
|
|
72
|
+
In SQL, this corresponds to ORDER BY ... LIMIT ... OFFSET
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
limit: Union[int, None] = 1
|
|
76
|
+
order_by: Union[str, List[str]] = "KEY"
|
|
77
|
+
offset: int = 0
|
|
78
|
+
|
|
79
|
+
def __post_init__(self):
|
|
80
|
+
self.order_by = self.order_by or ["KEY"]
|
|
81
|
+
self.offset = self.offset or 0
|
|
82
|
+
|
|
83
|
+
if self.limit is not None and not isinstance(self.limit, int):
|
|
84
|
+
raise TypeError("Top limit must be an integer")
|
|
85
|
+
if not isinstance(self.order_by, (str, collections.abc.Sequence)) or not all(
|
|
86
|
+
isinstance(r, str) for r in self.order_by
|
|
87
|
+
):
|
|
88
|
+
raise TypeError("Top order_by attributes must all be strings")
|
|
89
|
+
if not isinstance(self.offset, int):
|
|
90
|
+
raise TypeError("The offset argument must be an integer")
|
|
91
|
+
if self.offset and self.limit is None:
|
|
92
|
+
self.limit = 999999999999 # arbitrary large number to allow query
|
|
93
|
+
if isinstance(self.order_by, str):
|
|
94
|
+
self.order_by = [self.order_by]
|
|
95
|
+
|
|
96
|
+
|
|
64
97
|
class Not:
|
|
65
98
|
"""invert restriction"""
|
|
66
99
|
|
|
@@ -112,7 +145,7 @@ def make_condition(query_expression, condition, columns):
|
|
|
112
145
|
condition.
|
|
113
146
|
:return: an SQL condition string or a boolean value.
|
|
114
147
|
"""
|
|
115
|
-
from .expression import
|
|
148
|
+
from .expression import Aggregation, QueryExpression, U
|
|
116
149
|
|
|
117
150
|
def prep_value(k, v):
|
|
118
151
|
"""prepare SQL condition"""
|
datajoint/connection.py
CHANGED
|
@@ -3,20 +3,22 @@ This module contains the Connection class that manages the connection to the dat
|
|
|
3
3
|
the ``conn`` function that provides access to a persistent connection in datajoint.
|
|
4
4
|
"""
|
|
5
5
|
|
|
6
|
+
import logging
|
|
7
|
+
import pathlib
|
|
8
|
+
import re
|
|
6
9
|
import warnings
|
|
7
10
|
from contextlib import contextmanager
|
|
8
|
-
import pymysql as client
|
|
9
|
-
import logging
|
|
10
11
|
from getpass import getpass
|
|
11
|
-
import re
|
|
12
|
-
import pathlib
|
|
13
12
|
|
|
14
|
-
|
|
13
|
+
import pymysql as client
|
|
14
|
+
|
|
15
15
|
from . import errors
|
|
16
|
-
from .dependencies import Dependencies
|
|
17
16
|
from .blob import pack, unpack
|
|
17
|
+
from .dependencies import Dependencies
|
|
18
18
|
from .hash import uuid_from_buffer
|
|
19
19
|
from .plugin import connection_plugins
|
|
20
|
+
from .settings import config
|
|
21
|
+
from .version import __version__
|
|
20
22
|
|
|
21
23
|
logger = logging.getLogger(__name__.split(".")[0])
|
|
22
24
|
query_log_max_length = 300
|
|
@@ -190,15 +192,20 @@ class Connection:
|
|
|
190
192
|
self.conn_info["ssl_input"] = use_tls
|
|
191
193
|
self.conn_info["host_input"] = host_input
|
|
192
194
|
self.init_fun = init_fun
|
|
193
|
-
logger.info("Connecting {user}@{host}:{port}".format(**self.conn_info))
|
|
194
195
|
self._conn = None
|
|
195
196
|
self._query_cache = None
|
|
196
197
|
connect_host_hook(self)
|
|
197
198
|
if self.is_connected:
|
|
198
|
-
logger.info(
|
|
199
|
+
logger.info(
|
|
200
|
+
"DataJoint {version} connected to {user}@{host}:{port}".format(
|
|
201
|
+
version=__version__, **self.conn_info
|
|
202
|
+
)
|
|
203
|
+
)
|
|
199
204
|
self.connection_id = self.query("SELECT connection_id()").fetchone()[0]
|
|
200
205
|
else:
|
|
201
|
-
raise errors.LostConnectionError(
|
|
206
|
+
raise errors.LostConnectionError(
|
|
207
|
+
"Connection failed {user}@{host}:{port}".format(**self.conn_info)
|
|
208
|
+
)
|
|
202
209
|
self._in_transaction = False
|
|
203
210
|
self.schemas = dict()
|
|
204
211
|
self.dependencies = Dependencies(self)
|
|
@@ -344,7 +351,7 @@ class Connection:
|
|
|
344
351
|
except errors.LostConnectionError:
|
|
345
352
|
if not reconnect:
|
|
346
353
|
raise
|
|
347
|
-
logger.warning("
|
|
354
|
+
logger.warning("Reconnecting to MySQL server.")
|
|
348
355
|
connect_host_hook(self)
|
|
349
356
|
if self._in_transaction:
|
|
350
357
|
self.cancel_transaction()
|
datajoint/declare.py
CHANGED
|
@@ -3,12 +3,16 @@ This module hosts functions to convert DataJoint table definitions into mysql ta
|
|
|
3
3
|
declare the corresponding mysql tables.
|
|
4
4
|
"""
|
|
5
5
|
|
|
6
|
+
import logging
|
|
6
7
|
import re
|
|
8
|
+
from hashlib import sha1
|
|
9
|
+
|
|
7
10
|
import pyparsing as pp
|
|
8
|
-
|
|
9
|
-
from .errors import DataJointError, _support_filepath_types, FILEPATH_FEATURE_SWITCH
|
|
11
|
+
|
|
10
12
|
from .attribute_adapter import get_adapter
|
|
11
13
|
from .condition import translate_attribute
|
|
14
|
+
from .errors import FILEPATH_FEATURE_SWITCH, DataJointError, _support_filepath_types
|
|
15
|
+
from .settings import config
|
|
12
16
|
|
|
13
17
|
UUID_DATA_TYPE = "binary(16)"
|
|
14
18
|
MAX_TABLE_NAME_LENGTH = 64
|
|
@@ -161,8 +165,8 @@ def compile_foreign_key(
|
|
|
161
165
|
:param index_sql: list of INDEX declaration statements, duplicate or redundant indexes are ok.
|
|
162
166
|
"""
|
|
163
167
|
# Parse and validate
|
|
164
|
-
from .table import Table
|
|
165
168
|
from .expression import QueryExpression
|
|
169
|
+
from .table import Table
|
|
166
170
|
|
|
167
171
|
try:
|
|
168
172
|
result = foreign_key_parser.parseString(line)
|
|
@@ -310,6 +314,19 @@ def declare(full_table_name, definition, context):
|
|
|
310
314
|
external_stores,
|
|
311
315
|
) = prepare_declare(definition, context)
|
|
312
316
|
|
|
317
|
+
if config.get("add_hidden_timestamp", False):
|
|
318
|
+
metadata_attr_sql = [
|
|
319
|
+
"`_{full_table_name}_timestamp` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP"
|
|
320
|
+
]
|
|
321
|
+
attribute_sql.extend(
|
|
322
|
+
attr.format(
|
|
323
|
+
full_table_name=sha1(
|
|
324
|
+
full_table_name.replace("`", "").encode("utf-8")
|
|
325
|
+
).hexdigest()
|
|
326
|
+
)
|
|
327
|
+
for attr in metadata_attr_sql
|
|
328
|
+
)
|
|
329
|
+
|
|
313
330
|
if not primary_key:
|
|
314
331
|
raise DataJointError("Table must have a primary key")
|
|
315
332
|
|
|
@@ -442,9 +459,11 @@ def compile_index(line, index_sql):
|
|
|
442
459
|
return f"`{attr}`"
|
|
443
460
|
return f"({attr})"
|
|
444
461
|
|
|
445
|
-
match = re.match(
|
|
446
|
-
|
|
447
|
-
|
|
462
|
+
match = re.match(r"(?P<unique>unique\s+)?index\s*\(\s*(?P<args>.*)\)", line, re.I)
|
|
463
|
+
if match is None:
|
|
464
|
+
raise DataJointError(f'Table definition syntax error in line "{line}"')
|
|
465
|
+
match = match.groupdict()
|
|
466
|
+
|
|
448
467
|
attr_list = re.findall(r"(?:[^,(]|\([^)]*\))+", match["args"])
|
|
449
468
|
index_sql.append(
|
|
450
469
|
"{unique}index ({attrs})".format(
|
datajoint/dependencies.py
CHANGED
|
@@ -1,32 +1,70 @@
|
|
|
1
|
-
import networkx as nx
|
|
2
1
|
import itertools
|
|
3
2
|
import re
|
|
4
3
|
from collections import defaultdict
|
|
4
|
+
|
|
5
|
+
import networkx as nx
|
|
6
|
+
|
|
5
7
|
from .errors import DataJointError
|
|
6
8
|
|
|
7
9
|
|
|
8
|
-
def
|
|
10
|
+
def extract_master(part_table):
|
|
11
|
+
"""
|
|
12
|
+
given a part table name, return master part. None if not a part table
|
|
13
|
+
"""
|
|
14
|
+
match = re.match(r"(?P<master>`\w+`.`#?\w+)__\w+`", part_table)
|
|
15
|
+
return match["master"] + "`" if match else None
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def topo_sort(graph):
|
|
9
19
|
"""
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
Without this correction, a simple topological sort may insert other descendants between master and parts.
|
|
13
|
-
The input list must be topologically sorted.
|
|
14
|
-
:example:
|
|
15
|
-
unite_master_parts(
|
|
16
|
-
['`s`.`a`', '`s`.`a__q`', '`s`.`b`', '`s`.`c`', '`s`.`c__q`', '`s`.`b__q`', '`s`.`d`', '`s`.`a__r`']) ->
|
|
17
|
-
['`s`.`a`', '`s`.`a__q`', '`s`.`a__r`', '`s`.`b`', '`s`.`b__q`', '`s`.`c`', '`s`.`c__q`', '`s`.`d`']
|
|
20
|
+
topological sort of a dependency graph that keeps part tables together with their masters
|
|
21
|
+
:return: list of table names in topological order
|
|
18
22
|
"""
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
23
|
+
|
|
24
|
+
graph = nx.DiGraph(graph) # make a copy
|
|
25
|
+
|
|
26
|
+
# collapse alias nodes
|
|
27
|
+
alias_nodes = [node for node in graph if node.isdigit()]
|
|
28
|
+
for node in alias_nodes:
|
|
29
|
+
try:
|
|
30
|
+
direct_edge = (
|
|
31
|
+
next(x for x in graph.in_edges(node))[0],
|
|
32
|
+
next(x for x in graph.out_edges(node))[1],
|
|
33
|
+
)
|
|
34
|
+
except StopIteration:
|
|
35
|
+
pass # a disconnected alias node
|
|
36
|
+
else:
|
|
37
|
+
graph.add_edge(*direct_edge)
|
|
38
|
+
graph.remove_nodes_from(alias_nodes)
|
|
39
|
+
|
|
40
|
+
# Add parts' dependencies to their masters' dependencies
|
|
41
|
+
# to ensure correct topological ordering of the masters.
|
|
42
|
+
for part in graph:
|
|
43
|
+
# find the part's master
|
|
44
|
+
if (master := extract_master(part)) in graph:
|
|
45
|
+
for edge in graph.in_edges(part):
|
|
46
|
+
parent = edge[0]
|
|
47
|
+
if master not in (parent, extract_master(parent)):
|
|
48
|
+
# if parent is neither master nor part of master
|
|
49
|
+
graph.add_edge(parent, master)
|
|
50
|
+
sorted_nodes = list(nx.topological_sort(graph))
|
|
51
|
+
|
|
52
|
+
# bring parts up to their masters
|
|
53
|
+
pos = len(sorted_nodes) - 1
|
|
54
|
+
placed = set()
|
|
55
|
+
while pos > 1:
|
|
56
|
+
part = sorted_nodes[pos]
|
|
57
|
+
if (master := extract_master(part)) not in graph or part in placed:
|
|
58
|
+
pos -= 1
|
|
59
|
+
else:
|
|
60
|
+
placed.add(part)
|
|
61
|
+
insert_pos = sorted_nodes.index(master) + 1
|
|
62
|
+
if pos > insert_pos:
|
|
63
|
+
# move the part to the position immediately after its master
|
|
64
|
+
del sorted_nodes[pos]
|
|
65
|
+
sorted_nodes.insert(insert_pos, part)
|
|
66
|
+
|
|
67
|
+
return sorted_nodes
|
|
30
68
|
|
|
31
69
|
|
|
32
70
|
class Dependencies(nx.DiGraph):
|
|
@@ -131,6 +169,10 @@ class Dependencies(nx.DiGraph):
|
|
|
131
169
|
raise DataJointError("DataJoint can only work with acyclic dependencies")
|
|
132
170
|
self._loaded = True
|
|
133
171
|
|
|
172
|
+
def topo_sort(self):
|
|
173
|
+
""":return: list of tables names in topological order"""
|
|
174
|
+
return topo_sort(self)
|
|
175
|
+
|
|
134
176
|
def parents(self, table_name, primary=None):
|
|
135
177
|
"""
|
|
136
178
|
:param table_name: `schema`.`table`
|
|
@@ -167,10 +209,8 @@ class Dependencies(nx.DiGraph):
|
|
|
167
209
|
:return: all dependent tables sorted in topological order. Self is included.
|
|
168
210
|
"""
|
|
169
211
|
self.load(force=False)
|
|
170
|
-
nodes = self.subgraph(nx.
|
|
171
|
-
return
|
|
172
|
-
[full_table_name] + list(nx.algorithms.dag.topological_sort(nodes))
|
|
173
|
-
)
|
|
212
|
+
nodes = self.subgraph(nx.descendants(self, full_table_name))
|
|
213
|
+
return [full_table_name] + nodes.topo_sort()
|
|
174
214
|
|
|
175
215
|
def ancestors(self, full_table_name):
|
|
176
216
|
"""
|
|
@@ -178,11 +218,5 @@ class Dependencies(nx.DiGraph):
|
|
|
178
218
|
:return: all dependent tables sorted in topological order. Self is included.
|
|
179
219
|
"""
|
|
180
220
|
self.load(force=False)
|
|
181
|
-
nodes = self.subgraph(nx.
|
|
182
|
-
return
|
|
183
|
-
reversed(
|
|
184
|
-
unite_master_parts(
|
|
185
|
-
list(nx.algorithms.dag.topological_sort(nodes)) + [full_table_name]
|
|
186
|
-
)
|
|
187
|
-
)
|
|
188
|
-
)
|
|
221
|
+
nodes = self.subgraph(nx.ancestors(self, full_table_name))
|
|
222
|
+
return reversed(nodes.topo_sort() + [full_table_name])
|