dbt-common 0.1.1__py3-none-any.whl → 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dbt_common/__about__.py CHANGED
@@ -1 +1 @@
1
- version = "0.1.1"
1
+ version = "0.1.5"
@@ -28,7 +28,11 @@ class BlockData:
28
28
 
29
29
  class BlockTag:
30
30
  def __init__(
31
- self, block_type_name: str, block_name: str, contents: Optional[str] = None, full_block: Optional[str] = None
31
+ self,
32
+ block_type_name: str,
33
+ block_name: str,
34
+ contents: Optional[str] = None,
35
+ full_block: Optional[str] = None,
32
36
  ) -> None:
33
37
  self.block_type_name = block_type_name
34
38
  self.block_name = block_name
@@ -106,7 +110,9 @@ class TagIterator:
106
110
  self.pos: int = 0
107
111
 
108
112
  def linepos(self, end: Optional[int] = None) -> str:
109
- """Given an absolute position in the input text, return a pair of
113
+ """Return relative position in line.
114
+
115
+ Given an absolute position in the input data, return a pair of
110
116
  line number + relative position to the start of the line.
111
117
  """
112
118
  end_val: int = self.pos if end is None else end
@@ -148,7 +154,9 @@ class TagIterator:
148
154
  return match
149
155
 
150
156
  def handle_expr(self, match: re.Match) -> None:
151
- """Handle an expression. At this point we're at a string like:
157
+ """Handle an expression.
158
+
159
+ At this point we're at a string like:
152
160
  {{ 1 + 2 }}
153
161
  ^ right here
154
162
 
@@ -180,6 +188,7 @@ class TagIterator:
180
188
 
181
189
  def _expect_block_close(self) -> None:
182
190
  """Search for the tag close marker.
191
+
183
192
  To the right of the type name, there are a few possiblities:
184
193
  - a name (handled by the regex's 'block_name')
185
194
  - any number of: `=`, `(`, `)`, strings, etc (arguments)
@@ -191,7 +200,9 @@ class TagIterator:
191
200
  are quote and `%}` - nothing else can hide the %} and be valid jinja.
