awx-zipline-ai 0.2.1__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. agent/ttypes.py +6 -6
  2. ai/chronon/airflow_helpers.py +20 -23
  3. ai/chronon/cli/__init__.py +0 -0
  4. ai/chronon/cli/compile/__init__.py +0 -0
  5. ai/chronon/cli/compile/column_hashing.py +40 -17
  6. ai/chronon/cli/compile/compile_context.py +13 -17
  7. ai/chronon/cli/compile/compiler.py +59 -36
  8. ai/chronon/cli/compile/conf_validator.py +251 -99
  9. ai/chronon/cli/compile/display/__init__.py +0 -0
  10. ai/chronon/cli/compile/display/class_tracker.py +6 -16
  11. ai/chronon/cli/compile/display/compile_status.py +10 -10
  12. ai/chronon/cli/compile/display/diff_result.py +79 -14
  13. ai/chronon/cli/compile/fill_templates.py +3 -8
  14. ai/chronon/cli/compile/parse_configs.py +10 -17
  15. ai/chronon/cli/compile/parse_teams.py +38 -34
  16. ai/chronon/cli/compile/serializer.py +3 -9
  17. ai/chronon/cli/compile/version_utils.py +42 -0
  18. ai/chronon/cli/git_utils.py +2 -13
  19. ai/chronon/cli/logger.py +0 -2
  20. ai/chronon/constants.py +1 -1
  21. ai/chronon/group_by.py +47 -47
  22. ai/chronon/join.py +46 -32
  23. ai/chronon/logger.py +1 -2
  24. ai/chronon/model.py +9 -4
  25. ai/chronon/query.py +2 -2
  26. ai/chronon/repo/__init__.py +1 -2
  27. ai/chronon/repo/aws.py +17 -31
  28. ai/chronon/repo/cluster.py +121 -50
  29. ai/chronon/repo/compile.py +14 -8
  30. ai/chronon/repo/constants.py +1 -1
  31. ai/chronon/repo/default_runner.py +32 -54
  32. ai/chronon/repo/explore.py +70 -73
  33. ai/chronon/repo/extract_objects.py +6 -9
  34. ai/chronon/repo/gcp.py +89 -88
  35. ai/chronon/repo/gitpython_utils.py +3 -2
  36. ai/chronon/repo/hub_runner.py +145 -55
  37. ai/chronon/repo/hub_uploader.py +2 -1
  38. ai/chronon/repo/init.py +12 -5
  39. ai/chronon/repo/join_backfill.py +19 -5
  40. ai/chronon/repo/run.py +42 -39
  41. ai/chronon/repo/serializer.py +4 -12
  42. ai/chronon/repo/utils.py +72 -63
  43. ai/chronon/repo/zipline.py +3 -19
  44. ai/chronon/repo/zipline_hub.py +211 -39
  45. ai/chronon/resources/__init__.py +0 -0
  46. ai/chronon/resources/gcp/__init__.py +0 -0
  47. ai/chronon/resources/gcp/group_bys/__init__.py +0 -0
  48. ai/chronon/resources/gcp/group_bys/test/data.py +13 -17
  49. ai/chronon/resources/gcp/joins/__init__.py +0 -0
  50. ai/chronon/resources/gcp/joins/test/data.py +4 -8
  51. ai/chronon/resources/gcp/sources/__init__.py +0 -0
  52. ai/chronon/resources/gcp/sources/test/data.py +9 -6
  53. ai/chronon/resources/gcp/teams.py +9 -21
  54. ai/chronon/source.py +2 -4
  55. ai/chronon/staging_query.py +60 -19
  56. ai/chronon/types.py +3 -2
  57. ai/chronon/utils.py +21 -68
  58. ai/chronon/windows.py +2 -4
  59. {awx_zipline_ai-0.2.1.dist-info → awx_zipline_ai-0.3.1.dist-info}/METADATA +48 -24
  60. awx_zipline_ai-0.3.1.dist-info/RECORD +96 -0
  61. awx_zipline_ai-0.3.1.dist-info/top_level.txt +4 -0
  62. gen_thrift/__init__.py +0 -0
  63. {ai/chronon → gen_thrift}/api/ttypes.py +327 -197
  64. {ai/chronon/api → gen_thrift}/common/ttypes.py +9 -39
  65. gen_thrift/eval/ttypes.py +660 -0
  66. {ai/chronon → gen_thrift}/hub/ttypes.py +12 -131
  67. {ai/chronon → gen_thrift}/observability/ttypes.py +343 -180
  68. {ai/chronon → gen_thrift}/planner/ttypes.py +326 -45
  69. ai/chronon/eval/__init__.py +0 -122
  70. ai/chronon/eval/query_parsing.py +0 -19
  71. ai/chronon/eval/sample_tables.py +0 -100
  72. ai/chronon/eval/table_scan.py +0 -186
  73. ai/chronon/orchestration/ttypes.py +0 -4406
  74. ai/chronon/resources/gcp/README.md +0 -174
  75. ai/chronon/resources/gcp/zipline-cli-install.sh +0 -54
  76. awx_zipline_ai-0.2.1.dist-info/RECORD +0 -93
  77. awx_zipline_ai-0.2.1.dist-info/licenses/LICENSE +0 -202
  78. awx_zipline_ai-0.2.1.dist-info/top_level.txt +0 -3
  79. /jars/__init__.py → /__init__.py +0 -0
  80. {awx_zipline_ai-0.2.1.dist-info → awx_zipline_ai-0.3.1.dist-info}/WHEEL +0 -0
  81. {awx_zipline_ai-0.2.1.dist-info → awx_zipline_ai-0.3.1.dist-info}/entry_points.txt +0 -0
  82. {ai/chronon → gen_thrift}/api/__init__.py +0 -0
  83. {ai/chronon/api/common → gen_thrift/api}/constants.py +0 -0
  84. {ai/chronon/api → gen_thrift}/common/__init__.py +0 -0
  85. {ai/chronon/api → gen_thrift/common}/constants.py +0 -0
  86. {ai/chronon/fetcher → gen_thrift/eval}/__init__.py +0 -0
  87. {ai/chronon/fetcher → gen_thrift/eval}/constants.py +0 -0
  88. {ai/chronon/hub → gen_thrift/fetcher}/__init__.py +0 -0
  89. {ai/chronon/hub → gen_thrift/fetcher}/constants.py +0 -0
  90. {ai/chronon → gen_thrift}/fetcher/ttypes.py +0 -0
  91. {ai/chronon/observability → gen_thrift/hub}/__init__.py +0 -0
  92. {ai/chronon/observability → gen_thrift/hub}/constants.py +0 -0
  93. {ai/chronon/orchestration → gen_thrift/observability}/__init__.py +0 -0
  94. {ai/chronon/orchestration → gen_thrift/observability}/constants.py +0 -0
  95. {ai/chronon → gen_thrift}/planner/__init__.py +0 -0
  96. {ai/chronon → gen_thrift}/planner/constants.py +0 -0
