ygg 0.1.57__py3-none-any.whl → 0.1.64__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {ygg-0.1.57.dist-info → ygg-0.1.64.dist-info}/METADATA +2 -2
- ygg-0.1.64.dist-info/RECORD +74 -0
- yggdrasil/ai/__init__.py +2 -0
- yggdrasil/ai/session.py +87 -0
- yggdrasil/ai/sql_session.py +310 -0
- yggdrasil/databricks/__init__.py +0 -3
- yggdrasil/databricks/compute/cluster.py +68 -113
- yggdrasil/databricks/compute/command_execution.py +674 -0
- yggdrasil/databricks/compute/exceptions.py +19 -0
- yggdrasil/databricks/compute/execution_context.py +491 -282
- yggdrasil/databricks/compute/remote.py +4 -14
- yggdrasil/databricks/exceptions.py +10 -0
- yggdrasil/databricks/sql/__init__.py +0 -4
- yggdrasil/databricks/sql/engine.py +178 -178
- yggdrasil/databricks/sql/exceptions.py +9 -1
- yggdrasil/databricks/sql/statement_result.py +108 -120
- yggdrasil/databricks/sql/warehouse.py +339 -92
- yggdrasil/databricks/workspaces/io.py +185 -40
- yggdrasil/databricks/workspaces/path.py +114 -100
- yggdrasil/databricks/workspaces/workspace.py +210 -61
- yggdrasil/exceptions.py +7 -0
- yggdrasil/libs/databrickslib.py +22 -18
- yggdrasil/libs/extensions/spark_extensions.py +1 -1
- yggdrasil/libs/pandaslib.py +15 -6
- yggdrasil/libs/polarslib.py +49 -13
- yggdrasil/pyutils/__init__.py +1 -2
- yggdrasil/pyutils/callable_serde.py +12 -19
- yggdrasil/pyutils/exceptions.py +16 -0
- yggdrasil/pyutils/modules.py +6 -7
- yggdrasil/pyutils/python_env.py +16 -21
- yggdrasil/pyutils/waiting_config.py +171 -0
- yggdrasil/requests/msal.py +9 -96
- yggdrasil/types/cast/arrow_cast.py +3 -0
- yggdrasil/types/cast/pandas_cast.py +157 -169
- yggdrasil/types/cast/polars_cast.py +11 -43
- yggdrasil/types/dummy_class.py +81 -0
- yggdrasil/types/file_format.py +6 -2
- yggdrasil/types/python_defaults.py +92 -76
- yggdrasil/version.py +1 -1
- ygg-0.1.57.dist-info/RECORD +0 -66
- yggdrasil/databricks/ai/loki.py +0 -53
- {ygg-0.1.57.dist-info → ygg-0.1.64.dist-info}/WHEEL +0 -0
- {ygg-0.1.57.dist-info → ygg-0.1.64.dist-info}/entry_points.txt +0 -0
- {ygg-0.1.57.dist-info → ygg-0.1.64.dist-info}/licenses/LICENSE +0 -0
- {ygg-0.1.57.dist-info → ygg-0.1.64.dist-info}/top_level.txt +0 -0
- /yggdrasil/{databricks/ai/__init__.py → pyutils/mimetypes.py} +0 -0
|
@@ -497,9 +497,8 @@ class CallableSerde:
|
|
|
497
497
|
kwargs: Optional[Dict[str, Any]] = None,
|
|
498
498
|
*,
|
|
499
499
|
result_tag: str = "__CALLABLE_SERDE_RESULT__",
|
|
500
|
-
prefer: str = "dill",
|
|
501
500
|
byte_limit: int = 64 * 1024,
|
|
502
|
-
dump_env: str = "none",
|
|
501
|
+
dump_env: str = "none", # "none" | "globals" | "closure" | "both"
|
|
503
502
|
filter_used_globals: bool = True,
|
|
504
503
|
env_keys: Optional[Iterable[str]] = None,
|
|
505
504
|
env_variables: Optional[Dict[str, str]] = None,
|
|
@@ -697,26 +696,18 @@ sys.stdout.flush()
|
|
|
697
696
|
string_result.replace("DBXPATH:", "")
|
|
698
697
|
)
|
|
699
698
|
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
with path.open(mode="rb") as f:
|
|
704
|
-
buf = io.BytesIO(f.read_all_bytes())
|
|
705
|
-
|
|
699
|
+
try:
|
|
700
|
+
df = path.read_pandas()
|
|
701
|
+
finally:
|
|
706
702
|
path.rmfile()
|
|
707
|
-
buf.seek(0)
|
|
708
|
-
return pandas.read_parquet(buf)
|
|
709
703
|
|
|
710
|
-
|
|
711
|
-
blob = f.read_all_bytes()
|
|
704
|
+
return df
|
|
712
705
|
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
except (UnicodeEncodeError, binascii.Error) as e:
|
|
719
|
-
raise ValueError("Invalid base64 payload after result tag (corrupted/contaminated).") from e
|
|
706
|
+
# Strict base64 decode (rejects junk chars)
|
|
707
|
+
try:
|
|
708
|
+
blob = base64.b64decode(string_result.encode("ascii"), validate=True)
|
|
709
|
+
except (UnicodeEncodeError, binascii.Error) as e:
|
|
710
|
+
raise ValueError("Invalid base64 payload after result tag (corrupted/contaminated).") from e
|
|
720
711
|
|
|
721
712
|
raw = _decode_result_blob(blob)
|
|
722
713
|
try:
|
|
@@ -725,3 +716,5 @@ sys.stdout.flush()
|
|
|
725
716
|
raise ValueError("Failed to dill.loads decoded payload") from e
|
|
726
717
|
|
|
727
718
|
return result
|
|
719
|
+
|
|
720
|
+
|
yggdrasil/pyutils/exceptions.py
CHANGED
|
@@ -86,6 +86,19 @@ def parse_exception_from_traceback(tb_text: str) -> ParsedException:
|
|
|
86
86
|
return ParsedException(RuntimeError, clean, "RuntimeError")
|
|
87
87
|
|
|
88
88
|
|
|
89
|
+
def missing_module_name(exc: BaseException) -> str | None:
|
|
90
|
+
if isinstance(exc, ModuleNotFoundError):
|
|
91
|
+
if getattr(exc, "name", None):
|
|
92
|
+
return exc.name
|
|
93
|
+
|
|
94
|
+
# fallback: parse from message/args
|
|
95
|
+
msg = exc.args[0] if exc.args else str(exc)
|
|
96
|
+
m = re.search(r"No module named ['\"]([^'\"]+)['\"]", msg)
|
|
97
|
+
return m.group(1) if m else None
|
|
98
|
+
|
|
99
|
+
return None
|
|
100
|
+
|
|
101
|
+
|
|
89
102
|
def raise_parsed_traceback(tb_text: str, *, attach_as_cause: bool = True) -> None:
|
|
90
103
|
"""
|
|
91
104
|
Infer exception from traceback text and raise it.
|
|
@@ -94,6 +107,9 @@ def raise_parsed_traceback(tb_text: str, *, attach_as_cause: bool = True) -> Non
|
|
|
94
107
|
parsed = parse_exception_from_traceback(tb_text)
|
|
95
108
|
exc = parsed.exc_type(parsed.message) if parsed.message else parsed.exc_type()
|
|
96
109
|
|
|
110
|
+
if isinstance(exc, ModuleNotFoundError):
|
|
111
|
+
exc.name = missing_module_name(exc)
|
|
112
|
+
|
|
97
113
|
if attach_as_cause:
|
|
98
114
|
raise exc from RemoteTraceback(tb_text)
|
|
99
115
|
raise exc
|
yggdrasil/pyutils/modules.py
CHANGED
|
@@ -42,7 +42,7 @@ MODULE_PROJECT_NAMES_ALIASES = {
|
|
|
42
42
|
"yggdrasil": "ygg",
|
|
43
43
|
"jwt": "PyJWT",
|
|
44
44
|
}
|
|
45
|
-
|
|
45
|
+
DEFAULT_PIP_INDEX_SETTINGS = None
|
|
46
46
|
|
|
47
47
|
def module_name_to_project_name(module_name: str) -> str:
|
|
48
48
|
"""Map module import names to PyPI project names when they differ.
|
|
@@ -264,6 +264,11 @@ class PipIndexSettings:
|
|
|
264
264
|
Returns:
|
|
265
265
|
Default PipIndexSettings instance.
|
|
266
266
|
"""
|
|
267
|
+
global DEFAULT_PIP_INDEX_SETTINGS
|
|
268
|
+
|
|
269
|
+
if DEFAULT_PIP_INDEX_SETTINGS is None:
|
|
270
|
+
DEFAULT_PIP_INDEX_SETTINGS = get_pip_index_settings()
|
|
271
|
+
|
|
267
272
|
return DEFAULT_PIP_INDEX_SETTINGS
|
|
268
273
|
|
|
269
274
|
@property
|
|
@@ -363,9 +368,3 @@ def get_pip_index_settings() -> PipIndexSettings:
|
|
|
363
368
|
extra_index_urls.append(u)
|
|
364
369
|
|
|
365
370
|
return PipIndexSettings(index_url=index_url, extra_index_urls=extra_index_urls, sources=sources)
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
try:
|
|
369
|
-
DEFAULT_PIP_INDEX_SETTINGS = get_pip_index_settings()
|
|
370
|
-
except:
|
|
371
|
-
DEFAULT_PIP_INDEX_SETTINGS = PipIndexSettings()
|
yggdrasil/pyutils/python_env.py
CHANGED
|
@@ -20,14 +20,13 @@ from dataclasses import dataclass, field
|
|
|
20
20
|
from pathlib import Path
|
|
21
21
|
from typing import Any, Iterable, Iterator, Mapping, MutableMapping, Optional, Union, List, Tuple
|
|
22
22
|
|
|
23
|
-
from
|
|
23
|
+
from .modules import PipIndexSettings
|
|
24
24
|
|
|
25
25
|
log = logging.getLogger(__name__)
|
|
26
26
|
|
|
27
27
|
|
|
28
28
|
class PythonEnvError(RuntimeError):
|
|
29
29
|
"""Raised when Python environment operations fail."""
|
|
30
|
-
|
|
31
30
|
pass
|
|
32
31
|
|
|
33
32
|
|
|
@@ -72,6 +71,9 @@ _NON_PIPABLE_RE = re.compile(
|
|
|
72
71
|
re.IGNORECASE,
|
|
73
72
|
)
|
|
74
73
|
|
|
74
|
+
# Snapshot singleton (import-time)
|
|
75
|
+
CURRENT_PYTHON_ENV: "PythonEnv" = None
|
|
76
|
+
|
|
75
77
|
|
|
76
78
|
|
|
77
79
|
def _filter_non_pipable_linux_packages(requirements: Iterable[str]) -> List[str]:
|
|
@@ -441,19 +443,20 @@ class PythonEnv:
|
|
|
441
443
|
Returns:
|
|
442
444
|
PythonEnv representing the current environment.
|
|
443
445
|
"""
|
|
444
|
-
|
|
445
|
-
if venv:
|
|
446
|
-
log.debug("current env from VIRTUAL_ENV=%s", venv)
|
|
447
|
-
return cls(Path(venv))
|
|
446
|
+
global CURRENT_PYTHON_ENV
|
|
448
447
|
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
log.debug("current env inferred from sys.executable=%s", str(exe))
|
|
453
|
-
return cls(parent.parent)
|
|
448
|
+
if CURRENT_PYTHON_ENV is None:
|
|
449
|
+
exe = Path(sys.executable).expanduser().resolve()
|
|
450
|
+
parent = exe.parent
|
|
454
451
|
|
|
455
|
-
|
|
456
|
-
|
|
452
|
+
if parent.name in ("bin", "Scripts"):
|
|
453
|
+
log.debug("current env inferred from sys.executable=%s", str(exe))
|
|
454
|
+
CURRENT_PYTHON_ENV = cls(parent.parent)
|
|
455
|
+
else:
|
|
456
|
+
log.debug("current env fallback to sys.prefix=%s", sys.prefix)
|
|
457
|
+
CURRENT_PYTHON_ENV = cls(Path(sys.prefix))
|
|
458
|
+
|
|
459
|
+
return CURRENT_PYTHON_ENV
|
|
457
460
|
|
|
458
461
|
@classmethod
|
|
459
462
|
def ensure_uv(
|
|
@@ -1507,11 +1510,3 @@ print("RESULT:" + json.dumps(top_level))""".strip()
|
|
|
1507
1510
|
log.error("python_env CLI error: %s", e)
|
|
1508
1511
|
print(f"ERROR: {e}", file=sys.stderr)
|
|
1509
1512
|
return 2
|
|
1510
|
-
|
|
1511
|
-
|
|
1512
|
-
# Snapshot singleton (import-time)
|
|
1513
|
-
CURRENT_PYTHON_ENV: PythonEnv = PythonEnv.get_current()
|
|
1514
|
-
|
|
1515
|
-
|
|
1516
|
-
if __name__ == "__main__":
|
|
1517
|
-
raise SystemExit(PythonEnv.cli())
|
|
@@ -0,0 +1,171 @@
|
|
|
1
|
+
import datetime as dt
|
|
2
|
+
import time
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Optional, Union
|
|
5
|
+
|
|
6
|
+
__all__ = ["WaitingConfig", "WaitingConfigArg"]
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def _safe_seconds_tick(ticks: Union[int, float, dt.timedelta]):
|
|
10
|
+
if isinstance(ticks, dt.timedelta):
|
|
11
|
+
return ticks.total_seconds()
|
|
12
|
+
return ticks
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
DEFAULT_TIMEOUT_TICKS = float(20 * 60) # 20 minutes
|
|
16
|
+
WaitingConfigArg = Union["WaitingConfig", dict, int, float, dt.datetime, bool]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass(frozen=True)
|
|
20
|
+
class WaitingConfig:
|
|
21
|
+
timeout: float = DEFAULT_TIMEOUT_TICKS
|
|
22
|
+
interval: float = 2.0
|
|
23
|
+
backoff: float = 1.0
|
|
24
|
+
max_interval: float = 10.0
|
|
25
|
+
|
|
26
|
+
@property
|
|
27
|
+
def timeout_timedelta(self) -> dt.timedelta:
|
|
28
|
+
return dt.timedelta(seconds=self.timeout)
|
|
29
|
+
|
|
30
|
+
@classmethod
|
|
31
|
+
def default(cls):
|
|
32
|
+
return DEFAULT_WAITING_CONFIG
|
|
33
|
+
|
|
34
|
+
@staticmethod
|
|
35
|
+
def _to_seconds(value) -> Optional[float]:
|
|
36
|
+
if value is None:
|
|
37
|
+
return None
|
|
38
|
+
if isinstance(value, dt.timedelta):
|
|
39
|
+
return float(value.total_seconds())
|
|
40
|
+
if isinstance(value, (int, float)):
|
|
41
|
+
return float(value)
|
|
42
|
+
raise TypeError(f"Expected seconds as int/float/timedelta, got {type(value)!r}")
|
|
43
|
+
|
|
44
|
+
@staticmethod
|
|
45
|
+
def _deadline_to_timeout(deadline: dt.datetime) -> float:
|
|
46
|
+
if not isinstance(deadline, dt.datetime):
|
|
47
|
+
raise TypeError(f"deadline must be datetime, got {type(deadline)!r}")
|
|
48
|
+
now = dt.datetime.now(tz=deadline.tzinfo) if deadline.tzinfo else dt.datetime.now()
|
|
49
|
+
return (deadline - now).total_seconds()
|
|
50
|
+
|
|
51
|
+
@classmethod
|
|
52
|
+
def check_arg(
|
|
53
|
+
cls,
|
|
54
|
+
arg: Optional[WaitingConfigArg] = None,
|
|
55
|
+
timeout: Optional[Union[int, float, dt.timedelta]] = None,
|
|
56
|
+
interval: Optional[Union[int, float, dt.timedelta]] = None,
|
|
57
|
+
backoff: Optional[Union[int, float, dt.timedelta]] = None,
|
|
58
|
+
max_interval: Optional[Union[int, float, dt.timedelta]] = None,
|
|
59
|
+
) -> Optional["WaitingConfig"]:
|
|
60
|
+
base_timeout: Optional[float] = None
|
|
61
|
+
base_interval: Optional[float] = None
|
|
62
|
+
base_backoff: Optional[float] = None
|
|
63
|
+
base_max_interval: Optional[float] = None
|
|
64
|
+
|
|
65
|
+
if arg is not None:
|
|
66
|
+
if isinstance(arg, cls):
|
|
67
|
+
if timeout is None and interval is None and backoff is None and max_interval is None:
|
|
68
|
+
return arg
|
|
69
|
+
|
|
70
|
+
base_timeout = arg.timeout
|
|
71
|
+
base_interval = arg.interval
|
|
72
|
+
base_backoff = arg.backoff
|
|
73
|
+
base_max_interval = arg.max_interval
|
|
74
|
+
|
|
75
|
+
elif isinstance(arg, bool):
|
|
76
|
+
base_timeout = DEFAULT_TIMEOUT_TICKS if arg else 0.0
|
|
77
|
+
base_interval = 2.0
|
|
78
|
+
base_backoff = 2.0
|
|
79
|
+
base_max_interval = 15.0
|
|
80
|
+
|
|
81
|
+
elif isinstance(arg, (int, float, dt.timedelta)):
|
|
82
|
+
base_timeout = cls._to_seconds(arg)
|
|
83
|
+
|
|
84
|
+
elif isinstance(arg, dt.datetime):
|
|
85
|
+
base_timeout = float(cls._deadline_to_timeout(arg))
|
|
86
|
+
|
|
87
|
+
elif isinstance(arg, dict):
|
|
88
|
+
if "deadline" in arg and "timeout" in arg:
|
|
89
|
+
raise ValueError("Provide only one of 'deadline' or 'timeout' in WaitingOptions dict.")
|
|
90
|
+
|
|
91
|
+
if "deadline" in arg and arg["deadline"] is not None:
|
|
92
|
+
base_timeout = float(cls._deadline_to_timeout(arg["deadline"]))
|
|
93
|
+
else:
|
|
94
|
+
base_timeout = cls._to_seconds(arg.get("timeout"))
|
|
95
|
+
|
|
96
|
+
base_interval = cls._to_seconds(arg.get("interval"))
|
|
97
|
+
base_backoff = cls._to_seconds(arg.get("backoff"))
|
|
98
|
+
base_max_interval = cls._to_seconds(arg.get("max_interval"))
|
|
99
|
+
|
|
100
|
+
else:
|
|
101
|
+
raise TypeError(f"Unsupported WaitingOptions arg type: {type(arg)!r}")
|
|
102
|
+
|
|
103
|
+
# explicit kwargs win
|
|
104
|
+
final_timeout = cls._to_seconds(timeout) if timeout is not None else base_timeout
|
|
105
|
+
final_interval = cls._to_seconds(interval) if interval is not None else base_interval
|
|
106
|
+
final_backoff = cls._to_seconds(backoff) if backoff is not None else base_backoff
|
|
107
|
+
final_max_interval = cls._to_seconds(max_interval) if max_interval is not None else base_max_interval
|
|
108
|
+
|
|
109
|
+
# defaults to match non-Optional signature
|
|
110
|
+
if final_timeout is None:
|
|
111
|
+
final_timeout = 0.0
|
|
112
|
+
elif final_timeout < 0:
|
|
113
|
+
final_timeout = 0.0
|
|
114
|
+
|
|
115
|
+
if final_interval is None:
|
|
116
|
+
final_interval = 2.0
|
|
117
|
+
|
|
118
|
+
if final_backoff is None:
|
|
119
|
+
final_backoff = 2.0
|
|
120
|
+
elif final_backoff < 1:
|
|
121
|
+
final_backoff = 2.0
|
|
122
|
+
|
|
123
|
+
if final_max_interval is None:
|
|
124
|
+
final_max_interval = 10.0
|
|
125
|
+
|
|
126
|
+
return cls(
|
|
127
|
+
timeout=float(final_timeout),
|
|
128
|
+
interval=float(final_interval),
|
|
129
|
+
backoff=float(final_backoff),
|
|
130
|
+
max_interval=float(final_max_interval),
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
def sleep(self, iteration: int, start: float | None = None) -> None:
|
|
134
|
+
"""
|
|
135
|
+
iteration is 0-based (first wait => iteration=0)
|
|
136
|
+
|
|
137
|
+
- interval == 0 => no sleep
|
|
138
|
+
- backoff >= 1 => interval * backoff**iteration
|
|
139
|
+
- max_interval == 0 => no cap, else cap sleep to max_interval
|
|
140
|
+
- if start is provided and timeout > 0:
|
|
141
|
+
* raise TimeoutError if already out of time
|
|
142
|
+
* cap sleep so we don't oversleep past timeout
|
|
143
|
+
"""
|
|
144
|
+
if iteration < 0:
|
|
145
|
+
raise ValueError(f"iteration must be >= 0, got {iteration}")
|
|
146
|
+
|
|
147
|
+
if self.interval == 0:
|
|
148
|
+
return
|
|
149
|
+
|
|
150
|
+
sleep_s = self.interval * (self.backoff ** int(iteration))
|
|
151
|
+
|
|
152
|
+
if self.max_interval > 0:
|
|
153
|
+
sleep_s = min(sleep_s, self.max_interval)
|
|
154
|
+
|
|
155
|
+
if sleep_s <= 0:
|
|
156
|
+
return
|
|
157
|
+
|
|
158
|
+
if start is not None and self.timeout > 0:
|
|
159
|
+
elapsed = time.time() - float(start)
|
|
160
|
+
remaining = self.timeout - elapsed
|
|
161
|
+
if remaining <= 0:
|
|
162
|
+
raise TimeoutError(f"Timed out waiting after {self.timeout:.3f}s")
|
|
163
|
+
sleep_s = min(sleep_s, remaining)
|
|
164
|
+
|
|
165
|
+
if sleep_s <= 0:
|
|
166
|
+
return
|
|
167
|
+
|
|
168
|
+
time.sleep(sleep_s)
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
DEFAULT_WAITING_CONFIG = WaitingConfig()
|
yggdrasil/requests/msal.py
CHANGED
|
@@ -3,12 +3,8 @@
|
|
|
3
3
|
# auth_session.py
|
|
4
4
|
import os
|
|
5
5
|
import time
|
|
6
|
-
from
|
|
7
|
-
|
|
8
|
-
import urllib3
|
|
9
|
-
|
|
10
|
-
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
|
11
|
-
from dataclasses import dataclass
|
|
6
|
+
from dataclasses import dataclass, field
|
|
7
|
+
from typing import Any, Optional
|
|
12
8
|
|
|
13
9
|
from .session import YGGSession
|
|
14
10
|
|
|
@@ -38,11 +34,11 @@ class MSALAuth:
|
|
|
38
34
|
authority: Optional authority URL override.
|
|
39
35
|
scopes: List of scopes to request.
|
|
40
36
|
"""
|
|
41
|
-
tenant_id: Optional[str] =
|
|
42
|
-
client_id: Optional[str] =
|
|
43
|
-
client_secret: Optional[str] =
|
|
44
|
-
authority: Optional[str] =
|
|
45
|
-
scopes: list[str] | None =
|
|
37
|
+
tenant_id: Optional[str] = field(default_factory=lambda: os.environ.get("AZURE_TENANT_ID"))
|
|
38
|
+
client_id: Optional[str] = field(default_factory=lambda: os.environ.get("AZURE_CLIENT_ID"))
|
|
39
|
+
client_secret: Optional[str] = field(default_factory=lambda: os.environ.get("AZURE_CLIENT_SECRET"))
|
|
40
|
+
authority: Optional[str] = field(default_factory=lambda: os.environ.get("AZURE_AUTHORITY"))
|
|
41
|
+
scopes: list[str] | None = field(default_factory=lambda: os.environ.get("AZURE_SCOPES"))
|
|
46
42
|
|
|
47
43
|
_auth_app: ConfidentialClientApplication | None = None
|
|
48
44
|
_expires_at: float | None = None
|
|
@@ -77,97 +73,15 @@ class MSALAuth:
|
|
|
77
73
|
Returns:
|
|
78
74
|
None.
|
|
79
75
|
"""
|
|
80
|
-
self.tenant_id = self.tenant_id or os.environ.get("AZURE_TENANT_ID")
|
|
81
|
-
self.client_id = self.client_id or os.environ.get("AZURE_CLIENT_ID")
|
|
82
|
-
self.client_secret = self.client_secret or os.environ.get("AZURE_CLIENT_SECRET")
|
|
83
|
-
|
|
84
|
-
self.authority = self.authority or os.environ.get("AZURE_AUTHORITY")
|
|
85
76
|
if not self.authority:
|
|
77
|
+
assert self.tenant_id, "tenant_id is required to build authority URL"
|
|
78
|
+
|
|
86
79
|
self.authority = f"https://login.microsoftonline.com/{self.tenant_id}"
|
|
87
80
|
|
|
88
|
-
self.scopes = self.scopes or os.environ.get("AZURE_SCOPES")
|
|
89
81
|
if self.scopes:
|
|
90
82
|
if isinstance(self.scopes, str):
|
|
91
83
|
self.scopes = self.scopes.split(",")
|
|
92
84
|
|
|
93
|
-
self._validate_config()
|
|
94
|
-
|
|
95
|
-
def _validate_config(self):
|
|
96
|
-
"""Validate that all required configuration is present.
|
|
97
|
-
|
|
98
|
-
Returns:
|
|
99
|
-
None.
|
|
100
|
-
"""
|
|
101
|
-
missing = []
|
|
102
|
-
|
|
103
|
-
if not self.client_id:
|
|
104
|
-
missing.append("azure_client_id (AZURE_CLIENT_ID)")
|
|
105
|
-
if not self.client_secret:
|
|
106
|
-
missing.append("azure_client_secret (AZURE_CLIENT_SECRET)")
|
|
107
|
-
if not self.tenant_id:
|
|
108
|
-
missing.append("azure_client_secret (AZURE_TENANT_ID)")
|
|
109
|
-
if not self.scopes:
|
|
110
|
-
missing.append("scopes (AZURE_SCOPES)")
|
|
111
|
-
|
|
112
|
-
if missing:
|
|
113
|
-
raise ValueError(f"Missing required configuration: {', '.join(missing)}")
|
|
114
|
-
|
|
115
|
-
@classmethod
|
|
116
|
-
def find_in_env(
|
|
117
|
-
cls,
|
|
118
|
-
env: Mapping = None,
|
|
119
|
-
prefix: Optional[str] = None
|
|
120
|
-
) -> "MSALAuth":
|
|
121
|
-
"""Return an MSALAuth built from environment variables if available.
|
|
122
|
-
|
|
123
|
-
Args:
|
|
124
|
-
env: Mapping to read variables from; defaults to os.environ.
|
|
125
|
-
prefix: Optional prefix for variable names.
|
|
126
|
-
|
|
127
|
-
Returns:
|
|
128
|
-
A configured MSALAuth instance or None.
|
|
129
|
-
"""
|
|
130
|
-
if not env:
|
|
131
|
-
env = os.environ
|
|
132
|
-
prefix = prefix or "AZURE_"
|
|
133
|
-
|
|
134
|
-
required = {
|
|
135
|
-
key: env.get(prefix + key.upper())
|
|
136
|
-
for key in (
|
|
137
|
-
"client_id", "client_secret", "tenant_id", "scopes"
|
|
138
|
-
)
|
|
139
|
-
}
|
|
140
|
-
|
|
141
|
-
if all(required.values()):
|
|
142
|
-
scopes = required["scopes"].split(",") if required["scopes"] else None
|
|
143
|
-
return MSALAuth(
|
|
144
|
-
tenant_id=required["tenant_id"],
|
|
145
|
-
client_id=required["client_id"],
|
|
146
|
-
client_secret=required["client_secret"],
|
|
147
|
-
scopes=scopes,
|
|
148
|
-
authority=env.get(prefix + "AUTHORITY"),
|
|
149
|
-
)
|
|
150
|
-
|
|
151
|
-
return None
|
|
152
|
-
|
|
153
|
-
def export_to(self, to: dict = os.environ):
|
|
154
|
-
"""Export the auth configuration to the provided mapping.
|
|
155
|
-
|
|
156
|
-
Args:
|
|
157
|
-
to: Mapping to populate with auth configuration values.
|
|
158
|
-
|
|
159
|
-
Returns:
|
|
160
|
-
None.
|
|
161
|
-
"""
|
|
162
|
-
for key, value in (
|
|
163
|
-
("AZURE_CLIENT_ID", self.client_id),
|
|
164
|
-
("AZURE_CLIENT_SECRET", self.client_secret),
|
|
165
|
-
("AZURE_AUTHORITY", self.authority),
|
|
166
|
-
("AZURE_SCOPES", ",".join(self.scopes)),
|
|
167
|
-
):
|
|
168
|
-
if value:
|
|
169
|
-
to[key] = value
|
|
170
|
-
|
|
171
85
|
@property
|
|
172
86
|
def auth_app(self) -> ConfidentialClientApplication:
|
|
173
87
|
"""Return or initialize the MSAL confidential client.
|
|
@@ -298,7 +212,6 @@ class MSALSession(YGGSession):
|
|
|
298
212
|
super().__init__(*args, **kwargs)
|
|
299
213
|
self.msal_auth = msal_auth
|
|
300
214
|
|
|
301
|
-
|
|
302
215
|
def prepare_request(self, request):
|
|
303
216
|
"""Prepare the request with an Authorization header when needed.
|
|
304
217
|
|