192
201
  """
193
202
  while True:
194
- end_match = self._expect_match('tag close ("%}")', QUOTE_START_PATTERN, TAG_CLOSE_PATTERN)
203
+ end_match = self._expect_match(
204
+ 'tag close ("%}")', QUOTE_START_PATTERN, TAG_CLOSE_PATTERN
205
+ )
195
206
  self.advance(end_match.end())
196
207
  if end_match.groupdict().get("tag_close") is not None:
197
208
  return
@@ -207,7 +218,9 @@ class TagIterator:
207
218
  return match.end()
208
219
 
209
220
  def handle_tag(self, match: re.Match) -> Tag:
210
- """The tag could be one of a few things:
221
+ """Determine tag type.
222
+
223
+ The tag could be one of a few things:
211
224
 
212
225
  {% mytag %}
213
226
  {% mytag x = y %}
@@ -229,11 +242,15 @@ class TagIterator:
229
242
  else:
230
243
  self.advance(match.end())
231
244
  self._expect_block_close()
232
- return Tag(block_type_name=block_type_name, block_name=block_name, start=start_pos, end=self.pos)
245
+ return Tag(
246
+ block_type_name=block_type_name, block_name=block_name, start=start_pos, end=self.pos
247
+ )
233
248
 
234
249
  def find_tags(self) -> Iterator[Tag]:
235
250
  while True:
236
- match = self._first_match(BLOCK_START_PATTERN, COMMENT_START_PATTERN, EXPR_START_PATTERN)
251
+ match = self._first_match(
252
+ BLOCK_START_PATTERN, COMMENT_START_PATTERN, EXPR_START_PATTERN
253
+ )
237
254
  if match is None:
238
255
  break
239
256
 
@@ -253,7 +270,8 @@ class TagIterator:
253
270
  yield self.handle_tag(match)
254
271
  else:
255
272
  raise DbtInternalError(
256
- "Invalid regex match in next_block, expected block start, " "expr start, or comment start"
273
+ "Invalid regex match in next_block, expected block start, "
274
+ "expr start, or comment start"
257
275
  )
258
276
 
259
277
  def __iter__(self) -> Iterator[Tag]:
@@ -349,4 +367,6 @@ class BlockIterator:
349
367
  def lex_for_blocks(
350
368
  self, allowed_blocks: Optional[Set[str]] = None, collect_raw_data: bool = True
351
369
  ) -> List[Union[BlockData, BlockTag]]:
352
- return list(self.find_blocks(allowed_blocks=allowed_blocks, collect_raw_data=collect_raw_data))
370
+ return list(
371
+ self.find_blocks(allowed_blocks=allowed_blocks, collect_raw_data=collect_raw_data)
372
+ )
@@ -1,13 +1,13 @@
1
1
  from codecs import BOM_UTF8
2
2
 
3
- import agate
3
+ import agate # type: ignore
4
4
  import datetime
5
5
  import isodate
6
6
  import json
7
7
  from typing import Iterable, List, Dict, Union, Optional, Any
8
8
 
9
9
  from dbt_common.exceptions import DbtRuntimeError
10
- from dbt_common.utils import ForgivingJSONEncoder
10
+ from dbt_common.utils.encoding import ForgivingJSONEncoder
11
11
 
12
12
  BOM = BOM_UTF8.decode("utf-8") # '\ufeff'
13
13
 
@@ -17,7 +17,7 @@ class Integer(agate.data_types.DataType):
17
17
  # by default agate will cast none as a Number
18
18
  # but we need to cast it as an Integer to preserve
19
19
  # the type when merging and unioning tables
20
- if type(d) == int or d is None:
20
+ if type(d) == int or d is None: # noqa [E721]
21
21
  return d
22
22
  else:
23
23
  raise agate.exceptions.CastError('Can not parse value "%s" as Integer.' % d)
@@ -30,7 +30,7 @@ class Number(agate.data_types.Number):
30
30
  # undo the change in https://github.com/wireservice/agate/pull/733
31
31
  # i.e. do not cast True and False to numeric 1 and 0
32
32
  def cast(self, d):
33
- if type(d) == bool:
33
+ if type(d) == bool: # noqa [E721]
34
34
  raise agate.exceptions.CastError("Do not cast True to 1 or False to 0.")
35
35
  else:
36
36
  return super().cast(d)
@@ -59,14 +59,15 @@ class ISODateTime(agate.data_types.DateTime):
59
59
  def build_type_tester(
60
60
  text_columns: Iterable[str], string_null_values: Optional[Iterable[str]] = ("null", "")
61
61
  ) -> agate.TypeTester:
62
-
63
62
  types = [
64
63
  Integer(null_values=("null", "")),
65
64
  Number(null_values=("null", "")),
66
65
  agate.data_types.Date(null_values=("null", ""), date_format="%Y-%m-%d"),
67
66
  agate.data_types.DateTime(null_values=("null", ""), datetime_format="%Y-%m-%d %H:%M:%S"),
68
67
  ISODateTime(null_values=("null", "")),
69
- agate.data_types.Boolean(true_values=("true",), false_values=("false",), null_values=("null", "")),
68
+ agate.data_types.Boolean(
69
+ true_values=("true",), false_values=("false",), null_values=("null", "")
70
+ ),
70
71
  agate.data_types.Text(null_values=string_null_values),
71
72
  ]
72
73
  force = {k: agate.data_types.Text(null_values=string_null_values) for k in text_columns}
@@ -92,13 +93,13 @@ def table_from_rows(
92
93
 
93
94
 
94
95
  def table_from_data(data, column_names: Iterable[str]) -> agate.Table:
95
- "Convert a list of dictionaries into an Agate table"
96
+ """Convert a list of dictionaries into an Agate table.
96
97
 
97
- # The agate table is generated from a list of dicts, so the column order
98
- # from `data` is not preserved. We can use `select` to reorder the columns
99
- #
100
- # If there is no data, create an empty table with the specified columns
98
+ The agate table is generated from a list of dicts, so the column order
99
+ from `data` is not preserved. We can use `select` to reorder the columns
101
100
 
