primitive 0.2.9__py3-none-any.whl → 0.2.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
primitive/__about__.py CHANGED
@@ -1,4 +1,4 @@
1
1
  # SPDX-FileCopyrightText: 2024-present Dylan Stein <dylan@primitive.tech>
2
2
  #
3
3
  # SPDX-License-Identifier: MIT
4
- __version__ = "0.2.9"
4
+ __version__ = "0.2.11"
@@ -6,25 +6,29 @@ from loguru import logger
6
6
  from primitive.__about__ import __version__
7
7
  from primitive.utils.actions import BaseAction
8
8
 
9
- from ..utils.exceptions import P_CLI_100
10
9
  from .runner import Runner
11
10
  from .uploader import Uploader
11
+ from ..db import sqlite
12
+ from ..db.models import JobRun
12
13
 
13
14
 
14
15
  class Agent(BaseAction):
15
16
  def execute(
16
17
  self,
17
18
  ):
18
- logger.enable("primitive")
19
19
  logger.remove()
20
20
  logger.add(
21
21
  sink=sys.stderr,
22
- # catch=True,
22
+ format="<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | <level>{level: <8}</level> | <level>{message}</level>",
23
23
  backtrace=True,
24
24
  diagnose=True,
25
+ level="DEBUG" if self.primitive.DEBUG else "INFO",
25
26
  )
26
- logger.info(" [*] primitive")
27
- logger.info(f" [*] Version: {__version__}")
27
+ logger.info("[*] primitive agent")
28
+ logger.info(f"[*] Version: {__version__}")
29
+
30
+ # Initialize the database
31
+ sqlite.init()
28
32
 
29
33
  # Create uploader
30
34
  uploader = Uploader(primitive=self.primitive)
@@ -49,9 +53,6 @@ class Agent(BaseAction):
49
53
  logger.debug("Scanning for files to upload...")
50
54
  uploader.scan()
51
55
 
52
- logger.debug("Syncing children...")
53
- self.primitive.hardware._sync_children()
54
-
55
56
  hardware = self.primitive.hardware.get_own_hardware_details()
56
57
 
57
58
  if hardware["activeReservation"]:
@@ -108,6 +109,7 @@ class Agent(BaseAction):
108
109
  ]
109
110
 
110
111
  if not pending_job_runs:
112
+ self.primitive.hardware.check_in_http(is_online=True)
111
113
  sleep_amount = 5