ai/chronon/repo/utils.py CHANGED
@@ -1,3 +1,4 @@
1
+ import functools
1
2
  import json
2
3
  import os
3
4
  import re
@@ -7,6 +8,8 @@ import xml.etree.ElementTree as ET
7
8
  from datetime import datetime, timedelta
8
9
  from enum import Enum
9
10
 
11
+ from click import style
12
+
10
13
  from ai.chronon.cli.compile.parse_teams import EnvOrConfigAttribute
11
14
  from ai.chronon.logger import get_logger
12
15
  from ai.chronon.repo.constants import (
@@ -17,6 +20,7 @@ from ai.chronon.repo.constants import (
17
20
 
18
21
  LOG = get_logger()
19
22
 
23
+
20
24
  class JobType(Enum):
21
25
  SPARK = "spark"
22
26
  FLINK = "flink"
@@ -52,9 +56,11 @@ def get_environ_arg(env_name, ignoreError=False) -> str:
52
56
  raise ValueError(f"Please set {env_name} environment variable")
53
57
  return value
54
58
 
59
+
55
60
  def get_customer_warehouse_bucket() -> str:
56
61
  return f"zipline-warehouse-{get_customer_id()}"
57
62
 
63
+
58
64
  def get_customer_id() -> str:
59
65
  return get_environ_arg("CUSTOMER_ID")
60
66
 
@@ -94,17 +100,13 @@ def download_only_once(url, path, skip_download=False):
94
100
  LOG.info(
95
101
  """Files sizes of {url} vs. {path}
96
102
  Remote size: {remote_size}
97
- Local size : {local_size}""".format(
98
- **locals()
99
- )
103
+ Local size : {local_size}""".format(**locals())
100
104
  )
101
105
  if local_size == remote_size:
102
106
  LOG.info("Sizes match. Assuming it's already downloaded.")
103
107
  should_download = False
104
108
  if should_download:
105
- LOG.info(
106
- "Different file from remote at local: " + path + ". Re-downloading.."
107
- )
109
+ LOG.info("Different file from remote at local: " + path + ". Re-downloading..")
108
110
  check_call("curl {} -o {} --connect-timeout 10".format(url, path))
109
111
  else:
110
112
  LOG.info("No file at: " + path + ". Downloading..")
@@ -126,9 +128,7 @@ def download_jar(
126
128
  )
127
129
  scala_version = SCALA_VERSION_FOR_SPARK[spark_version]
128
130
  maven_url_prefix = os.environ.get("CHRONON_MAVEN_MIRROR_PREFIX", None)
129
- default_url_prefix = (
130
- "https://s01.oss.sonatype.org/service/local/repositories/public/content"
131
- )
131
+ default_url_prefix = "https://s01.oss.sonatype.org/service/local/repositories/public/content"
132
132
  url_prefix = maven_url_prefix if maven_url_prefix else default_url_prefix
133
133
  base_url = "{}/ai/chronon/spark_{}_{}".format(url_prefix, jar_type, scala_version)
134
134
  LOG.info("Downloading jar from url: " + base_url)
@@ -137,9 +137,7 @@ def download_jar(
137
137
  if version == "latest":
138
138
  version = None
139
139
  if version is None:
140
- metadata_content = check_output(
141
- "curl -s {}/maven-metadata.xml".format(base_url)
142
- )
140
+ metadata_content = check_output("curl -s {}/maven-metadata.xml".format(base_url))
143
141
  meta_tree = ET.fromstring(metadata_content)
144
142
  versions = [
145
143
  node.text
@@ -152,11 +150,13 @@ def download_jar(
152
150
  )
153
151
  ]
154
152
  version = versions[-1]
155
- jar_url = "{base_url}/{version}/spark_{jar_type}_{scala_version}-{version}-assembly.jar".format(
156
- base_url=base_url,
157
- version=version,
158
- scala_version=scala_version,
159
- jar_type=jar_type,
153
+ jar_url = (
154
+ "{base_url}/{version}/spark_{jar_type}_{scala_version}-{version}-assembly.jar".format(
155
+ base_url=base_url,
156
+ version=version,
157
+ scala_version=scala_version,
158
+ jar_type=jar_type,
159
+ )
160
160
  )
161
161
  jar_path = os.path.join("/tmp", extract_filename_from_path(jar_url))
162
162
  download_only_once(jar_url, jar_path, skip_download)
@@ -170,6 +170,7 @@ def get_teams_json_file_path(repo_path):
170
170
  def get_teams_py_file_path(repo_path):
171
171
  return os.path.join(repo_path, "teams.py")
172
172
 
173
+
173
174
  def set_runtime_env_v3(params, conf):
174
175
  effective_mode = params.get("mode")
175
176
 
@@ -181,9 +182,14 @@ def set_runtime_env_v3(params, conf):
181
182
  if os.path.isfile(conf_path):
182
183
  with open(conf_path, "r") as infile:
183
184
  conf_json = json.load(infile)
184
- metadata = conf_json.get("metaData", {}) or conf_json # user may just pass metadata as the entire json
185
+ metadata = (
186
+ conf_json.get("metaData", {}) or conf_json
187
+ ) # user may just pass metadata as the entire json
185
188
  env = metadata.get("executionInfo", {}).get("env", {})
186
- runtime_env.update(env.get(EnvOrConfigAttribute.ENV,{}).get(effective_mode,{}) or env.get("common", {}))
189
+ runtime_env.update(
190
+ env.get(EnvOrConfigAttribute.ENV, {}).get(effective_mode, {})
191
+ or env.get("common", {})
192
+ )
187
193
  # Also set APP_NAME
188
194
  try:
189
195
  _, conf_type, team, _ = conf.split("/")[-4:]
@@ -206,23 +212,14 @@ def set_runtime_env_v3(params, conf):
206
212
  except Exception:
207
213
  LOG.warn(
208
214
  "Failed to set APP_NAME due to invalid conf path: {}, please ensure to supply the "
209
- "relative path to zipline/ folder".format(
210
- conf
211
- )
215
+ "relative path to zipline/ folder".format(conf)
212
216
  )
213
217
  else:
214
218
  if not params.get("app_name") and not os.environ.get("APP_NAME"):
215
219
  # Provide basic app_name when no conf is defined.
216
220
  # Modes like metadata-upload and metadata-export can rely on conf-type or folder rather than a conf.
217
221
  runtime_env["APP_NAME"] = "_".join(
218
- [
219
- k
220
- for k in [
221
- "chronon",
222
- effective_mode.replace("-", "_")
223
- ]
224
- if k is not None
225
- ]
222
+ [k for k in ["chronon", effective_mode.replace("-", "_")] if k is not None]
226
223
  )
227
224
  for key, value in runtime_env.items():
228
225
  if key not in os.environ and value is not None:
@@ -230,6 +227,7 @@ def set_runtime_env_v3(params, conf):
230
227
  print(f"Setting to environment: {key}={value}")
231
228
  os.environ[key] = value
232
229
 
230
+
233
231
  # TODO: delete this when we cutover
234
232
  def set_runtime_env(params):
235
233
  """
@@ -263,20 +261,15 @@ def set_runtime_env(params):
263
261
  if effective_mode and "streaming" in effective_mode:
264
262
  effective_mode = "streaming"
265
263
  if params["repo"]:
266
-
267
264
  # Break if teams.json and teams.py exists
268
265
  teams_json_file = get_teams_json_file_path(params["repo"])
269
266
  teams_py_file = get_teams_py_file_path(params["repo"])
270
267
 
271
268
  if os.path.exists(teams_json_file) and os.path.exists(teams_py_file):
272
- raise ValueError(
273
- "Both teams.json and teams.py exist. Please only use teams.py."
274
- )
269
+ raise ValueError("Both teams.json and teams.py exist. Please only use teams.py.")
275
270
 
276
271
  if os.path.exists(teams_json_file):
277
- set_runtime_env_teams_json(
278
- environment, params, effective_mode, teams_json_file
279
- )
272
+ set_runtime_env_teams_json(environment, params, effective_mode, teams_json_file)
280
273
  if params["app_name"]:
281
274
  environment["cli_args"]["APP_NAME"] = params["app_name"]
282
275
  else:
@@ -289,11 +282,7 @@ def set_runtime_env(params):
289
282
  for k in [
290
283
  "chronon",
291
284
  conf_type,
292
- (
293
- params["mode"].replace("-", "_")
294
- if params["mode"]
295
- else None
296
- ),
285
+ (params["mode"].replace("-", "_") if params["mode"] else None),
297
286
  ]
298
287
  if k is not None
299
288
  ]
@@ -321,6 +310,7 @@ def set_runtime_env(params):
321
310
  LOG.info(f"From <{set_key}> setting {key}={value}")
322
311
  os.environ[key] = value
323
312
 
313
+
324
314
  # TODO: delete this when we cutover
325
315
  def set_runtime_env_teams_json(environment, params, effective_mode, teams_json_file):
326
316
  if os.path.exists(teams_json_file):
@@ -346,9 +336,7 @@ def set_runtime_env_teams_json(environment, params, effective_mode, teams_json_f
346
336
  context = params["env"]
347
337
  else:
348
338
  context = "dev"
349
- LOG.info(
350
- f"Context: {context} -- conf_type: {conf_type} -- team: {team}"
351
- )
339
+ LOG.info(f"Context: {context} -- conf_type: {conf_type} -- team: {team}")
352
340
  conf_path = os.path.join(params["repo"], params["conf"])
353
341
  if os.path.isfile(conf_path):
354
342
  with open(conf_path, "r") as conf_file:
@@ -362,9 +350,7 @@ def set_runtime_env_teams_json(environment, params, effective_mode, teams_json_f
362
350
  )
363
351
 
364
352
  old_env = (
365
- conf_json.get("metaData")
366
- .get("modeToEnvMap", {})
367
- .get(effective_mode, {})
353
+ conf_json.get("metaData").get("modeToEnvMap", {}).get(effective_mode, {})
368
354
  )
369
355
 
370
356
  environment["conf_env"] = new_env if new_env else old_env
@@ -375,8 +361,8 @@ def set_runtime_env_teams_json(environment, params, effective_mode, teams_json_f
375
361
  "backfill-left",
376
362
  "backfill-final",
377
363
  ]:
378
- environment["conf_env"]["CHRONON_CONFIG_ADDITIONAL_ARGS"] = (
379
- " ".join(custom_json(conf_json).get("additional_args", []))
364
+ environment["conf_env"]["CHRONON_CONFIG_ADDITIONAL_ARGS"] = " ".join(
365
+ custom_json(conf_json).get("additional_args", [])
380
366
  )
381
367
  environment["cli_args"]["APP_NAME"] = APP_NAME_TEMPLATE.format(
382
368
  mode=effective_mode,
@@ -384,18 +370,14 @@ def set_runtime_env_teams_json(environment, params, effective_mode, teams_json_f
384
370
  context=context,
385
371
  name=conf_json["metaData"]["name"],
386
372
  )
387
- environment["team_env"] = (
388
- teams_json[team].get(context, {}).get(effective_mode, {})
389
- )
373
+ environment["team_env"] = teams_json[team].get(context, {}).get(effective_mode, {})
390
374
  # fall-back to prod env even in dev mode when dev env is undefined.
391
375
  environment["production_team_env"] = (
392
376
  teams_json[team].get("production", {}).get(effective_mode, {})
393
377
  )
394
378
  # By default use production env.
395
379
  environment["default_env"] = (
396
- teams_json.get("default", {})
397
- .get("production", {})
398
- .get(effective_mode, {})
380
+ teams_json.get("default", {}).get("production", {}).get(effective_mode, {})
399
381
  )
400
382
  environment["cli_args"]["CHRONON_CONF_PATH"] = conf_path
401
383
  if params["app_name"]:
@@ -444,9 +426,7 @@ def split_date_range(start_date, end_date, parallelism):
444
426
  end_date = datetime.strptime(end_date, "%Y-%m-%d")
445
427
  if start_date > end_date:
446
428
  raise ValueError("Start date should be earlier than end date")
447
- total_days = (
448
- end_date - start_date
449
- ).days + 1 # +1 to include the end_date in the range
429
+ total_days = (end_date - start_date).days + 1 # +1 to include the end_date in the range
450
430
 
451
431
  # Check if parallelism is greater than total_days
452
432
  if parallelism > total_days:
@@ -461,12 +441,41 @@ def split_date_range(start_date, end_date, parallelism):
461
441
  split_end = end_date
462
442
  else:
463
443
  split_end = split_start + timedelta(days=split_size - 1)
464
- date_ranges.append(
465
- (split_start.strftime("%Y-%m-%d"), split_end.strftime("%Y-%m-%d"))
466
- )
444
+ date_ranges.append((split_start.strftime("%Y-%m-%d"), split_end.strftime("%Y-%m-%d")))
467
445
  return date_ranges
468
446
 
447
+
469
448
  def get_metadata_name_from_conf(repo_path, conf_path):
470
449
  with open(os.path.join(repo_path, conf_path), "r") as conf_file:
471
450
  data = json.load(conf_file)
472
- return data.get("metaData", {}).get("name", None)
451
+ return data.get("metaData", {}).get("name", None)
452
+
453
+
454
+ def handle_conf_not_found(log_error=True, callback=None):
455
+ def wrapper(func):
456
+ @functools.wraps(func)
457
+ def wrapped(*args, **kwargs):
458
+ try:
459
+ return func(*args, **kwargs)
460
+ except FileNotFoundError as e:
461
+ if log_error:
462
+ print(style(f"File not found in {func.__name__}: {e}", fg="red"))
463
+ if callback:
464
+ callback(*args, **kwargs)
465
+ return
466
+
467
+ return wrapped
468
+
469
+ return wrapper
470
+
471
+
472
+ def print_possible_confs(conf, repo, *args, **kwargs):
473
+ conf_location = os.path.join(repo, conf)
474
+ conf_dirname = os.path.dirname(conf_location)
475
+ if os.path.exists(conf_dirname):
476
+ print(
477
+ f"Possible confs from {style(conf_dirname, fg='yellow')}: \n -",
478
+ "\n - ".join([name for name in os.listdir(conf_dirname)]),
479
+ )
480
+ else:
481
+ print(f"Directory does not exist: {style(conf_dirname, fg='yellow')}")
@@ -9,24 +9,6 @@ from ai.chronon.repo.hub_runner import hub
9
9
  from ai.chronon.repo.init import main as init_main
10
10
  from ai.chronon.repo.run import main as run_main
11
11
 
12
- LOGO = """
13
- =%%%@:-%%%@=:%%%@+ .%@%@@@@@@%%%%%%: .+#%*. -%%%= -#%#-
14
- :@@@@#.@@@@%.%@@@@. .@@@@@@@@@@@@@@- -@@@@= =@@@+ @@@@@
15
- :@@@@*.%@@@#.#@@@%. .#@@@@: :==: =@@@+ -=- :
16
- =@@@@=-@@@@+:%@@@#. #@@@%. :--: .%%=:+#%@@@%#+- =@@@+ .-:-. *%= #%%* :=#%@@@@#*-
17
- .#@@@#-+@@@%-=@@@@- .%@@@%. @@@@ .@@@@@@@@%%@@@@%= =@@@+ +@@@= *@@@+. %@@% :#@@@@%%%@@@@@=
18
- +**+=-%@@@+-#@@@*----=. :@@@@# %@@@ .@@@@%=. .-#@@@* =@@@+ +@@@= *@@@@@*: %@@% -@@@%- .+@@@*
19
- +@@@%-+@@@%-=@@@@+ :@@@@* @@@@ .@@@@. #@@@: =@@@+ +@@@= *@@@%@@@*: %@@% %@@@#++****+*@@@@-
20
- -@@@@+:#@@@*:#@@@#. -@@@@* @@@@ .@@@@ *@@@- =@@@+ +@@@= *@@@.-%@@@#-%@@% @@@@****#****++++:
21
- =@@@@--@@@@=:@@@@* =@@@@+ @@@@ .@@@@#. .+@@@% =@@@+ +@@@= *@@@ -#@@@@@@% =@@@*.
22
- +@@@@--@@@@=:@@@@* +@@@@@#########+ @@@@ .@@@@@@%*+*#@@@@* =@@@+ +@@@= *@@@. :#@@@@% =@@@@% -==+-
23
- :@@@@* @@@@# @@@@% *@@@@@@@@@@@@@@@% @@@@ .@@@@#@@@@@@@%+: =@@@+ +@@@= *@@@. :*@@% .=#@@@@@@@%*:
24
- .@@@%
25
- .@@@%
26
- .@@@@
27
- ---:
28
- """
29
-
30
12
 
31
13
  def _set_package_version():
32
14
  try:
@@ -37,7 +19,9 @@ def _set_package_version():
37
19
  return package_version
38
20
 
39
21
 
40
- @click.group(help="The Zipline CLI. A tool for authoring and running Zipline pipelines in the cloud. For more information, see: https://chronon.ai/")
22
+ @click.group(
23
+ help="The Zipline CLI. A tool for compiling and running Zipline pipelines. For more information, see: https://zipline.ai/docs"
24
+ )
41
25
  @click.version_option(version=_set_package_version())
42
26
  @click.pass_context
43
27
  def zipline(ctx):
@@ -1,105 +1,277 @@
1
+ import json
1
2
  import os
3
+ from datetime import date, datetime, timedelta, timezone
2
4
  from typing import Optional
3
5
 
4
6
  import google.auth
5
7
  import requests
6
8
  from google.auth.transport.requests import Request
9
+ from google.cloud import iam_credentials_v1
7
10
 
8
11
 
9
12
  class ZiplineHub:
10
- def __init__(self, base_url):
13
+ def __init__(self, base_url, sa_name=None):
11
14
  if not base_url:
12
15
  raise ValueError("Base URL for ZiplineHub cannot be empty.")
13
16
  self.base_url = base_url
14
- if self.base_url.startswith("https") and self.base_url.endswith(".app"):
17
+ if self.base_url.startswith("https"):
15
18
  print("\n 🔐 Using Google Cloud authentication for ZiplineHub.")
16
19
 
17
20
  # First try to get ID token from environment (GitHub Actions)
18
- self.id_token = os.getenv('GCP_ID_TOKEN')
21
+ self.id_token = os.getenv("GCP_ID_TOKEN")
19
22
  if self.id_token:
20
23
  print(" 🔑 Using ID token from environment")
21
- else:
24
+ elif sa_name is not None:
22
25
  # Fallback to Google Cloud authentication
26
+ print(" 🔑 Generating ID token from service account credentials")
27
+ credentials, project_id = google.auth.default()
28
+ self.project_id = project_id
29
+ credentials.refresh(Request())
30
+
31
+ self.sa = f"{sa_name}@{project_id}.iam.gserviceaccount.com"
32
+ else:
23
33
  print(" 🔑 Generating ID token from default credentials")
24
34
  credentials, project_id = google.auth.default()
25
35
  credentials.refresh(Request())
36
+ self.sa = None
26
37
  self.id_token = credentials.id_token
27
38
 
28
- def call_diff_api(self, names_to_hashes: dict[str, str]) -> Optional[list[str]]:
29
- url = f"{self.base_url}/upload/v1/diff"
39
+ def _generate_jwt_payload(self, service_account_email: str, resource_url: str) -> str:
40
+ """Generates JWT payload for service account.
30
41
 
31
- diff_request = {
32
- 'namesToHashes': names_to_hashes
42
+ Creates a properly formatted JWT payload with standard claims (iss, sub, aud,
43
+ iat, exp) needed for IAP authentication.
44
+
45
+ Args:
46
+ service_account_email (str): Specifies service account JWT is created for.
47
+ resource_url (str): Specifies scope of the JWT, the URL that the JWT will
48
+ be allowed to access.
49
+
50
+ Returns:
51
+ str: JSON string containing the JWT payload with properly formatted claims.
52
+ """
53
+ # Create current time and expiration time (1 hour later) in UTC
54
+ iat = datetime.now(tz=timezone.utc)
55
+ exp = iat + timedelta(seconds=3600)
56
+
57
+ # Convert datetime objects to numeric timestamps (seconds since epoch)
58
+ # as required by JWT standard (RFC 7519)
59
+ payload = {
60
+ "iss": service_account_email,
61
+ "sub": service_account_email,
62
+ "aud": resource_url,
63
+ "iat": int(iat.timestamp()),
64
+ "exp": int(exp.timestamp()),
33
65
  }
34
- headers = {'Content-Type': 'application/json'}
35
- if hasattr(self, 'id_token'):
36
- headers['Authorization'] = f'Bearer {self.id_token}'
66
+
67
+ return json.dumps(payload)
68
+
69
+ def _sign_jwt(self, target_sa: str, resource_url: str) -> str:
70
+ """Signs JWT payload using ADC and IAM credentials API.
71
+
72
+ Uses Google Cloud's IAM Credentials API to sign a JWT. This requires the
73
+ caller to have iap.webServiceVersions.accessViaIap permission on the target
74
+ service account.
75
+
76
+ Args:
77
+ target_sa (str): Service Account JWT is being created for.
78
+ iap.webServiceVersions.accessViaIap permission is required.
79
+ resource_url (str): Audience of the JWT, and scope of the JWT token.
80
+ This is the url of the IAP protected application.
81
+
82
+ Returns:
83
+ str: A signed JWT that can be used to access IAP protected apps.
84
+ Use in Authorization header as: 'Bearer <signed_jwt>'
85
+ """
86
+ # Get default credentials from environment or application credentials
87
+ source_credentials, project_id = google.auth.default()
88
+
89
+ # Initialize IAM credentials client with source credentials
90
+ iam_client = iam_credentials_v1.IAMCredentialsClient(credentials=source_credentials)
91
+
92
+ # Generate the service account resource name
93
+ # Use '-' as placeholder as per API requirements
94
+ name = iam_client.service_account_path("-", target_sa)
95
+
96
+ # Create and sign the JWT payload
97
+ payload = self._generate_jwt_payload(target_sa, resource_url)
98
+
99
+ request = iam_credentials_v1.SignJwtRequest(
100
+ name=name,
101
+ payload=payload,
102
+ )
103
+ # Sign the JWT using the IAM credentials API
104
+ response = iam_client.sign_jwt(request=request)
105
+
106
+ return response.signed_jwt
107
+
108
+ def call_diff_api(self, names_to_hashes: dict[str, str]) -> Optional[list[str]]:
109
+ url = f"{self.base_url}/upload/v2/diff"
110
+
111
+ diff_request = {"namesToHashes": names_to_hashes}
112
+ headers = {"Content-Type": "application/json"}
113
+ if self.base_url.startswith("https") and hasattr(self, "sa") and self.sa is not None:
114
+ headers["Authorization"] = f"Bearer {self._sign_jwt(self.sa, url)}"
115
+ elif self.base_url.startswith("https"):
116
+ headers["Authorization"] = f"Bearer {self.id_token}"
37
117
  try:
38
118
  response = requests.post(url, json=diff_request, headers=headers)
39
119
  response.raise_for_status()
40
120
  diff_response = response.json()
41
- return diff_response['diff']
121
+ return diff_response["diff"]
42
122
  except requests.RequestException as e:
43
- print(f" Error calling diff API: {e}")
123
+ if e.response is not None and e.response.status_code == 401 and self.sa is None:
124
+ print(
125
+ " ❌ Error calling diff API. Unauthorized and no service account provided. Make sure the environment has default credentials set up or provide a service account name as SA_NAME in teams.py."
126
+ )
127
+ elif e.response is not None and e.response.status_code == 401 and self.sa is not None:
128
+ print(
129
+ f" ❌ Error calling diff API. Unauthorized with provided service account: {self.sa}. Make sure the service account has the 'iap.webServiceVersions.accessViaIap' permission."
130
+ )
131
+ else:
132
+ print(f" ❌ Error calling diff API: {e}")
44
133
  raise e
45
134
 
46
135
  def call_upload_api(self, diff_confs, branch: str):
47
- url = f"{self.base_url}/upload/v1/confs"
136
+ url = f"{self.base_url}/upload/v2/confs"
48
137
 
49
138
  upload_request = {
50
- 'diffConfs': diff_confs,
51
- 'branch': branch,
139
+ "diffConfs": diff_confs,
140
+ "branch": branch,
52
141
  }
53
- headers = {'Content-Type': 'application/json'}
54
- if hasattr(self, 'id_token'):
55
- headers['Authorization'] = f'Bearer {self.id_token}'
142
+ headers = {"Content-Type": "application/json"}
143
+ if self.base_url.startswith("https") and hasattr(self, "sa") and self.sa is not None:
144
+ headers["Authorization"] = f"Bearer {self._sign_jwt(self.sa, url)}"
145
+ elif self.base_url.startswith("https"):
146
+ headers["Authorization"] = f"Bearer {self.id_token}"
56
147
 
57
148
  try:
58
149
  response = requests.post(url, json=upload_request, headers=headers)
59
150
  response.raise_for_status()
60
151
  return response.json()
61
152
  except requests.RequestException as e:
62
- print(f" Error calling upload API: {e}")
153
+ if e.response is not None and e.response.status_code == 401 and self.sa is None:
154
+ print(
155
+ " ❌ Error calling upload API. Unauthorized and no service account provided. Make sure the environment has default credentials set up or provide a service account name as SA_NAME in teams.py."
156
+ )
157
+ elif e.response is not None and e.response.status_code == 401 and self.sa is not None:
158
+ print(
159
+ f" ❌ Error calling upload API. Unauthorized with provided service account: {self.sa}. Make sure the service account has the 'iap.webServiceVersions.accessViaIap' permission."
160
+ )
161
+ else:
162
+ print(f" ❌ Error calling upload API: {e}")
163
+ raise e
164
+
165
+ def call_schedule_api(self, modes, branch, conf_name, conf_hash):
166
+ url = f"{self.base_url}/schedule/v2/schedules"
167
+
168
+ schedule_request = {
169
+ "modeSchedules": modes,
170
+ "branch": branch,
171
+ "confName": conf_name,
172
+ "confHash": conf_hash,
173
+ }
174
+
175
+ headers = {"Content-Type": "application/json"}
176
+ if self.base_url.startswith("https") and hasattr(self, "sa") and self.sa is not None:
177
+ headers["Authorization"] = f"Bearer {self._sign_jwt(self.sa, url)}"
178
+ elif self.base_url.startswith("https"):
179
+ headers["Authorization"] = f"Bearer {self.id_token}"
180
+
181
+ try:
182
+ response = requests.post(url, json=schedule_request, headers=headers)
183
+ response.raise_for_status()
184
+ return response.json()
185
+ except requests.RequestException as e:
186
+ if e.response is not None and e.response.status_code == 401 and self.sa is None:
187
+ print(
188
+ " ❌ Error deploying schedule. Unauthorized and no service account provided. Make sure the environment has default credentials set up or provide a service account name as SA_NAME in teams.py."
189
+ )
190
+ elif e.response is not None and e.response.status_code == 401 and self.sa is not None:
191
+ print(
192
+ f" ❌ Error deploying schedule. Unauthorized with provided service account: {self.sa}. Make sure the service account has the 'iap.webServiceVersions.accessViaIap' permission."
193
+ )
194
+ else:
195
+ print(f" ❌ Error deploying schedule: {e}")
63
196
  raise e
64
197
 
65
198
  def call_sync_api(self, branch: str, names_to_hashes: dict[str, str]) -> Optional[list[str]]:
66
- url = f"{self.base_url}/upload/v1/sync"
199
+ url = f"{self.base_url}/upload/v2/sync"
67
200
 
68
201
  sync_request = {
69
202
  "namesToHashes": names_to_hashes,
70
203
  "branch": branch,
71
204
  }
72
- headers = {'Content-Type': 'application/json'}
73
- if hasattr(self, 'id_token'):
74
- headers['Authorization'] = f'Bearer {self.id_token}'
205
+ headers = {"Content-Type": "application/json"}
206
+ if self.base_url.startswith("https") and hasattr(self, "sa") and self.sa is not None:
207
+ headers["Authorization"] = f"Bearer {self._sign_jwt(self.sa, url)}"
208
+ elif self.base_url.startswith("https"):
209
+ headers["Authorization"] = f"Bearer {self.id_token}"
210
+
75
211
  try:
76
212
  response = requests.post(url, json=sync_request, headers=headers)
77
213
  response.raise_for_status()
78
214
  return response.json()
79
215
  except requests.RequestException as e:
80
- print(f" Error calling diff API: {e}")
216
+ if e.response is not None and e.response.status_code == 401 and self.sa is None:
217
+ print(
218
+ " ❌ Error calling sync API. Unauthorized and no service account provided. Make sure the environment has default credentials set up or provide a service account name as SA_NAME in teams.py."
219
+ )
220
+ elif e.response is not None and e.response.status_code == 401 and self.sa is not None:
221
+ print(
222
+ f" ❌ Error calling sync API. Unauthorized with provided service account: {self.sa}. Make sure the service account has the 'iap.webServiceVersions.accessViaIap' permission."
223
+ )
224
+ else:
225
+ print(f" ❌ Error calling sync API: {e}")
81
226
  raise e
82
227
 
83
- def call_workflow_start_api(self, conf_name, mode, branch, user, start, end, conf_hash):
84
- url = f"{self.base_url}/workflow/start"
85
-
228
+ def call_workflow_start_api(
229
+ self,
230
+ conf_name,
231
+ mode,
232
+ branch,
233
+ user,
234
+ conf_hash,
235
+ start=None,
236
+ end=None,
237
+ skip_long_running=False,
238
+ ):
239
+ url = f"{self.base_url}/workflow/v2/start"
240
+ end_dt = end.strftime("%Y-%m-%d") if end else date.today().strftime("%Y-%m-%d")
241
+ start_dt = (
242
+ start.strftime("%Y-%m-%d")
243
+ if start
244
+ else (date.today() - timedelta(days=14)).strftime("%Y-%m-%d")
245
+ )
86
246
  workflow_request = {
87
- 'confName': conf_name,
88
- 'confHash': conf_hash,
89
- 'mode': mode,
90
- 'branch': branch,
91
- 'user': user,
92
- 'start': start,
93
- 'end': end,
247
+ "confName": conf_name,
248
+ "confHash": conf_hash,
249
+ "mode": mode,
250
+ "branch": branch,
251
+ "user": user,
252
+ "start": start_dt,
253
+ "end": end_dt,
254
+ "skipLongRunningNodes": skip_long_running,
94
255
  }
95
- headers = {'Content-Type': 'application/json'}
96
- if hasattr(self, 'id_token'):
97
- headers['Authorization'] = f'Bearer {self.id_token}'
256
+ headers = {"Content-Type": "application/json"}
257
+ if self.base_url.startswith("https") and hasattr(self, "sa") and self.sa is not None:
258
+ headers["Authorization"] = f"Bearer {self._sign_jwt(self.sa, url)}"
259
+ elif self.base_url.startswith("https"):
260
+ headers["Authorization"] = f"Bearer {self.id_token}"
98
261
 
99
262
  try:
100
263
  response = requests.post(url, json=workflow_request, headers=headers)
101
264
  response.raise_for_status()
102
265
  return response.json()
103
266
  except requests.RequestException as e:
104
- print(f" Error calling workflow start API: {e}")
267
+ if e.response is not None and e.response.status_code == 401 and self.sa is None:
268
+ print(
269
+ " ❌ Error calling workflow start API. Unauthorized and no service account provided. Make sure the environment has default credentials set up or provide a service account name as SA_NAME in teams.py."
270
+ )
271
+ elif e.response is not None and e.response.status_code == 401 and self.sa is not None:
272
+ print(
273
+ f" ❌ Error calling workflow start API. Unauthorized with provided service account: {self.sa}. Make sure the service account has the 'iap.webServiceVersions.accessViaIap' permission."
274
+ )
275
+ else:
276
+ print(f" ❌ Error calling workflow start API: {e}")
105
277
  raise e
File without changes
File without changes
File without changes