101
+ If there is no data, create an empty table with the specified columns
102
+ """
102
103
  if len(data) == 0:
103
104
  return agate.Table([], column_names=column_names)
104
105
  else:
@@ -107,13 +108,13 @@ def table_from_data(data, column_names: Iterable[str]) -> agate.Table:
107
108
 
108
109
 
109
110
  def table_from_data_flat(data, column_names: Iterable[str]) -> agate.Table:
110
- """
111
- Convert a list of dictionaries into an Agate table. This method does not
111
+ """Convert a list of dictionaries into an Agate table.
112
+
113
+ This method does not
112
114
  coerce string values into more specific types (eg. '005' will not be
113
115
  coerced to '5'). Additionally, this method does not coerce values to
114
116
  None (eg. '' or 'null' will retain their string literal representations).
115
117
  """
116
-
117
118
  rows = []
118
119
  text_only_columns = set()
119
120
  for _row in data:
@@ -130,18 +131,21 @@ def table_from_data_flat(data, column_names: Iterable[str]) -> agate.Table:
130
131
 
131
132
  rows.append(row)
132
133
 
133
- return table_from_rows(rows=rows, column_names=column_names, text_only_columns=text_only_columns)
134
+ return table_from_rows(
135
+ rows=rows, column_names=column_names, text_only_columns=text_only_columns
136
+ )
134
137
 
135
138
 
136
139
  def empty_table():
137
- "Returns an empty Agate table. To be used in place of None"
140
+ """Returns an empty Agate table.
138
141
 
142
+ To be used in place of None
143
+ """
139
144
  return agate.Table(rows=[])
140
145
 
141
146
 
142
147
  def as_matrix(table):
143
- "Return an agate table as a matrix of data sans columns"
144
-
148
+ """Return an agate table as a matrix of data sans columns."""
145
149
  return [r.values() for r in table.rows.values()]
146
150
 
147
151
 
@@ -176,7 +180,8 @@ class ColumnTypeBuilder(Dict[str, NullableAgateType]):
176
180
  elif isinstance(value, _NullMarker):
177
181
  # use the existing value
178
182
  return
179
- # when one table column is Number while another is Integer, force the column to Number on merge
183
+ # when one table column is Number while another is Integer,
184
+ # force the column to Number on merge
180
185
  elif isinstance(value, Integer) and isinstance(existing_type, agate.data_types.Number):
181
186
  # use the existing value
182
187
  return
@@ -203,8 +208,11 @@ class ColumnTypeBuilder(Dict[str, NullableAgateType]):
203
208
 
204
209
 
205
210
  def _merged_column_types(tables: List[agate.Table]) -> Dict[str, agate.data_types.DataType]:
206
- # this is a lot like agate.Table.merge, but with handling for all-null
207
- # rows being "any type".
211
+ """Custom version of agate.Table.merge.
212
+
213
+ this is a lot like agate.Table.merge, but with handling for all-null
214
+ rows being "any type".
215
+ """
208
216
  new_columns: ColumnTypeBuilder = ColumnTypeBuilder()
209
217
  for table in tables:
210
218
  for i in range(len(table.columns)):