112
114
  logger.debug(
113
115
  f"Waiting for Job Runs... [sleeping {sleep_amount} seconds]"
@@ -120,6 +122,11 @@ class Agent(BaseAction):
120
122
  logger.debug(f"Job Run ID: {job_run['id']}")
121
123
  logger.debug(f"Job Name: {job_run['job']['name']}")
122
124
 
125
+ JobRun.objects.create(
126
+ job_run_id=job_run["id"],
127
+ pid=None,
128
+ )
129
+
123
130
  runner = Runner(
124
131
  primitive=self.primitive,
125
132
  job_run=job_run,
@@ -137,6 +144,7 @@ class Agent(BaseAction):
137
144
  status="request_completed",
138
145
  conclusion="failure",
139
146
  )
147
+ JobRun.objects.filter_by(job_run_id=job_run["id"]).delete()
140
148
  continue
141
149
 
142
150
  try:
@@ -152,17 +160,11 @@ class Agent(BaseAction):
152
160
  runner.cleanup()
153
161
 
154
162
  # NOTE: also run scan here to force upload of artifacts
155
- # This should probably eventuall be another daemon?
163
+ # This should probably eventually be another daemon?
156
164
  uploader.scan()
157
165
 
166
+ JobRun.objects.filter_by(job_run_id=job_run["id"]).delete()
167
+
158
168
  sleep(5)
159
169
  except KeyboardInterrupt:
160
- logger.info(" [*] Stopping primitive...")
161
- try:
162
- self.primitive.hardware.check_in_http(
163
- is_available=False, is_online=False, stopping_agent=True
164
- )
165
- except P_CLI_100 as exception:
166
- logger.error(" [*] Error stopping primitive.")
167
- logger.error(str(exception))
168
- sys.exit()
170
+ logger.info("[*] Stopping primitive agent...")
primitive/agent/runner.py CHANGED
@@ -2,12 +2,12 @@ import asyncio
2
2
  import os
3
3
  import re
4
4
  import shutil
5
- import time
6
5
  import typing
7
6
  from abc import abstractmethod
8
7
  from enum import Enum, IntEnum
9
8
  from pathlib import Path, PurePath
10
9
  from typing import Dict, List, TypedDict
10
+ from ..db.models import JobRun
11
11
 
12
12
  import yaml
13
13
  from loguru import logger
@@ -80,8 +80,6 @@ class Runner:
80
80
  self.modified_env = {}
81
81
  self.file_logger = None
82
82
 
83
- logger.enable("primitive")
84
-
85
83
  # If max_log_size set to <= 0, disable file logging
86
84
  if max_log_size > 0:
87
85
  log_name = f"{self.job['slug']}_{self.job_run['jobRunNumber']}_{{time}}.primitive.log"
@@ -158,22 +156,40 @@ class Runner:
158
156
  self.modified_env = {**self.initial_env}
159
157
 
160
158
  task_failed = False
159
+ cancelled = False
161
160
  conclusion = "success"
162
161
  for task in self.config["executes"]:
162
+ status = self.primitive.jobs.get_job_status(self.job_run["id"])
163
+ status_value = status.data["jobRun"]["status"]
164
+ conclusion_value = status.data["jobRun"]["conclusion"]
165
+
166
+ if status_value == "completed" and conclusion_value == "cancelled":
167
+ cancelled = True
168
+ break
169
+
163
170
  with logger.contextualize(label=task["label"]):
164
171
  with asyncio.Runner() as async_runner:
165
172
  if task_failed := async_runner.run(self.run_task(task)):
166
173
  break
167
174
 
175
+ number_of_files_produced = self.get_number_of_files_produced()
176
+ logger.info(
177
+ f"Produced {number_of_files_produced} files for {self.job['slug']} job"
178
+ )
179
+
180
+ if cancelled:
181
+ logger.warning("Job cancelled by user")
182
+ self.primitive.jobs.job_run_update(
183
+ self.job_run["id"],
184
+ number_of_files_produced=number_of_files_produced,
185
+ )
186
+ return
187
+
168
188
  if task_failed:
169
189
  conclusion = "failure"
170
190
  else:
171
191
  logger.success(f"Completed {self.job['slug']} job")
172
192
 
173
- number_of_files_produced = self.get_number_of_files_produced()
174
- logger.info(
175
- f"Produced {number_of_files_produced} files for {self.job['slug']} job"
176
- )
177
193
  self.primitive.jobs.job_run_update(
178
194
  self.job_run["id"],
179
195
  status="request_completed",
@@ -187,11 +203,19 @@ class Runner:
187
203
 
188
204
  # Logs can be produced even if no artifact stores are created for the job run.
189
205
  job_run_logs_cache = get_logs_cache(self.job_run["id"])
190
- log_files = [
191
- file
192
- for _, _, current_path_files in job_run_logs_cache.walk()
193
- for file in current_path_files
194
- ]
206
+ has_walk = getattr(job_run_logs_cache, "walk", None)
207
+ if has_walk:
208
+ log_files = [
209
+ file
210
+ for _, _, current_path_files in job_run_logs_cache.walk()
211
+ for file in current_path_files
212
+ ]
213
+ else:
214
+ log_files = [
215
+ file
216
+ for _, _, current_path_files in os.walk(job_run_logs_cache)
217
+ for file in current_path_files
218
+ ]
195
219
 
196
220
  number_of_files_produced += len(log_files)
197
221
 
@@ -199,11 +223,19 @@ class Runner:
199
223
  return number_of_files_produced
200
224
 
201
225
  job_run_artifacts_cache = get_artifacts_cache(self.job_run["id"])
202
- artifact_files = [
203
- file
204
- for _, _, current_path_files in job_run_artifacts_cache.walk()
205
- for file in current_path_files
206
- ]
226
+ has_walk = getattr(job_run_artifacts_cache, "walk", None)
227
+ if has_walk:
228
+ artifact_files = [
229
+ file
230
+ for _, _, current_path_files in job_run_artifacts_cache.walk()
231
+ for file in current_path_files
232
+ ]
233
+ else:
234
+ artifact_files = [
235
+ file
236
+ for _, _, current_path_files in os.walk(job_run_artifacts_cache)
237
+ for file in current_path_files
238
+ ]
207
239
 
208
240
  number_of_files_produced += len(artifact_files)
209
241
 
@@ -233,24 +265,24 @@ class Runner:
233
265
  stderr=asyncio.subprocess.PIPE,
234
266
  )
235
267
 
236
- loop = asyncio.get_running_loop()
237
- monitor_task = loop.run_in_executor(None, self.monitor_cmd, process)
268
+ JobRun.objects.filter_by(job_run_id=self.job_run["id"]).update(
269
+ {"pid": process.pid}
270
+ )
238
271
 
239
- stdout_failed, stderr_failed, cancelled = await asyncio.gather(
272
+ stdout_failed, stderr_failed = await asyncio.gather(
240
273
  self.log_cmd(
241
274
  process=process, stream=process.stdout, tags=task.get("tags", {})
242
275
  ),
243
276
  self.log_cmd(
244
277
  process=process, stream=process.stderr, tags=task.get("tags", {})
245
278
  ),
246
- monitor_task,
247
279
  )
248
280
 
249
281
  returncode = await process.wait()
250
282
 
251
- if cancelled:
252
- logger.warning("Job cancelled by user")
253
- return True
283
+ JobRun.objects.filter_by(job_run_id=self.job_run["id"]).update(
284
+ {"pid": None}
285
+ )
254
286
 
255
287
  if returncode > 0:
256
288
  logger.error(
@@ -339,25 +371,6 @@ class Runner:
339
371
 
340
372
  return [line for line in lines if len(line) > 0]
341
373
 
342
- def monitor_cmd(self, process) -> bool:
343
- while process.returncode is None:
344
- status = self.primitive.jobs.get_job_status(self.job_run["id"])
345
-
346
- status_value = status.data["jobRun"]["status"]
347
- conclusion_value = status.data["jobRun"]["conclusion"]
348
-
349
- if status_value == "completed" and conclusion_value == "cancelled":
350
- try:
351
- process.terminate()
352
- except ProcessLookupError:
353
- pass
354
-
355
- return True
356
-
357
- time.sleep(10)
358
-
359
- return False
360
-
361
374
  def cleanup(self) -> None:
362
375
  logger.remove(self.file_logger)
363
376
 
primitive/cli.py CHANGED
@@ -16,6 +16,7 @@ from .jobs.commands import cli as jobs_commands
16
16
  from .organizations.commands import cli as organizations_commands
17
17
  from .projects.commands import cli as projects_commands
18
18
  from .reservations.commands import cli as reservations_commands
19
+ from .monitor.commands import cli as monitor_commands
19
20
 
20
21
 
21
22
  @click.group()
@@ -71,6 +72,7 @@ cli.add_command(organizations_commands, "organizations")
71
72
  cli.add_command(projects_commands, "projects")
72
73
  cli.add_command(reservations_commands, "reservations")
73
74
  cli.add_command(exec_commands, "exec")
75
+ cli.add_command(monitor_commands, "monitor")
74
76
 
75
77
  if __name__ == "__main__":
76
78
  cli(obj={})
primitive/client.py CHANGED
@@ -1,7 +1,8 @@
1
- import sys
2
-
3
1
  from gql import Client
4
2
  from loguru import logger
3
+ from rich.logging import RichHandler
4
+ from rich.traceback import install
5
+ from typing import Optional
5
6
 
6
7
  from .agent.actions import Agent
7
8
  from .auth.actions import Auth
@@ -15,10 +16,9 @@ from .organizations.actions import Organizations
15
16
  from .projects.actions import Projects
16
17
  from .provisioning.actions import Provisioning
17
18
  from .reservations.actions import Reservations
19
+ from .monitor.actions import Monitor
18
20
  from .utils.config import read_config_file
19
21
 
20
- logger.disable("primitive")
21
-
22
22
 
23
23
  class Primitive:
24
24
  def __init__(
@@ -26,24 +26,48 @@ class Primitive:
26
26
  host: str = "api.primitive.tech",
27
27
  DEBUG: bool = False,
28
28
  JSON: bool = False,
29
- token: str = None,
30
- transport: str = None,
29
+ token: Optional[str] = None,
30
+ transport: Optional[str] = None,
31
31
  ) -> None:
32
32
  self.host: str = host
33
- self.session: Client = None
33
+ self.session: Optional[Client] = None
34
34
  self.DEBUG: bool = DEBUG
35
35
  self.JSON: bool = JSON
36
36
 
37
+ # Enable tracebacks with local variables
37
38
  if self.DEBUG:
38
- logger.enable("primitive")
39
- logger.remove()
40
- logger.add(
41
- sink=sys.stderr,
42
- serialize=self.JSON,
43
- catch=True,
44
- backtrace=True,
45
- diagnose=True,
46
- )
39
+ install(show_locals=True)
40
+
41
+ # Configure rich logging handler
42
+ rich_handler = RichHandler(
43
+ rich_tracebacks=self.DEBUG, # Pretty tracebacks
44
+ markup=True, # Allow Rich markup tags
45
+ show_time=self.DEBUG, # Show timestamps
46
+ show_level=self.DEBUG, # Show log levels
47
+ show_path=self.DEBUG, # Hide source path (optional)
48
+ )
49
+
50
+ def formatter(record) -> str:
51
+ match record["level"].name:
52
+ case "ERROR":
53
+ return "[bold red]Error>[/bold red] {name}:{function}:{line} - {message}"
54
+ case "CRITICAL":
55
+ return "[italic bold red]Critical>[/italic bold red] {name}:{function}:{line} - {message}"
56
+ case "WARNING":
57
+ return "[bold yellow]Warning>[/bold yellow] {message}"
58
+ case _:
59
+ return "[#666666]>[/#666666] {message}"
60
+
61
+ logger.remove()
62
+ logger.add(
63
+ sink=rich_handler,
64
+ format="{message}" if self.DEBUG else formatter,
65
+ level="DEBUG" if self.DEBUG else "INFO",
66
+ backtrace=self.DEBUG,
67
+ )
68
+
69
+ # Nothing will print here if DEBUG is false
70
+ logger.debug("Debug mode enabled")
47
71
 
48
72
  # Generate full or partial host config
49
73
  if not token and not transport:
@@ -67,6 +91,7 @@ class Primitive:
67
91
  self.daemons: Daemons = Daemons(self)
68
92
  self.exec: Exec = Exec(self)
69
93
  self.provisioning: Provisioning = Provisioning(self)
94
+ self.monitor: Monitor = Monitor(self)
70
95
 
71
96
  def get_host_config(self):
72
97
  self.full_config = read_config_file()
@@ -1,24 +1,13 @@
1
1
  import platform
2
2
  import typing
3
+ from typing import Dict, Optional, List
3
4
 
4
5
  if typing.TYPE_CHECKING:
5
6
  from ..client import Primitive
6
7
 
7
- from .launch_agents import (
8
- full_launch_agent_install,
9
- full_launch_agent_uninstall,
10
- start_launch_agent,
11
- stop_launch_agent,
12
- view_launch_agent_logs,
13
- )
14
-
15
- from .launch_service import (
16
- full_service_install,
17
- full_service_uninstall,
18
- start_service,
19
- stop_service,
20
- view_service_logs,
21
- )
8
+ from .launch_agents import LaunchAgent
9
+ from .launch_service import LaunchService
10
+ from ..utils.daemons import Daemon
22
11
 
23
12
 
24
13
  class Daemons:
@@ -26,50 +15,47 @@ class Daemons:
26
15
  self.primitive: Primitive = primitive
27
16
  self.os_family = platform.system()
28
17
 
29
- def install(self):
30
- result = True
31
- if self.os_family == "Darwin":
32
- full_launch_agent_install()
33
- elif self.os_family == "Linux":
34
- full_service_install()
35
- elif self.os_family == "Windows":
36
- print("Not Implemented")
37
- return result
38
-
39
- def uninstall(self):
40
- result = True
41
- if self.os_family == "Darwin":
42
- full_launch_agent_uninstall()
43
- elif self.os_family == "Linux":
44
- full_service_uninstall()
45
- elif self.os_family == "Windows":
46
- print("Not Implemented")
47
- return result
48
-
49
- def stop(self) -> bool:
50
- result = True
51
- if self.os_family == "Darwin":
52
- result = stop_launch_agent()
53
- elif self.os_family == "Linux":
54
- stop_service()
55
- elif self.os_family == "Windows":
56
- print("Not Implemented")
57
- return result
58
-
59
- def start(self) -> bool:
60
- result = True
61
- if self.os_family == "Darwin":
62
- result = start_launch_agent()
63
- elif self.os_family == "Linux":
64
- start_service()
65
- elif self.os_family == "Windows":
66
- print("Not Implemented")
67
- return result
68
-
69
- def logs(self):
70
- if self.os_family == "Darwin":
71
- view_launch_agent_logs()
72
- elif self.os_family == "Linux":
73
- view_service_logs()
74
- elif self.os_family == "Windows":
75
- print("Not Implemented")
18
+ match self.os_family:
19
+ case "Darwin":
20
+ self.daemons: Dict[str, Daemon] = {
21
+ "agent": LaunchAgent("tech.primitive.agent"),
22
+ "monitor": LaunchAgent("tech.primitive.monitor"),
23
+ }
24
+ case "Linux":
25
+ self.daemons: Dict[str, Daemon] = {
26
+ "agent": LaunchService("tech.primitive.agent"),
27
+ "monitor": LaunchService("tech.primitive.monitor"),
28
+ }
29
+ case _:
30
+ raise NotImplementedError(f"{self.os_family} is not supported.")
31
+
32
+ def install(self, name: Optional[str]) -> bool:
33
+ if name:
34
+ return self.daemons[name].install()
35
+ else:
36
+ return all([daemon.install() for daemon in self.daemons.values()])
37
+
38
+ def uninstall(self, name: Optional[str]) -> bool:
39
+ if name:
40
+ return self.daemons[name].uninstall()
41
+ else:
42
+ return all([daemon.uninstall() for daemon in self.daemons.values()])
43
+
44
+ def stop(self, name: Optional[str]) -> bool:
45
+ if name:
46
+ return self.daemons[name].stop()
47
+ else:
48
+ return all([daemon.stop() for daemon in self.daemons.values()])
49
+
50
+ def start(self, name: Optional[str]) -> bool:
51
+ if name:
52
+ return self.daemons[name].start()
53
+ else:
54
+ return all([daemon.start() for daemon in self.daemons.values()])
55
+
56
+ def list(self) -> List[Daemon]:
57
+ """List all daemons"""
58
+ return list(self.daemons.values())
59
+
60
+ def logs(self, name: str) -> None:
61
+ self.daemons[name].view_logs()
@@ -1,7 +1,10 @@
1
1
  import click
2
2
 
3
- from ..utils.printer import print_result
4
3
  import typing
4
+ from typing import Optional
5
+ from .ui import render_daemon_list
6
+
7
+ from loguru import logger
5
8
 
6
9
  if typing.TYPE_CHECKING:
7
10
  from ..client import Primitive
@@ -16,50 +19,93 @@ def cli(context):
16
19
 
17
20
  @cli.command("install")
18
21
  @click.pass_context
19
- def install_daemon_command(context):
22
+ @click.argument(
23
+ "name",
24
+ type=str,
25
+ required=False,
26
+ )
27
+ def install_daemon_command(context, name: Optional[str]):
20
28
  """Install the full primitive daemon"""
21
29
  primitive: Primitive = context.obj.get("PRIMITIVE")
22
- result = primitive.daemons.install()
23
- print_result(message=result, context=context)
30
+ installed = primitive.daemons.install(name=name)
31
+
32
+ if installed:
33
+ logger.info(":white_check_mark: daemon(s) installed successfully!")
34
+ else:
35
+ logger.error("Unable to install daemon(s).")
24
36
 
25
37
 
26
38
  @cli.command("uninstall")
27
39
  @click.pass_context
28
- def uninstall_daemon_command(context):
40
+ @click.argument(
41
+ "name",
42
+ type=str,
43
+ required=False,
44
+ )
45
+ def uninstall_daemon_command(context, name: Optional[str]):
29
46
  """Uninstall the full primitive Daemon"""
30
47
  primitive: Primitive = context.obj.get("PRIMITIVE")
31
- result = primitive.daemons.uninstall()
32
- print_result(message=result, context=context)
48
+ uninstalled = primitive.daemons.uninstall(name=name)
49
+
50
+ if uninstalled:
51
+ logger.info(":white_check_mark: daemon(s) uninstalled successfully!")
52
+ else:
53
+ logger.error("Unable to uninstall daemon(s).")
33
54
 
34
55
 
35
56
  @cli.command("stop")
36
57
  @click.pass_context
37
- def stop_daemon_command(context):
58
+ @click.argument(
59
+ "name",
60
+ type=str,
61
+ required=False,
62
+ )
63
+ def stop_daemon_command(context, name: Optional[str]):
38
64
  """Stop primitive Daemon"""
39
65
  primitive: Primitive = context.obj.get("PRIMITIVE")
40
- result = primitive.daemons.stop()
41
- message = "stopping primitive daemon"
42
- if context.obj["JSON"]:
43
- message = result
44
- print_result(message=message, context=context)
66
+ stopped = primitive.daemons.stop(name=name)
67
+
68
+ if stopped:
69
+ logger.info(":white_check_mark: daemon(s) stopped successfully!")
70
+ else:
71
+ logger.error("Unable to stop daemon(s).")
45
72
 
46
73
 
47
74
  @cli.command("start")
48
75
  @click.pass_context
49
- def start_daemon_command(context):
76
+ @click.argument(
77
+ "name",
78
+ type=str,
79
+ required=False,
80
+ )
81
+ def start_daemon_command(context, name: Optional[str]):
50
82
  """Start primitive Daemon"""
51
83
  primitive: Primitive = context.obj.get("PRIMITIVE")
52
- result = primitive.daemons.start()
53
- message = "starting primitive daemon"
54
- if context.obj["JSON"]:
55
- message = result
56
- print_result(message=message, context=context)
84
+ started = primitive.daemons.start(name=name)
85
+
86
+ if started:
87
+ logger.info(":white_check_mark: daemon(s) started successfully!")
88
+ else:
89
+ logger.error("Unable to start daemon(s).")
57
90
 
58
91
 
59
92
  @cli.command("logs")
60
93
  @click.pass_context
61
- def log_daemon_command(context):
94
+ @click.argument(
95
+ "name",
96
+ type=str,
97
+ required=True,
98
+ )
99
+ def log_daemon_command(context, name: str):
62
100
  """Logs from primitive Daemon"""
63
101
  primitive: Primitive = context.obj.get("PRIMITIVE")
64
- result = primitive.daemons.logs()
65
- print_result(message=result, context=context)
102
+ primitive.daemons.logs(name=name)
103
+
104
+
105
+ @cli.command("list")
106
+ @click.pass_context
107
+ def list_daemon_command(context):
108
+ """List all daemons"""
109
+ primitive: Primitive = context.obj.get("PRIMITIVE")
110
+ daemon_list = primitive.daemons.list()
111
+ render_daemon_list(daemons=daemon_list)