datajoint 0.14.1__py3-none-any.whl → 0.14.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datajoint might be problematic. Click here for more details.

datajoint/dependencies.py CHANGED
@@ -5,28 +5,64 @@ from collections import defaultdict
5
5
  from .errors import DataJointError
6
6
 
7
7
 
8
- def unite_master_parts(lst):
8
+ def extract_master(part_table):
9
9
  """
10
- re-order a list of table names so that part tables immediately follow their master tables without breaking
11
- the topological order.
12
- Without this correction, a simple topological sort may insert other descendants between master and parts.
13
- The input list must be topologically sorted.
14
- :example:
15
- unite_master_parts(
16
- ['`s`.`a`', '`s`.`a__q`', '`s`.`b`', '`s`.`c`', '`s`.`c__q`', '`s`.`b__q`', '`s`.`d`', '`s`.`a__r`']) ->
17
- ['`s`.`a`', '`s`.`a__q`', '`s`.`a__r`', '`s`.`b`', '`s`.`b__q`', '`s`.`c`', '`s`.`c__q`', '`s`.`d`']
10
+ given a part table name, return master part. None if not a part table
18
11
  """
19
- for i in range(2, len(lst)):
20
- name = lst[i]
21
- match = re.match(r"(?P<master>`\w+`.`#?\w+)__\w+`", name)
22
- if match: # name is a part table
23
- master = match.group("master")
24
- for j in range(i - 1, -1, -1):
25
- if lst[j] == master + "`" or lst[j].startswith(master + "__"):
26
- # move from the ith position to the (j+1)th position
27
- lst[j + 1 : i + 1] = [name] + lst[j + 1 : i]
28
- break
29
- return lst
12
+ match = re.match(r"(?P<master>`\w+`.`#?\w+)__\w+`", part_table)
13
+ return match["master"] + "`" if match else None
14
+
15
+
16
+ def topo_sort(graph):
17
+ """
18
+ topological sort of a dependency graph that keeps part tables together with their masters
19
+ :return: list of table names in topological order
20
+ """
21
+
22
+ graph = nx.DiGraph(graph) # make a copy
23
+
24
+ # collapse alias nodes
25
+ alias_nodes = [node for node in graph if node.isdigit()]
26
+ for node in alias_nodes:
27
+ try:
28
+ direct_edge = (
29
+ next(x for x in graph.in_edges(node))[0],
30
+ next(x for x in graph.out_edges(node))[1],
31
+ )
32
+ except StopIteration:
33
+ pass # a disconnected alias node
34
+ else:
35
+ graph.add_edge(*direct_edge)
36
+ graph.remove_nodes_from(alias_nodes)
37
+
38
+ # Add parts' dependencies to their masters' dependencies
39
+ # to ensure correct topological ordering of the masters.
40
+ for part in graph:
41
+ # find the part's master
42
+ if (master := extract_master(part)) in graph:
43
+ for edge in graph.in_edges(part):
44
+ parent = edge[0]
45
+ if master not in (parent, extract_master(parent)):
46
+ # if parent is neither master nor part of master
47
+ graph.add_edge(parent, master)
48
+ sorted_nodes = list(nx.topological_sort(graph))
49
+
50
+ # bring parts up to their masters
51
+ pos = len(sorted_nodes) - 1
52
+ placed = set()
53
+ while pos > 1:
54
+ part = sorted_nodes[pos]
55
+ if (master := extract_master(part)) not in graph or part in placed:
56
+ pos -= 1
57
+ else:
58
+ placed.add(part)
59
+ insert_pos = sorted_nodes.index(master) + 1
60
+ if pos > insert_pos:
61
+ # move the part to the position immediately after its master
62
+ del sorted_nodes[pos]
63
+ sorted_nodes.insert(insert_pos, part)
64
+
65
+ return sorted_nodes
30
66
 
31
67
 
32
68
  class Dependencies(nx.DiGraph):
@@ -131,6 +167,10 @@ class Dependencies(nx.DiGraph):
131
167
  raise DataJointError("DataJoint can only work with acyclic dependencies")
132
168
  self._loaded = True
133
169
 