@@ -219,8 +227,9 @@ def _merged_column_types(tables: List[agate.Table]) -> Dict[str, agate.data_type
219
227
 
220
228
 
221
229
  def merge_tables(tables: List[agate.Table]) -> agate.Table:
222
- """This is similar to agate.Table.merge, but it handles rows of all 'null'
223
- values more gracefully during merges.
230
+ """This is similar to agate.Table.merge.
231
+
232
+ This handles rows of all 'null' values more gracefully during merges.
224
233
  """
225
234
  new_columns = _merged_column_types(tables)
226
235
  column_names = tuple(new_columns.keys())
@@ -9,14 +9,15 @@ from itertools import chain, islice
9
9
  from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Union, Set, Type
10
10
  from typing_extensions import Protocol
11
11
 
12
- import jinja2
13
- import jinja2.ext
12
+ import jinja2 # type: ignore
13
+ import jinja2.ext # type: ignore
14
14
  import jinja2.nativetypes # type: ignore
15
- import jinja2.nodes
16
- import jinja2.parser
17
- import jinja2.sandbox
15
+ import jinja2.nodes # type: ignore
16
+ import jinja2.parser # type: ignore
17
+ import jinja2.sandbox # type: ignore
18
18
 
19
- from dbt_common.utils import (
19
+ from dbt_common.tests import test_caching_enabled
20
+ from dbt_common.utils.jinja import (
20
21
  get_dbt_macro_name,
21
22
  get_docs_macro_name,
22
23
  get_materialization_macro_name,
@@ -86,7 +87,13 @@ class MacroFuzzEnvironment(jinja2.sandbox.SandboxedEnvironment):
86
87
  return MacroFuzzParser(self, source, name, filename).parse()
87
88
 
88
89
  def _compile(self, source, filename):
89
- """Override jinja's compilation to stash the rendered source inside
90
+ """
91
+
92
+
93
+
94
+
95
+
96
+ Override jinja's compilation. Use to stash the rendered source inside
90
97
  the python linecache for debugging when the appropriate environment
91
98
  variable is set.
92
99
 
@@ -112,7 +119,10 @@ class MacroFuzzTemplate(jinja2.nativetypes.NativeTemplate):
112
119
  # This custom override makes the assumption that the locals and shared
113
120
  # parameters are not used, so enforce that.
114
121
  if shared or locals:
115
- raise Exception("The MacroFuzzTemplate.new_context() override cannot use the shared or locals parameters.")
122
+ raise Exception(
123
+ "The MacroFuzzTemplate.new_context() override cannot use the "
124
+ "shared or locals parameters."
125
+ )
116
126
 
117
127
  parent = ChainMap(vars, self.globals) if self.globals else vars
118
128
 
@@ -120,7 +130,9 @@ class MacroFuzzTemplate(jinja2.nativetypes.NativeTemplate):
120
130
 
121
131
  def render(self, *args: Any, **kwargs: Any) -> Any:
122
132
  if kwargs or len(args) != 1:
123
- raise Exception("The MacroFuzzTemplate.render() override requires exactly one argument.")
133
+ raise Exception(
134
+ "The MacroFuzzTemplate.render() override requires exactly one argument."
135
+ )
124
136
 
125
137
  ctx = self.new_context(args[0])
126
138
 
@@ -140,16 +152,14 @@ class NativeSandboxEnvironment(MacroFuzzEnvironment):
140
152
 
141
153
 
142
154
  class TextMarker(str):
143
- """A special native-env marker that indicates a value is text and is
144
- not to be evaluated. Use this to prevent your numbery-strings from becoming
145
- numbers!
155
+ """A special native-env marker that indicates a value is text and is not to be evaluated.
156
+
157
+ Use this to prevent your numbery-strings from becoming numbers!
146
158
  """
147
159
 
148
160
 
149
161
  class NativeMarker(str):
150
- """A special native-env marker that indicates the field should be passed to
151
- literal_eval.
152
- """
162
+ """A special native-env marker that indicates the field should be passed to literal_eval."""
153
163
 
154
164
 
155
165
  class BoolMarker(NativeMarker):
@@ -165,7 +175,9 @@ def _is_number(value) -> bool:
165
175
 
166
176
 
167
177
  def quoted_native_concat(nodes):
168
- """This is almost native_concat from the NativeTemplate, except in the
178
+ """Handle special case for native_concat from the NativeTemplate.
179
+
180
+ This is almost native_concat from the NativeTemplate, except in the
169
181
  special case of a single argument that is a quoted string and returns a
170
182
  string, the quotes are re-inserted.
171
183
  """
@@ -201,9 +213,10 @@ class NativeSandboxTemplate(jinja2.nativetypes.NativeTemplate): # mypy: ignore
201
213
  environment_class = NativeSandboxEnvironment # type: ignore
202
214
 
203
215
  def render(self, *args, **kwargs):
204
- """Render the template to produce a native Python type. If the
205
- result is a single node, its value is returned. Otherwise, the
206
- nodes are concatenated as strings. If the result can be parsed
216
+ """Render the template to produce a native Python type.
217
+
218
+ If the result is a single node, its value is returned. Otherwise,
219
+ the nodes are concatenated as strings. If the result can be parsed
207
220
  with :func:`ast.literal_eval`, the parsed value is returned.
208
221
  Otherwise, the string is returned.
209
222
  """
@@ -415,7 +428,9 @@ def create_undefined(node=None):
415
428
 
416
429
  def __getattr__(self, name):
417
430
  if name == "name" or _is_dunder_name(name):
418
- raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, name))
431
+ raise AttributeError(
432
+ "'{}' object has no attribute '{}'".format(type(self).__name__, name)
433
+ )
419
434
 
