spells-mtg 0.3.1__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of spells-mtg might be problematic. Click here for more details.

spells/draft_data.py CHANGED
@@ -10,15 +10,17 @@ import datetime
10
10
  import functools
11
11
  import hashlib
12
12
  import re
13
- from typing import Callable, TypeVar
13
+ from inspect import signature
14
+ from typing import Callable, TypeVar, Any
14
15
 
15
16
  import polars as pl
17
+ from polars.exceptions import ColumnNotFoundError
16
18
 
17
19
  from spells.external import data_file_path
18
20
  import spells.cache
19
21
  import spells.filter
20
22
  import spells.manifest
21
- from spells.columns import ColumnDefinition, ColumnSpec
23
+ from spells.columns import ColDef, ColSpec
22
24
  from spells.enums import View, ColName, ColType
23
25
 
24
26
 
@@ -33,7 +35,7 @@ def _cache_key(args) -> str:
33
35
 
34
36
 
35
37
  @functools.lru_cache(maxsize=None)
36
- def _get_names(set_code: str) -> tuple[str, ...]:
38
+ def _get_names(set_code: str) -> list[str]:
37
39
  card_fp = data_file_path(set_code, View.CARD)
38
40
  card_view = pl.read_parquet(card_fp)
39
41
  card_names_set = frozenset(card_view.get_column("name").to_list())
@@ -43,7 +45,7 @@ def _get_names(set_code: str) -> tuple[str, ...]:
43
45
  cols = draft_view.collect_schema().names()
44
46
 
45
47
  prefix = "pack_card_"
46
- names = tuple(col[len(prefix) :] for col in cols if col.startswith(prefix))
48
+ names = [col[len(prefix) :] for col in cols if col.startswith(prefix)]
47
49
  draft_names_set = frozenset(names)
48
50
 