170
+ def topo_sort(self):
171
+ """:return: list of tables names in topological order"""
172
+ return topo_sort(self)
173
+
134
174
  def parents(self, table_name, primary=None):
135
175
  """
136
176
  :param table_name: `schema`.`table`
@@ -167,10 +207,8 @@ class Dependencies(nx.DiGraph):
167
207
  :return: all dependent tables sorted in topological order. Self is included.
168
208
  """
169
209
  self.load(force=False)
170
- nodes = self.subgraph(nx.algorithms.dag.descendants(self, full_table_name))
171
- return unite_master_parts(
172
- [full_table_name] + list(nx.algorithms.dag.topological_sort(nodes))
173
- )
210
+ nodes = self.subgraph(nx.descendants(self, full_table_name))
211
+ return [full_table_name] + nodes.topo_sort()
174
212
 
175
213
  def ancestors(self, full_table_name):
176
214
  """
@@ -178,11 +216,5 @@ class Dependencies(nx.DiGraph):
178
216
  :return: all dependent tables sorted in topological order. Self is included.
179
217
  """
180
218
  self.load(force=False)
181
- nodes = self.subgraph(nx.algorithms.dag.ancestors(self, full_table_name))
182
- return list(
183
- reversed(
184
- unite_master_parts(
185
- list(nx.algorithms.dag.topological_sort(nodes)) + [full_table_name]
186
- )
187
- )
188
- )
219
+ nodes = self.subgraph(nx.ancestors(self, full_table_name))
220
+ return reversed(nodes.topo_sort() + [full_table_name])
datajoint/diagram.py CHANGED
@@ -1,12 +1,11 @@
1
1
  import networkx as nx
2
- import re
3
2
  import functools
4
3
  import io
5
4
  import logging
6
5
  import inspect
7
6
  from .table import Table
8
- from .dependencies import unite_master_parts
9
- from .user_tables import Manual, Imported, Computed, Lookup, Part
7
+ from .dependencies import topo_sort
8
+ from .user_tables import Manual, Imported, Computed, Lookup, Part, _get_tier, _AliasNode
10
9
  from .errors import DataJointError
11
10
  from .table import lookup_class_name
12
11
 
@@ -27,29 +26,6 @@ except:
27
26
 
28
27
 
29
28
  logger = logging.getLogger(__name__.split(".")[0])
30
- user_table_classes = (Manual, Lookup, Computed, Imported, Part)
31
-
32
-
33
- class _AliasNode:
34
- """
35
- special class to indicate aliased foreign keys
36
- """
37
-
38
- pass
39
-
40
-
41
- def _get_tier(table_name):
42
- if not table_name.startswith("`"):
43
- return _AliasNode
44
- else:
45
- try:
46
- return next(
47
- tier
48
- for tier in user_table_classes
49
- if re.fullmatch(tier.tier_regexp, table_name.split("`")[-2])
50
- )
51
- except StopIteration:
52
- return None
53
29
 
54
30
 
55
31
  if not diagram_active:
@@ -59,8 +35,7 @@ if not diagram_active:
59
35
  Entity relationship diagram, currently disabled due to the lack of required packages: matplotlib and pygraphviz.
60
36
 
61
37
  To enable Diagram feature, please install both matplotlib and pygraphviz. For instructions on how to install
62
- these two packages, refer to http://docs.datajoint.io/setup/Install-and-connect.html#python and
63
- http://tutorials.datajoint.io/setting-up/datajoint-python.html
38
+ these two packages, refer to https://datajoint.com/docs/core/datajoint-python/0.14/client/install/
64
39
  """
65
40
 
66
41
  def __init__(self, *args, **kwargs):
@@ -72,19 +47,22 @@ else:
72
47
 
73
48
  class Diagram(nx.DiGraph):
74
49
  """
75
- Entity relationship diagram.
50
+ Schema diagram showing tables and foreign keys between in the form of a directed
51
+ acyclic graph (DAG). The diagram is derived from the connection.dependencies object.
76
52
 
77
53
  Usage:
78
54
 
79
55
  >>> diag = Diagram(source)
80
56
 
81
- source can be a base table object, a base table class, a schema, or a module that has a schema.
57
+ source can be a table object, a table class, a schema, or a module that has a schema.
82
58
 