420
435
  self.name = name
421
436
 
@@ -463,7 +478,6 @@ def get_environment(
463
478
  args["extensions"].append(TestExtension)
464
479
 
465
480
  env_cls: Type[jinja2.Environment]
466
- text_filter: Type
467
481
  if native:
468
482
  env_cls = NativeSandboxEnvironment
469
483
  filters = NATIVE_FILTERS
@@ -491,9 +505,19 @@ def catch_jinja(node=None) -> Iterator[None]:
491
505
  raise
492
506
 
493
507
 
508
+ _TESTING_PARSE_CACHE: Dict[str, jinja2.Template] = {}
509
+
510
+
494
511
  def parse(string):
512
+ str_string = str(string)
513
+ if test_caching_enabled() and str_string in _TESTING_PARSE_CACHE:
514
+ return _TESTING_PARSE_CACHE[str_string]
515
+
495
516
  with catch_jinja():
496
- return get_environment().parse(str(string))
517
+ parsed = get_environment().parse(str(string))
518
+ if test_caching_enabled():
519
+ _TESTING_PARSE_CACHE[str_string] = parsed
520
+ return parsed
497
521
 
498
522
 
499
523
  def get_template(
@@ -515,15 +539,25 @@ def render_template(template, ctx: Dict[str, Any], node=None) -> str:
515
539
  return template.render(ctx)
516
540
 
517
541
 
542
+ _TESTING_BLOCKS_CACHE: Dict[int, List[Union[BlockData, BlockTag]]] = {}
543
+
544
+
545
+ def _get_blocks_hash(text: str, allowed_blocks: Optional[Set[str]], collect_raw_data: bool) -> int:
546
+ """Provides a hash function over the arguments to extract_toplevel_blocks, in order to support caching."""
547
+ allowed_tuple = tuple(sorted(allowed_blocks) or [])
548
+ return text.__hash__() + allowed_tuple.__hash__() + collect_raw_data.__hash__()
549
+
550
+
518
551
  def extract_toplevel_blocks(
519
552
  text: str,
520
553
  allowed_blocks: Optional[Set[str]] = None,
521
554
  collect_raw_data: bool = True,
522
555
  ) -> List[Union[BlockData, BlockTag]]:
523
- """Extract the top-level blocks with matching block types from a jinja
524
- file, with some special handling for block nesting.
556
+ """Extract the top-level blocks with matching block types from a jinja file.
557
+
558
+ Includes some special handling for block nesting.
525
559
 
526
- :param data: The data to extract blocks from.
560
+ :param text: The data to extract blocks from.
527
561
  :param allowed_blocks: The names of the blocks to extract from the file.
528
562
  They may not be nested within if/for blocks. If None, use the default
529
563
  values.
@@ -534,5 +568,19 @@ def extract_toplevel_blocks(
534
568
  :return: A list of `BlockTag`s matching the allowed block types and (if
535
569
  `collect_raw_data` is `True`) `BlockData` objects.
536
570
  """
571
+
572
+ if test_caching_enabled():
573
+ hash = _get_blocks_hash(text, allowed_blocks, collect_raw_data)
574
+ if hash in _TESTING_BLOCKS_CACHE:
575
+ return _TESTING_BLOCKS_CACHE[hash]
576
+
537
577
  tag_iterator = TagIterator(text)
538
- return BlockIterator(tag_iterator).lex_for_blocks(allowed_blocks=allowed_blocks, collect_raw_data=collect_raw_data)
578
+ blocks = BlockIterator(tag_iterator).lex_for_blocks(
579
+ allowed_blocks=allowed_blocks, collect_raw_data=collect_raw_data
580
+ )
581
+
582
+ if test_caching_enabled():
583
+ hash = _get_blocks_hash(text, allowed_blocks, collect_raw_data)
584
+ _TESTING_BLOCKS_CACHE[hash] = blocks
585
+
586
+ return blocks
@@ -41,7 +41,8 @@ def find_matching(
41
41
  file_pattern: str,
42
42
  ignore_spec: Optional[PathSpec] = None,
43
43
  ) -> List[Dict[str, Any]]:
44
- """
44
+ """Return file info from paths and patterns.
45
+
45
46
  Given an absolute `root_path`, a list of relative paths to that
46
47
  absolute root path (`relative_paths_to_search`), and a `file_pattern`
47
48
  like '*.sql', returns information about the files. For example:
@@ -78,7 +79,9 @@ def find_matching(
78
79
  relative_path_to_root = os.path.join(relative_path_to_search, relative_path)
79
80
 
80
81
  modification_time = os.path.getmtime(absolute_path)
81
- if reobj.match(local_file) and (not ignore_spec or not ignore_spec.match_file(relative_path_to_root)):
82
+ if reobj.match(local_file) and (
83
+ not ignore_spec or not ignore_spec.match_file(relative_path_to_root)
84
+ ):
82
85
  matching.append(
83
86
  {
84
87
  "searched_path": relative_path_to_search,
@@ -104,7 +107,8 @@ def load_file_contents(path: str, strip: bool = True) -> str:
104
107
 
105
108
  @functools.singledispatch
106
109
  def make_directory(path=None) -> None:
107
- """
110
+ """Handle directory creation with threading.
111
+
108
112
  Make a directory and any intermediate directories that don't already
109
113
  exist. This function handles the case where two threads try to create
110
114
  a directory at once.
@@ -133,7 +137,8 @@ def _(path: Path) -> None:
133
137
 
134
138
 
135
139
  def make_file(path: str, contents: str = "", overwrite: bool = False) -> bool:
136
- """
140
+ """Make a file with `contents` at `path`.
141
+
137
142
  Make a file at `path` assuming that the directory it resides in already
138
143
  exists. The file is saved with contents `contents`
139
144
  """
@@ -147,9 +152,7 @@ def make_file(path: str, contents: str = "", overwrite: bool = False) -> bool:
147
152
 
148
153
 
149
154
  def make_symlink(source: str, link_path: str) -> None:
150
- """
151
- Create a symlink at `link_path` referring to `source`.
152
- """
155
+ """Create a symlink at `link_path` referring to `source`."""
153
156
  if not supports_symlinks():
154
157
  # TODO: why not import these at top?
155
158
  raise dbt_common.exceptions.SymbolicLinkError()
@@ -209,9 +212,7 @@ def _windows_rmdir_readonly(func: Callable[[str], Any], path: str, exc: Tuple[An
209
212
 
210
213
 
211
214
  def resolve_path_from_base(path_to_resolve: str, base_path: str) -> str:
212
- """
213
- If path_to_resolve is a relative path, create an absolute path
214
- with base_path as the base.
215
+ """If path_to_resolve is a relative path, create an absolute path with base_path as the base.
215
216
 
216
217
  If path_to_resolve is an absolute path or a user path (~), just
217
218
  resolve it to an absolute path and return.
@@ -220,8 +221,9 @@ def resolve_path_from_base(path_to_resolve: str, base_path: str) -> str:
220
221
 
221
222
 
222
223
  def rmdir(path: str) -> None:
223
- """
224
- Recursively deletes a directory. Includes an error handler to retry with
224
+ """Recursively deletes a directory.
225
+
226
+ Includes an error handler to retry with
225
227
  different permissions on Windows. Otherwise, removing directories (eg.
226
228
  cloned via git) can cause rmtree to throw a PermissionError exception
227
229
  """
@@ -235,9 +237,7 @@ def rmdir(path: str) -> None:
235
237
 
236
238
 
237
239
  def _win_prepare_path(path: str) -> str:
238
- """Given a windows path, prepare it for use by making sure it is absolute
239
- and normalized.
240
- """
240
+ """Given a windows path, prepare it for use by making sure it is absolute and normalized."""
241
241
  path = os.path.normpath(path)
242
242
 
243
243
  # if a path starts with '\', splitdrive() on it will return '' for the
@@ -281,7 +281,9 @@ def _supports_long_paths() -> bool:
281
281
 
282
282
 
283
283
  def convert_path(path: str) -> str:
284
- """Convert a path that dbt has, which might be >260 characters long, to one
284
+ """Handle path length for windows.
285
+
286
+ Convert a path that dbt has, which might be >260 characters long, to one
285
287
  that will be writable/readable on Windows.
286
288
 
287
289
  On other platforms, this is a no-op.
@@ -387,14 +389,18 @@ def _handle_windows_error(exc: OSError, cwd: str, cmd: List[str]) -> NoReturn:
387
389
  cls: Type[dbt_common.exceptions.DbtBaseException] = dbt_common.exceptions.base.CommandError
388
390
  if exc.errno == errno.ENOENT:
389
391
  message = (
390
- "Could not find command, ensure it is in the user's PATH " "and that the user has permissions to run it"
392
+ "Could not find command, ensure it is in the user's PATH "
393
+ "and that the user has permissions to run it"
391
394
  )
392
395
  cls = dbt_common.exceptions.ExecutableError
393
396
  elif exc.errno == errno.ENOEXEC:
394
397
  message = "Command was not executable, ensure it is valid"
395
398
  cls = dbt_common.exceptions.ExecutableError
396
399
  elif exc.errno == errno.ENOTDIR:
397
- message = "Unable to cd: path does not exist, user does not have" " permissions, or not a directory"
400
+ message = (
401
+ "Unable to cd: path does not exist, user does not have"
402
+ " permissions, or not a directory"
403
+ )
398
404
  cls = dbt_common.exceptions.WorkingDirectoryError
399
405
  else:
400
406
  message = 'Unknown error: {} (errno={}: "{}")'.format(
@@ -415,7 +421,9 @@ def _interpret_oserror(exc: OSError, cwd: str, cmd: List[str]) -> NoReturn:
415
421
  _handle_posix_error(exc, cwd, cmd)
416
422
 
417
423
  # this should not be reachable, raise _something_ at least!
418
- raise dbt_common.exceptions.DbtInternalError("Unhandled exception in _interpret_oserror: {}".format(exc))
424
+ raise dbt_common.exceptions.DbtInternalError(
425
+ "Unhandled exception in _interpret_oserror: {}".format(exc)
426
+ )
419
427
 
420
428
 
421
429
  def run_cmd(cwd: str, cmd: List[str], env: Optional[Dict[str, Any]] = None) -> Tuple[bytes, bytes]:
@@ -434,7 +442,9 @@ def run_cmd(cwd: str, cmd: List[str], env: Optional[Dict[str, Any]] = None) -> T
434
442
  exe_pth = shutil.which(cmd[0])
435
443
  if exe_pth:
436
444
  cmd = [os.path.abspath(exe_pth)] + list(cmd[1:])
437
- proc = subprocess.Popen(cmd, cwd=cwd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=full_env)
445
+ proc = subprocess.Popen(
446
+ cmd, cwd=cwd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=full_env
447
+ )
438
448
 
439
449
  out, err = proc.communicate()
440
450
  except OSError as exc:
@@ -450,7 +460,9 @@ def run_cmd(cwd: str, cmd: List[str], env: Optional[Dict[str, Any]] = None) -> T
450
460
  return out, err
451
461
 
452
462
 
453
- def download_with_retries(url: str, path: str, timeout: Optional[Union[float, tuple]] = None) -> None:
463
+ def download_with_retries(
464
+ url: str, path: str, timeout: Optional[Union[float, tuple]] = None
465
+ ) -> None:
454
466
  download_fn = functools.partial(download, url, path, timeout)
455
467
  connection_exception_retry(download_fn, 5)
456
468
 
@@ -496,6 +508,7 @@ def untar_package(tar_path: str, dest_dir: str, rename_to: Optional[str] = None)
496
508
 
497
509
  def chmod_and_retry(func, path, exc_info):
498
510
  """Define an error handler to pass to shutil.rmtree.
511
+
499
512
  On Windows, when a file is marked read-only as git likes to do, rmtree will
500
513
  fail. To handle that, on errors try to make the file writable.
501
514
  We want to retry most operations here, but listdir is one that we know will
@@ -513,7 +526,9 @@ def _absnorm(path):
513
526
 
514
527
 
515
528
  def move(src, dst):
516
- """A re-implementation of shutil.move that properly removes the source
529
+ """A re-implementation of shutil.move for windows fun.
530
+
531
+ A re-implementation of shutil.move that properly removes the source
517
532
  directory on windows when it has read-only files in it and the move is
518
533
  between two drives.
519
534
 
@@ -541,7 +556,9 @@ def move(src, dst):
541
556
  if os.path.isdir(src):
542
557
  if _absnorm(dst + "\\").startswith(_absnorm(src + "\\")):
543
558
  # dst is inside src
544
- raise EnvironmentError("Cannot move a directory '{}' into itself '{}'".format(src, dst))
559
+ raise EnvironmentError(
560
+ "Cannot move a directory '{}' into itself '{}'".format(src, dst)
561
+ )
545
562
  shutil.copytree(src, dst, symlinks=True)
546
563
  rmtree(src)
547
564
  else:
@@ -550,8 +567,9 @@ def move(src, dst):
550
567
 
551
568
 
552
569
  def rmtree(path):
553
- """Recursively remove the path. On permissions errors on windows, try to remove
554
- the read-only flag and try again.
570
+ """Recursively remove the path.
571
+
572
+ On permissions errors on windows, try to remove the read-only flag and try again.
555
573
  """
556
574
  path = convert_path(path)
557
575
  return shutil.rmtree(path, onerror=chmod_and_retry)
dbt_common/context.py ADDED
@@ -0,0 +1,48 @@
1
+ from contextvars import ContextVar, copy_context
2
+ from typing import List, Mapping, Optional
3
+
4
+ from dbt_common.constants import SECRET_ENV_PREFIX
5
+
6
+
7
+ class InvocationContext:
8
+ def __init__(self, env: Mapping[str, str]):
9
+ self._env = env
10
+ self._env_secrets: Optional[List[str]] = None
11
+ # This class will also eventually manage the invocation_id, flags, event manager, etc.
12
+
13
+ @property
14
+ def env(self) -> Mapping[str, str]:
15
+ return self._env
16
+
17
+ @property
18
+ def env_secrets(self) -> List[str]:
19
+ if self._env_secrets is None:
20
+ self._env_secrets = [
21
+ v for k, v in self.env.items() if k.startswith(SECRET_ENV_PREFIX) and v.strip()
22
+ ]
23
+ return self._env_secrets
24
+
25
+
26
+ _INVOCATION_CONTEXT_VAR: ContextVar[InvocationContext] = ContextVar("DBT_INVOCATION_CONTEXT_VAR")
27
+
28
+
29
+ def reliably_get_invocation_var() -> ContextVar:
30
+ invocation_var: Optional[ContextVar] = next(
31
+ (cv for cv in copy_context() if cv.name == _INVOCATION_CONTEXT_VAR.name), None
32
+ )
33
+
34
+ if invocation_var is None:
35
+ invocation_var = _INVOCATION_CONTEXT_VAR
36
+
37
+ return invocation_var
38
+
39
+
40
+ def set_invocation_context(env: Mapping[str, str]) -> None:
41
+ invocation_var = reliably_get_invocation_var()
42
+ invocation_var.set(InvocationContext(env))
43
+
44
+
45
+ def get_invocation_context() -> InvocationContext:
46
+ invocation_var = reliably_get_invocation_var()
47
+ ctx = invocation_var.get()
48
+ return ctx