49
51
  assert (
@@ -52,34 +54,52 @@ def _get_names(set_code: str) -> tuple[str, ...]:
52
54
  return names
53
55
 
54
56
 
55
- def _hydrate_col_defs(set_code: str, col_spec_map: dict[str, ColumnSpec]):
56
- def get_views(spec: ColumnSpec) -> set[View]:
57
- if spec.name == ColName.NAME or spec.col_type in (
58
- ColType.AGG,
59
- ColType.CARD_SUM,
60
- ):
61
- return set()
62
- if spec.col_type == ColType.CARD_ATTR:
63
- return {View.CARD}
64
- if spec.views is not None:
65
- return set(spec.views)
66
- assert (
67
- spec.dependencies is not None
68
- ), f"Col {spec.name} should have dependencies"
69
-
70
- views = functools.reduce(
71
- lambda prev, curr: prev.intersection(curr),
72
- [get_views(col_spec_map[dep]) for dep in spec.dependencies],
73
- )
57
+ def _get_card_context(set_code: str, col_spec_map: dict[str, ColSpec], card_context: pl.DataFrame | dict[str, dict[str, Any]] | None) -> dict[str, dict[str, Any]]:
58
+ card_attr_specs = {col:spec for col, spec in col_spec_map.items() if spec.col_type == ColType.CARD_ATTR or spec.name == ColName.NAME}
59
+ col_def_map = _hydrate_col_defs(set_code, card_attr_specs, card_only=True)
74
60
 
75
- return views
61
+ columns = list(col_def_map.keys())
62
+
63
+ fp = data_file_path(set_code, View.CARD)
64
+ card_df = pl.read_parquet(fp)
65
+ select_rows = _view_select(
66
+ card_df, frozenset(columns), col_def_map, is_agg_view=False
67
+ ).to_dicts()
68
+
69
+ loaded_context = {row[ColName.NAME]: row for row in select_rows}
70
+
71
+ if card_context is not None:
72
+ if isinstance(card_context, pl.DataFrame):
73
+ try:
74
+ card_context = {row[ColName.NAME]: row for row in card_context.to_dicts()}
75
+ except ColumnNotFoundError:
76
+ raise ValueError("card_context DataFrame must have column 'name'")
77
+
78
+ names = list(loaded_context.keys())
79
+ for name in names:
80
+ assert name in card_context, f"card_context must include a row for each card name. {name} missing."
81
+ for col, value in card_context[name].items():
82
+ loaded_context[name][col] = value
83
+
84
+ return loaded_context
85
+
86
+
87
+ def _determine_expression(spec: ColSpec, names: list[str], card_context: dict[str, dict]) -> pl.Expr | tuple[pl.Expr, ...]:
88
+ def seed_params(expr):
89
+ params = {}
90
+
91
+ sig_params = signature(expr).parameters
92
+ if 'names' in sig_params:
93
+ params['names'] = names
94
+ if 'card_context' in sig_params:
95
+ params['card_context'] = card_context
96
+ return params
97
+
98
+ if spec.col_type == ColType.NAME_SUM:
99
+ if spec.expr is not None:
100
+ assert isinstance(spec.expr, Callable), f"NAME_SUM column {spec.name} must have a callable `expr` accepting a `name` argument"
101
+ unnamed_exprs = [spec.expr(**{'name': name, **seed_params(spec.expr)}) for name in names]
76
102
 
77
- names = _get_names(set_code)
78
- assert len(names) > 0, "there should be names"
79
- hydrated = {}
80
- for key, spec in col_spec_map.items():
81
- if spec.col_type == ColType.NAME_SUM and spec.exprMap is not None:
82
- unnamed_exprs = map(spec.exprMap, names)
83
103
  expr = tuple(
84
104
  map(
85
105
  lambda ex, name: ex.alias(f"{spec.name}_{name}"),
@@ -87,14 +107,76 @@ def _hydrate_col_defs(set_code: str, col_spec_map: dict[str, ColumnSpec]):
87
107
  names,
88
108
  )
89
109
  )
90
- elif spec.expr is not None:
91
- expr = spec.expr.alias(spec.name)
92
-
93
110
  else:
94
- if spec.col_type == ColType.NAME_SUM:
95
- expr = tuple(map(lambda name: pl.col(f"{spec.name}_{name}"), names))
111
+ expr = tuple(map(lambda name: pl.col(f"{spec.name}_{name}"), names))
112
+
113
+ elif spec.expr is not None:
114
+ if isinstance(spec.expr, Callable):
115
+ params = seed_params(spec.expr)
116
+ if spec.col_type == ColType.PICK_SUM and 'name' in signature(spec.expr).parameters:
117
+ expr = pl.lit(None)
118
+ for name in names:
119
+ name_params = {'name': name, **params}
120
+ expr = pl.when(pl.col(ColName.PICK) == name).then(spec.expr(**name_params)).otherwise(expr)
96
121
  else:
97
- expr = pl.col(spec.name)
122
+ expr = spec.expr(**params)
123
+ else:
124
+ expr = spec.expr
125
+ expr = expr.alias(spec.name)
126
+ else:
127
+ expr = pl.col(spec.name)
128
+
129
+ return expr
130
+
131
+
132
+ def _infer_dependencies(name: str, expr: pl.Expr | tuple[pl.Expr,...], col_spec_map: dict[str, ColSpec], names: list[str]) -> set[str]:
133
+ dependencies = set()
134
+ tricky_ones = set()
135
+
136
+ if isinstance(expr, pl.Expr):
137
+ dep_cols = [c for c in expr.meta.root_names() if c != name]
138
+ for dep_col in dep_cols:
139
+ if dep_col in col_spec_map.keys():
140
+ dependencies.add(dep_col)
141
+ else:
142
+ tricky_ones.add(dep_col)
143
+ else:
144
+ for idx, exp in enumerate(expr):
145
+ pattern = f"_{names[idx]}$"
146
+ dep_cols = [c for c in exp.meta.root_names() if c != name]
147
+ for dep_col in dep_cols:
148
+ if dep_col in col_spec_map.keys():
149
+ dependencies.add(dep_col)
150
+ elif len(split := re.split(pattern, dep_col)) == 2 and split[0] in col_spec_map:
151
+ dependencies.add(split[0])
152
+ else:
153
+ tricky_ones.add(dep_col)
154
+
155
+ for item in tricky_ones:
156
+ found = False
157
+ for n in names:
158
+ pattern = f"_{n}$"
159
+ if not found and len(split := re.split(pattern, item)) == 2 and split[0] in col_spec_map:
160
+ dependencies.add(split[0])
161
+ found = True
162
+ assert found, f"Could not locate column spec for root col {item}"
163
+
164
+ return dependencies
165
+
166
+
167
+ def _hydrate_col_defs(set_code: str, col_spec_map: dict[str, ColSpec], card_context: pl.DataFrame | dict[str, dict] | None = None, card_only: bool =False):
168
+ names = _get_names(set_code)
169
+
170
+ if card_only:
171
+ card_context = {}
172
+ else:
173
+ card_context = _get_card_context(set_code, col_spec_map, card_context)
174
+
175
+ assert len(names) > 0, "there should be names"
176
+ hydrated = {}
177
+ for key, spec in col_spec_map.items():
178
+ expr = _determine_expression(spec, names, card_context)
179
+ dependencies = _infer_dependencies(key, expr, col_spec_map, names)
98
180
 
99
181
  try:
100
182
  sig_expr = expr if isinstance(expr, pl.Expr) else expr[0]
@@ -107,22 +189,19 @@ def _hydrate_col_defs(set_code: str, col_spec_map: dict[str, ColumnSpec]):
107
189
  else:
108
190
  expr_sig = str(datetime.datetime.now)
109
191
 
110
- dependencies = tuple(spec.dependencies or ())
111
- views = get_views(spec)
112
192
  signature = str(
113
193
  (
114
194
  spec.name,
115
195
  spec.col_type.value,
116
196
  expr_sig,
117
- tuple(view.value for view in views),
118
197
  dependencies,
119
198
  )
120
199
  )
121
200
 
122
- cdef = ColumnDefinition(
201
+ cdef = ColDef(
123
202
  name=spec.name,
124
203
  col_type=spec.col_type,
125
- views=views,
204
+ views=set(spec.views or set()),
126
205
  expr=expr,
127
206
  dependencies=dependencies,
128
207
  signature=signature,
@@ -134,20 +213,15 @@ def _hydrate_col_defs(set_code: str, col_spec_map: dict[str, ColumnSpec]):
134
213
  def _view_select(
135
214
  df: DF,
136
215
  view_cols: frozenset[str],
137
- col_def_map: dict[str, ColumnDefinition],
216
+ col_def_map: dict[str, ColDef],
138
217
  is_agg_view: bool,
139
- is_card_sum: bool = False,
140
218
  ) -> DF:
141
219
  base_cols = frozenset()
142
220
  cdefs = [col_def_map[c] for c in view_cols]
143
221
  select = []
144
222
  for cdef in cdefs:
145
223
  if is_agg_view:
146
- if (
147
- cdef.col_type == ColType.AGG
148
- or cdef.col_type == ColType.CARD_SUM
149
- and is_card_sum
150
- ):
224
+ if cdef.col_type == ColType.AGG:
151
225
  base_cols = base_cols.union(cdef.dependencies)
152
226
  select.append(cdef.expr)
153
227
  else:
@@ -164,7 +238,7 @@ def _view_select(
164
238
  select.append(cdef.expr)
165
239
 
166
240
  if base_cols != view_cols:
167
- df = _view_select(df, base_cols, col_def_map, is_agg_view, is_card_sum)
241
+ df = _view_select(df, base_cols, col_def_map, is_agg_view)
168
242
 
169
243
  return df.select(select)
170
244
 
@@ -274,44 +348,23 @@ def _base_agg_df(
274
348
  )
275
349
 
276
350
 
277
- def card_df(
278
- set_code: str,
279
- extensions: list[ColumnSpec] | None = None,
280
- ):
281
- col_spec_map = dict(spells.columns.col_spec_map)
282
- if extensions is not None:
283
- for spec in extensions:
284
- col_spec_map[spec.name] = spec
285
-
286
- col_def_map = _hydrate_col_defs(set_code, col_spec_map)
287
-
288
- columns = [ColName.NAME] + [
289
- c for c, cdef in col_def_map.items() if cdef.col_type == ColType.CARD_ATTR
290
- ]
291
- fp = data_file_path(set_code, View.CARD)
292
- card_df = pl.read_parquet(fp)
293
- select_df = _view_select(
294
- card_df, frozenset(columns), col_def_map, is_agg_view=False
295
- )
296
- return select_df.select(columns)
297
-
298
-
299
351
  def summon(
300
352
  set_code: str,
301
353
  columns: list[str] | None = None,
302
354
  group_by: list[str] | None = None,
303
355
  filter_spec: dict | None = None,
304
- extensions: list[ColumnSpec] | None = None,
356
+ extensions: list[ColSpec] | None = None,
305
357
  use_streaming: bool = False,
306
358
  read_cache: bool = True,
307
359
  write_cache: bool = True,
360
+ card_context: pl.DataFrame | dict[str, dict] | None = None
308
361
  ) -> pl.DataFrame:
309
362
  col_spec_map = dict(spells.columns.col_spec_map)
310
363
  if extensions is not None:
311
364
  for spec in extensions:
312
365
  col_spec_map[spec.name] = spec
313
366
 
314
- col_def_map = _hydrate_col_defs(set_code, col_spec_map)
367
+ col_def_map = _hydrate_col_defs(set_code, col_spec_map, card_context)
315
368
  m = spells.manifest.create(col_def_map, columns, group_by, filter_spec)
316
369
 
317
370
  calc_fn = functools.partial(_base_agg_df, set_code, m, use_streaming=use_streaming)
@@ -337,12 +390,6 @@ def summon(
337
390
  select_df = _view_select(card_df, card_cols, m.col_def_map, is_agg_view=False)
338
391
  agg_df = agg_df.join(select_df, on="name", how="outer", coalesce=True)
339
392
 
340
- if m.card_sum:
341
- card_sum_df = _view_select(
342
- agg_df, m.card_sum, m.col_def_map, is_agg_view=True, is_card_sum=True
343
- )
344
- agg_df = pl.concat([agg_df, card_sum_df], how="horizontal")
345
-
346
393
  if ColName.NAME not in m.group_by:
347
394
  agg_df = agg_df.group_by(m.group_by).sum()
348
395
 
spells/enums.py CHANGED
@@ -19,7 +19,6 @@ class ColType(StrEnum):
19
19
  NAME_SUM = "name_sum"
20
20
  AGG = "agg"
21
21
  CARD_ATTR = "card_attr"
22
- CARD_SUM = "card_sum"
23
22
 
24
23
 
25
24
  class ColName(StrEnum):
@@ -61,6 +60,7 @@ class ColName(StrEnum):
61
60
  PICK_NUM = "pick_num" # pick_number plus 1
62
61
  TAKEN_AT = "taken_at"
63
62
  NUM_TAKEN = "num_taken"
63
+ NUM_DRAFTS = "num_drafts"
64
64
  PICK = "pick"
65
65
  PICK_MAINDECK_RATE = "pick_maindeck_rate"
66
66
  PICK_SIDEBOARD_IN_RATE = "pick_sideboard_in_rate"
spells/manifest.py CHANGED
@@ -3,18 +3,17 @@ from dataclasses import dataclass
3
3
  import spells.columns
4
4
  import spells.filter
5
5
  from spells.enums import View, ColName, ColType
6
- from spells.columns import ColumnDefinition
6
+ from spells.columns import ColDef
7
7
 
8
8
 
9
9
  @dataclass(frozen=True)
10
10
  class Manifest:
11
11
  columns: tuple[str, ...]
12
- col_def_map: dict[str, ColumnDefinition]
12
+ col_def_map: dict[str, ColDef]
13
13
  base_view_group_by: frozenset[str]
14
14
  view_cols: dict[View, frozenset[str]]
15
15
  group_by: tuple[str, ...]
16
16
  filter: spells.filter.Filter | None
17
- card_sum: frozenset[str]
18
17
 
19
18
  def __post_init__(self):
20
19
  # No name filter check
@@ -40,21 +39,13 @@ class Manifest:
40
39
  ), f"Invalid groupby {col}!"
41
40
 
42
41
  for view, cols_for_view in self.view_cols.items():
43
- # cols_for_view are actually in view check
44
42
  for col in cols_for_view:
45
- assert (
46
- view in self.col_def_map[col].views
47
- ), f"View cols generated incorrectly, {col} not in view {view}"
48
43
  # game sum cols on in game, and no NAME groupby
49
44
  assert self.col_def_map[col].col_type != ColType.GAME_SUM or (
50
45
  view == View.GAME and ColName.NAME not in self.base_view_group_by
51
46
  ), f"Invalid manifest for GAME_SUM column {col}"
52
47
  if view != View.CARD:
53
48
  for col in self.base_view_group_by:
54
- # base_view_groupbys in view check
55
- assert (
56
- col == ColName.NAME or view in self.col_def_map[col].views
57
- ), f"Groupby {col} not in view {view}!"
58
49
  # base_view_groupbys in view_cols for view
59
50
  assert (
60
51
  col == ColName.NAME or col in cols_for_view
@@ -94,8 +85,8 @@ class Manifest:
94
85
 
95
86
  def _resolve_view_cols(
96
87
  col_set: frozenset[str],
97
- col_def_map: dict[str, ColumnDefinition],
98
- ) -> tuple[dict[View, frozenset[str]], frozenset[str]]:
88
+ col_def_map: dict[str, ColDef],
89
+ ) -> dict[View, frozenset[str]]:
99
90
  """
100
91
  For each view ('game', 'draft', and 'card'), return the columns
101
92
  that must be present at the aggregation step. 'name' need not be
@@ -104,7 +95,6 @@ def _resolve_view_cols(
104
95
  MAX_DEPTH = 1000
105
96
  unresolved_cols = col_set
106
97
  view_resolution = {}
107
- card_sum = frozenset()
108
98
 
109
99
  iter_num = 0
110
100
  while unresolved_cols and iter_num < MAX_DEPTH:
@@ -116,9 +106,9 @@ def _resolve_view_cols(
116
106
  view_resolution[View.DRAFT] = view_resolution.get(
117
107
  View.DRAFT, frozenset()
118
108
  ).union({ColName.PICK})
119
- if cdef.col_type == ColType.CARD_SUM:
120
- card_sum = card_sum.union({col})
121
- if cdef.views:
109
+ if cdef.col_type == ColType.CARD_ATTR:
110
+ view_resolution[View.CARD] = view_resolution.get(View.CARD, frozenset()).union({col})
111
+ elif cdef.views:
122
112
  for view in cdef.views:
123
113
  view_resolution[view] = view_resolution.get(
124
114
  view, frozenset()
@@ -128,18 +118,42 @@ def _resolve_view_cols(
128
118
  raise ValueError(
129
119
  f"Invalid column def: {col} has neither views nor dependencies!"
130
120
  )
131
- for dep in cdef.dependencies:
132
- next_cols = next_cols.union({dep})
121
+ if cdef.col_type != ColType.AGG:
122
+ fully_resolved = True
123
+ col_views = frozenset({View.GAME, View.DRAFT, View.CARD})
124
+ for dep in cdef.dependencies:
125
+ dep_views = frozenset()
126
+ for view, view_cols in view_resolution.items():
127
+ if dep in view_cols:
128
+ dep_views = dep_views.union({view})
129
+ if not dep_views:
130
+ fully_resolved = False
131
+ next_cols = next_cols.union({dep})
132
+ else:
133
+ col_views = col_views.intersection(dep_views)
134
+ if fully_resolved:
135
+ assert len(col_views), f"Column {col} can't be defined in any views!"
136
+ for view in col_views:
137
+ if view not in view_resolution:
138
+ print(cdef)
139
+ assert False, f"Something went wrong with col {col}"
140
+
141
+ view_resolution[view] = view_resolution[view].union({col})
142
+ else:
143
+ next_cols = next_cols.union({col})
144
+ else:
145
+ for dep in cdef.dependencies:
146
+ next_cols = next_cols.union({dep})
133
147
  unresolved_cols = next_cols
134
148
 
135
149
  if iter_num >= MAX_DEPTH:
136
150
  raise ValueError("broken dependency chain in column spec, loop probable")
137
151
 
138
- return view_resolution, card_sum
152
+ return view_resolution
139
153
 
140
154
 
141
155
  def create(
142
- col_def_map: dict[str, ColumnDefinition],
156
+ col_def_map: dict[str, ColDef],
143
157
  columns: list[str] | None = None,
144
158
  group_by: list[str] | None = None,
145
159
  filter_spec: dict | None = None,
@@ -148,7 +162,7 @@ def create(
148
162
  if columns is None:
149
163
  cols = tuple(spells.columns.default_columns)
150
164
  if ColName.NAME not in gbs:
151
- cols = tuple(c for c in cols if c not in [ColName.COLOR, ColName.RARITY])
165
+ cols = tuple(c for c in cols if col_def_map[c].col_type != ColType.CARD_ATTR)
152
166
  else:
153
167
  cols = tuple(columns)
154
168
 
@@ -159,12 +173,8 @@ def create(
159
173
  if m_filter is not None:
160
174
  col_set = col_set.union(m_filter.lhs)
161
175
 
162
- view_cols, card_sum = _resolve_view_cols(col_set, col_def_map)
163
176
  base_view_group_by = frozenset()
164
177
 
165
- if card_sum:
166
- base_view_group_by = base_view_group_by.union({ColName.NAME})
167
-
168
178
  for col in gbs:
169
179
  cdef = col_def_map[col]
170
180
  if cdef.col_type == ColType.GROUP_BY:
@@ -172,14 +182,23 @@ def create(
172
182
  elif cdef.col_type == ColType.CARD_ATTR:
173
183
  base_view_group_by = base_view_group_by.union({ColName.NAME})
174
184
 
185
+ view_cols = _resolve_view_cols(col_set, col_def_map)
186
+
175
187
  needed_views = frozenset()
176
- for view, cols_for_view in view_cols.items():
177
- for col in cols_for_view:
178
- if col_def_map[col].views == {view}: # only found in this view
179
- needed_views = needed_views.union({view})
188
+ if View.CARD in view_cols:
189
+ needed_views = needed_views.union({View.CARD})
180
190
 
181
- if not needed_views:
182
- needed_views = {View.DRAFT}
191
+ draft_view_cols = view_cols.get(View.DRAFT, frozenset())
192
+ game_view_cols = view_cols.get(View.GAME, frozenset())
193
+
194
+ base_cols = draft_view_cols.union(game_view_cols)
195
+
196
+ if base_cols == draft_view_cols:
197
+ needed_views = needed_views.union({View.DRAFT})
198
+ elif base_cols == game_view_cols:
199
+ needed_views = needed_views.union({View.GAME})
200
+ else:
201
+ needed_views = needed_views.union({View.GAME, View.DRAFT})
183
202
 
184
203
  view_cols = {v: view_cols[v] for v in needed_views}
185
204
 
@@ -190,5 +209,4 @@ def create(
190
209
  view_cols=view_cols,
191
210
  group_by=gbs,
192
211
  filter=m_filter,
193
- card_sum=card_sum,
194
212
  )