83
59
  >>> diag.draw()
84
60
 
85
61
  draws the diagram using pyplot
86
62
 
87
63
  diag1 + diag2 - combines the two diagrams.
64
+ diag1 - diag2 - difference between diagrams
65
+ diag1 * diag2 - intersection of diagrams
88
66
  diag + n - expands n levels of successors
89
67
  diag - n - expands n levels of predecessors
90
68
  Thus dj.Diagram(schema.Table)+1-1 defines the diagram of immediate ancestors and descendants of schema.Table
@@ -94,6 +72,7 @@ else:
94
72
  """
95
73
 
96
74
  def __init__(self, source, context=None):
75
+
97
76
  if isinstance(source, Diagram):
98
77
  # copy constructor
99
78
  self.nodes_to_show = set(source.nodes_to_show)
@@ -154,7 +133,7 @@ else:
154
133
 
155
134
  def add_parts(self):
156
135
  """
157
- Adds to the diagram the part tables of tables already included in the diagram
136
+ Adds to the diagram the part tables of all master tables already in the diagram
158
137
  :return:
159
138
  """
160
139
 
@@ -179,16 +158,6 @@ else:
179
158
  )
180
159
  return self
181
160
 
182
- def topological_sort(self):
183
- """:return: list of nodes in topological order"""
184
- return unite_master_parts(
185
- list(
186
- nx.algorithms.dag.topological_sort(
187
- nx.DiGraph(self).subgraph(self.nodes_to_show)
188
- )
189
- )
190
- )
191
-
192
161
  def __add__(self, arg):
193
162
  """
194
163
  :param arg: either another Diagram or a positive integer.
@@ -256,6 +225,10 @@ else:
256
225
  self.nodes_to_show.intersection_update(arg.nodes_to_show)
257
226
  return self
258
227
 
228
+ def topo_sort(self):
229
+ """return nodes in lexicographical topological order"""
230
+ return topo_sort(self)
231
+
259
232
  def _make_graph(self):
260
233
  """
261
234
  Make the self.graph - a graph object ready for drawing
@@ -300,6 +273,36 @@ else:
300
273
  nx.relabel_nodes(graph, mapping, copy=False)
301
274
  return graph
302
275
 
276
+ @staticmethod
277
+ def _encapsulate_edge_attributes(graph):
278
+ """
279
+ Modifies the `nx.Graph`'s edge attribute `attr_map` to be a string representation
280
+ of the attribute map, and encapsulates the string in double quotes.
281
+ Changes the graph in place.
282
+
283
+ Implements workaround described in
284
+ https://github.com/pydot/pydot/issues/258#issuecomment-795798099
285
+ """
286
+ for u, v, *_, edgedata in graph.edges(data=True):
287
+ if "attr_map" in edgedata:
288
+ graph.edges[u, v]["attr_map"] = '"{0}"'.format(edgedata["attr_map"])
289
+
290
+ @staticmethod
291
+ def _encapsulate_node_names(graph):
292
+ """
293
+ Modifies the `nx.Graph`'s node names string representations encapsulated in
294
+ double quotes.
295
+ Changes the graph in place.
296
+
297
+ Implements workaround described in
298
+ https://github.com/datajoint/datajoint-python/pull/1176
299
+ """
300
+ nx.relabel_nodes(
301
+ graph,
302
+ {node: '"{0}"'.format(node) for node in graph.nodes()},
303
+ copy=False,
304
+ )
305
+
303
306
  def make_dot(self):
304
307
  graph = self._make_graph()
305
308
  graph.nodes()
@@ -368,6 +371,8 @@ else:
368
371
  for node, d in dict(graph.nodes(data=True)).items()
369
372
  }
370
373
 
374
+ self._encapsulate_node_names(graph)
375
+ self._encapsulate_edge_attributes(graph)
371
376
  dot = nx.drawing.nx_pydot.to_pydot(graph)
372
377
  for node in dot.get_nodes():
373
378
  node.set_shape("circle")
@@ -385,11 +390,15 @@ else:
385
390
  assert issubclass(cls, Table)
386
391
  description = cls().describe(context=self.context).split("\n")
387
392
  description = (
388
- "-" * 30
389
- if q.startswith("---")
390
- else q.replace("->", "&#8594;")
391
- if "->" in q
392
- else q.split(":")[0]
393
+ (
394
+ "-" * 30
395
+ if q.startswith("---")
396
+ else (
397
+ q.replace("->", "&#8594;")
398
+ if "->" in q
399
+ else q.split(":")[0]
400
+ )
401
+ )
393
402
  for q in description
394
403
  if not q.startswith("#")
395
404
  )
@@ -404,9 +413,14 @@ else:
404
413
 
405
414
  for edge in dot.get_edges():
406
415
  # see https://graphviz.org/doc/info/attrs.html
407
- src = edge.get_source().strip('"')
408
- dest = edge.get_destination().strip('"')
416
+ src = edge.get_source()
417
+ dest = edge.get_destination()
409
418
  props = graph.get_edge_data(src, dest)
419
+ if props is None:
420
+ raise DataJointError(
421
+ "Could not find edge with source "
422
+ "'{}' and destination '{}'".format(src, dest)
423
+ )
410
424
  edge.set_color("#00000040")
411
425
  edge.set_style("solid" if props["primary"] else "dashed")
412
426
  master_part = graph.nodes[dest][
datajoint/expression.py CHANGED
@@ -9,6 +9,7 @@ from .fetch import Fetch, Fetch1
9
9
  from .preview import preview, repr_html
10
10
  from .condition import (
11
11
  AndList,
12
+ Top,
12
13
  Not,
13
14
  make_condition,
14
15
  assert_join_compatibility,
@@ -52,6 +53,7 @@ class QueryExpression:
52
53
  _connection = None
53
54
  _heading = None
54
55
  _support = None
56
+ _top = None
55
57
 
56
58
  # If the query will be using distinct
57
59
  _distinct = False
@@ -100,9 +102,11 @@ class QueryExpression:
100
102
 
101
103
  def from_clause(self):
102
104
  support = (
103
- "(" + src.make_sql() + ") as `$%x`" % next(self._subquery_alias_count)
104
- if isinstance(src, QueryExpression)
105
- else src
105
+ (
106
+ "(" + src.make_sql() + ") as `$%x`" % next(self._subquery_alias_count)
107
+ if isinstance(src, QueryExpression)
108
+ else src
109
+ )
106
110
  for src in self.support
107
111
  )
108
112
  clause = next(support)
@@ -119,17 +123,33 @@ class QueryExpression:
119
123
  else " WHERE (%s)" % ")AND(".join(str(s) for s in self.restriction)
120
124
  )
121
125
 
126
+ def sorting_clauses(self):
127
+ if not self._top:
128
+ return ""
129
+ clause = ", ".join(
130
+ _wrap_attributes(
131
+ _flatten_attribute_list(self.primary_key, self._top.order_by)
132
+ )
133
+ )
134
+ if clause:
135
+ clause = f" ORDER BY {clause}"
136
+ if self._top.limit is not None:
137
+ clause += f" LIMIT {self._top.limit}{f' OFFSET {self._top.offset}' if self._top.offset else ''}"
138
+
139
+ return clause
140
+
122
141
  def make_sql(self, fields=None):
123
142
  """
124
143
  Make the SQL SELECT statement.
125
144
 
126
145
  :param fields: used to explicitly set the select attributes
127
146
  """
128
- return "SELECT {distinct}{fields} FROM {from_}{where}".format(
147
+ return "SELECT {distinct}{fields} FROM {from_}{where}{sorting}".format(
129
148
  distinct="DISTINCT " if self._distinct else "",
130
149
  fields=self.heading.as_sql(fields or self.heading.names),
131
150
  from_=self.from_clause(),
132
151
  where=self.where_clause(),
152
+ sorting=self.sorting_clauses(),
133
153
  )
134
154
 
135
155
  # --------- query operators -----------
@@ -187,6 +207,14 @@ class QueryExpression:
187
207
  string, or an AndList.
188
208
  """
189
209
  attributes = set()
210
+ if isinstance(restriction, Top):
211
+ result = (
212
+ self.make_subquery()
213
+ if self._top and not self._top.__eq__(restriction)
214
+ else copy.copy(self)
215
+ ) # make subquery to avoid overwriting existing Top
216
+ result._top = restriction
217
+ return result
190
218
  new_condition = make_condition(self, restriction, attributes)
191
219
  if new_condition is True:
192
220
  return self # restriction has no effect, return the same object
@@ -200,8 +228,10 @@ class QueryExpression:
200
228
  pass # all ok
201
229
  # If the new condition uses any new attributes, a subquery is required.
202
230
  # However, Aggregation's HAVING statement works fine with aliased attributes.
203
- need_subquery = isinstance(self, Union) or (
204
- not isinstance(self, Aggregation) and self.heading.new_attributes
231
+ need_subquery = (
232
+ isinstance(self, Union)
233
+ or (not isinstance(self, Aggregation) and self.heading.new_attributes)
234
+ or self._top
205
235
  )
206
236
  if need_subquery:
207
237
  result = self.make_subquery()
@@ -537,19 +567,20 @@ class QueryExpression:
537
567
 
538
568
  def __len__(self):
539
569
  """:return: number of elements in the result set e.g. ``len(q1)``."""
540
- return self.connection.query(
570
+ result = self.make_subquery() if self._top else copy.copy(self)
571
+ return result.connection.query(
541
572
  "SELECT {select_} FROM {from_}{where}".format(
542
573
  select_=(
543
574
  "count(*)"
544
- if any(self._left)
575
+ if any(result._left)
545
576
  else "count(DISTINCT {fields})".format(
546
- fields=self.heading.as_sql(
547
- self.primary_key, include_aliases=False
577
+ fields=result.heading.as_sql(
578
+ result.primary_key, include_aliases=False
548
579
  )
549
580
  )
550
581
  ),
551
- from_=self.from_clause(),
552
- where=self.where_clause(),
582
+ from_=result.from_clause(),
583
+ where=result.where_clause(),
553
584
  )
554
585
  ).fetchone()[0]
555
586
 
@@ -617,18 +648,12 @@ class QueryExpression:
617
648
  # -- move on to next entry.
618
649
  return next(self)
619
650
 
620
- def cursor(self, offset=0, limit=None, order_by=None, as_dict=False):
651
+ def cursor(self, as_dict=False):
621
652
  """
622
653
  See expression.fetch() for input description.
623
654
  :return: query cursor
624
655
  """
625
- if offset and limit is None:
626
- raise DataJointError("limit is required when offset is set")
627
656
  sql = self.make_sql()
628
- if order_by is not None:
629
- sql += " ORDER BY " + ", ".join(order_by)
630
- if limit is not None:
631
- sql += " LIMIT %d" % limit + (" OFFSET %d" % offset if offset else "")
632
657
  logger.debug(sql)
633
658
  return self.connection.query(sql, as_dict=as_dict)
634
659
 
@@ -699,21 +724,26 @@ class Aggregation(QueryExpression):
699
724
  fields = self.heading.as_sql(fields or self.heading.names)
700
725
  assert self._grouping_attributes or not self.restriction
701
726
  distinct = set(self.heading.names) == set(self.primary_key)
702
- return "SELECT {distinct}{fields} FROM {from_}{where}{group_by}".format(
703
- distinct="DISTINCT " if distinct else "",
704
- fields=fields,
705
- from_=self.from_clause(),
706
- where=self.where_clause(),
707
- group_by=""
708
- if not self.primary_key
709
- else (
710
- " GROUP BY `%s`" % "`,`".join(self._grouping_attributes)
711
- + (
727
+ return (
728
+ "SELECT {distinct}{fields} FROM {from_}{where}{group_by}{sorting}".format(
729
+ distinct="DISTINCT " if distinct else "",
730
+ fields=fields,
731
+ from_=self.from_clause(),
732
+ where=self.where_clause(),
733
+ group_by=(
712
734
  ""
713
- if not self.restriction
714
- else " HAVING (%s)" % ")AND(".join(self.restriction)
715
- )
716
- ),
735
+ if not self.primary_key
736
+ else (
737
+ " GROUP BY `%s`" % "`,`".join(self._grouping_attributes)
738
+ + (
739
+ ""
740
+ if not self.restriction
741
+ else " HAVING (%s)" % ")AND(".join(self.restriction)
742
+ )
743
+ )
744
+ ),
745
+ sorting=self.sorting_clauses(),
746
+ )
717
747
  )
718
748
 
719
749
  def __len__(self):
@@ -772,14 +802,19 @@ class Union(QueryExpression):
772
802
  ):
773
803
  # no secondary attributes: use UNION DISTINCT
774
804
  fields = arg1.primary_key
775
- return "SELECT * FROM (({sql1}) UNION ({sql2})) as `_u{alias}`".format(
776
- sql1=arg1.make_sql()
777
- if isinstance(arg1, Union)
778
- else arg1.make_sql(fields),
779
- sql2=arg2.make_sql()
780
- if isinstance(arg2, Union)
781
- else arg2.make_sql(fields),
805
+ return "SELECT * FROM (({sql1}) UNION ({sql2})) as `_u{alias}{sorting}`".format(
806
+ sql1=(
807
+ arg1.make_sql()
808
+ if isinstance(arg1, Union)
809
+ else arg1.make_sql(fields)
810
+ ),
811
+ sql2=(
812
+ arg2.make_sql()
813
+ if isinstance(arg2, Union)
814
+ else arg2.make_sql(fields)
815
+ ),
782
816
  alias=next(self.__count),
817
+ sorting=self.sorting_clauses(),
783
818
  )
784
819
  # with secondary attributes, use union of left join with antijoin
785
820
  fields = self.heading.names
@@ -839,7 +874,7 @@ class U:
839
874
  >>> dj.U().aggr(expr, n='count(*)')
840
875
 
841
876
  The following expressions both yield one element containing the number `n` of distinct values of attribute `attr` in
842
- query expressio `expr`.
877
+ query expression `expr`.
843
878
 
844
879
  >>> dj.U().aggr(expr, n='count(distinct attr)')
845
880
  >>> dj.U().aggr(dj.U('attr').aggr(expr), 'n=count(*)')
@@ -931,3 +966,25 @@ class U:
931
966
  )
932
967
 
933
968
  aggregate = aggr # alias for aggr
969
+
970
+
971
+ def _flatten_attribute_list(primary_key, attrs):
972
+ """
973
+ :param primary_key: list of attributes in primary key
974
+ :param attrs: list of attribute names, which may include "KEY", "KEY DESC" or "KEY ASC"
975
+ :return: generator of attributes where "KEY" is replaced with its component attributes
976
+ """
977
+ for a in attrs:
978
+ if re.match(r"^\s*KEY(\s+[aA][Ss][Cc])?\s*$", a):
979
+ if primary_key:
980
+ yield from primary_key
981
+ elif re.match(r"^\s*KEY\s+[Dd][Ee][Ss][Cc]\s*$", a):
982
+ if primary_key:
983
+ yield from (q + " DESC" for q in primary_key)
984
+ else:
985
+ yield a
986
+
987
+
988
+ def _wrap_attributes(attr):
989
+ for entry in attr: # wrap attribute names in backquotes
990
+ yield re.sub(r"\b((?!asc|desc)\w+)\b", r"`\1`", entry, flags=re.IGNORECASE)
datajoint/external.py CHANGED
@@ -8,7 +8,7 @@ from .hash import uuid_from_buffer, uuid_from_file
8
8
  from .table import Table, FreeTable
9
9
  from .heading import Heading
10
10
  from .declare import EXTERNAL_TABLE_ROOT
11
- from . import s3
11
+ from . import s3, errors
12
12
  from .utils import safe_write, safe_copy
13
13
 
14
14
  logger = logging.getLogger(__name__.split(".")[0])
@@ -141,7 +141,12 @@ class ExternalTable(Table):
141
141
  if self.spec["protocol"] == "s3":
142
142
  return self.s3.get(external_path)
143
143
  if self.spec["protocol"] == "file":
144
- return Path(external_path).read_bytes()
144
+ try:
145
+ return Path(external_path).read_bytes()
146
+ except FileNotFoundError:
147
+ raise errors.MissingExternalFile(
148
+ f"Missing external file {external_path}"
149
+ ) from None
145
150
  assert False
146
151
 
147
152
  def _remove_external_file(self, external_path):