lockss-pybasic 0.1.0.dev11__tar.gz → 0.1.0.dev13__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: lockss-pybasic
3
- Version: 0.1.0.dev11
3
+ Version: 0.1.0.dev13
4
4
  Summary: Basic Python utilities
5
5
  License: BSD-3-Clause
6
6
  Author: Thib Guicherd-Callin
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "lockss-pybasic"
3
- version = "0.1.0-dev11"
3
+ version = "0.1.0-dev13"
4
4
  description = "Basic Python utilities"
5
5
  authors = [
6
6
  { name = "Thib Guicherd-Callin", email = "thib@cs.stanford.edu" }
@@ -36,4 +36,4 @@ ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
36
36
  POSSIBILITY OF SUCH DAMAGE.
37
37
  '''.strip()
38
38
 
39
- __version__ = '0.1.0-dev11'
39
+ __version__ = '0.1.0-dev13'
@@ -34,73 +34,43 @@ Command line utilities.
34
34
 
35
35
  from abc import ABC, abstractmethod
36
36
  import sys
37
- from typing import Any, Dict, Generic, List, Optional, TypeVar
37
+ from typing import Any, Dict, Generic, Optional, TypeVar
38
38
 
39
39
  from pydantic.v1 import BaseModel, Field
40
40
  from pydantic_argparse import ArgumentParser
41
41
 
42
42
 
43
- class Printable(ABC):
44
-
45
- def print(self, file=sys.stdout):
46
- print(self.get_display(), file=file)
43
+ class ActionCommand(ABC, BaseModel):
47
44
 
48
45
  @abstractmethod
49
- def get_display(self):
46
+ def action(self):
50
47
  pass
51
48
 
52
49
 
53
- class CopyrightCommand:
54
-
55
- @staticmethod
56
- def make(copyright):
57
- class CopyrightModel(Printable, BaseModel):
58
- def get_display(self):
59
- return copyright
60
- return CopyrightModel
61
-
62
- @staticmethod
63
- def field():
64
- return Field(description='print the copyright and exit')
65
-
66
-
67
- class LicenseCommand:
50
+ class StringCommand(ActionCommand):
68
51
 
69
52
  @staticmethod
70
- def make(license):
71
- class LicenseModel(Printable, BaseModel):
72
- def get_display(self):
73
- return license
74
- return LicenseModel
75
-
76
- @staticmethod
77
- def field():
78
- return Field(description='print the software license and exit')
79
-
53
+ def type(display_str: str):
54
+ class _StringCommand(StringCommand):
55
+ def action(self, file=sys.stdout):
56
+ print(display_str, file=file)
57
+ return _StringCommand
80
58
 
81
- class VersionCommand:
82
59
 
83
- @staticmethod
84
- def make(version):
85
- class VersionModel(Printable, BaseModel):
86
- def get_display(self):
87
- return version
88
- return VersionModel
89
-
90
- @staticmethod
91
- def field():
92
- return Field(description='print the version number and exit')
60
+ COPYRIGHT_DESCRIPTION = 'print the copyright and exit'
61
+ LICENSE_DESCRIPTION = 'print the software license and exit'
62
+ VERSION_DESCRIPTION = 'print the version number and exit'
93
63
 
94
64
 
95
- ModelT = TypeVar('ModelT')
65
+ BaseModelT = TypeVar('BaseModelT', bound=BaseModel)
96
66
 
97
67
 
98
- class BaseCli(Generic[ModelT], ABC):
68
+ class BaseCli(Generic[BaseModelT]):
99
69
 
100
70
  def __init__(self, **extra):
101
71
  super().__init__()
102
- self.args: ModelT = None
103
- self.parser: ArgumentParser = None
72
+ self.args: Optional[BaseModelT] = None
73
+ self.parser: Optional[ArgumentParser] = None
104
74
  self.extra: Dict[str, Any] = dict(**extra)
105
75
 
106
76
  def run(self):
@@ -110,32 +80,58 @@ class BaseCli(Generic[ModelT], ABC):
110
80
  self.args = self.parser.parse_typed_args()
111
81
  self.dispatch()
112
82
 
113
- @abstractmethod
114
83
  def dispatch(self):
115
- pass
84
+ field_names = self.args.__class__.__fields__.keys()
85
+ for field_name in field_names:
86
+ field_value = getattr(self.args, field_name)
87
+ if issubclass(type(field_value), BaseModel):
88
+ func = getattr(self, f'_{field_name}')
89
+ func(field_value)
90
+ break
91
+ else:
92
+ self.parser.error(f'unknown command; expected one of {', '.join(field_names)}')
93
+
94
+
95
+ def at_most_one_from_enum(model_cls, values: Dict[str, Any], enum_cls):
96
+ enum_names = [field_name for field_name, model_field in model_cls.__fields__.items() if model_field.field_info.extra.get('enum') == enum_cls]
97
+ ret = list()
98
+ for field_name in enum_names:
99
+ if values.get(field_name):
100
+ ret.append(field_name)
101
+ if (length := len(ret)) > 1:
102
+ raise ValueError(f'at most one of {', '.join([option_name(enum_name) for enum_name in enum_names])} is allowed, got {', '.join([option_name(enum_name) for enum_name in ret])}')
103
+ return values
116
104
 
117
105
 
118
- def _matchy_length(values: Dict[str, Any], *names: str) -> int:
119
- return len(name for name in names if values.get(name)])
106
+ def get_from_enum(model_inst, enum_cls, default=None):
107
+ enum_names = [field_name for field_name, model_field in model_inst.__class__.__fields__.items() if model_field.field_info.extra.get('enum') == enum_cls]
108
+ for field_name in enum_names:
109
+ if getattr(model_inst, field_name):
110
+ return enum_cls.from_member(field_name)
111
+ return default
120
112
 
121
113
 
122
114
  def at_most_one(values: Dict[str, Any], *names: str):
123
- if (length := _matchy_length(values, names)) > 1:
115
+ if (length := _matchy_length(values, *names)) > 1:
124
116
  raise ValueError(f'at most one of {', '.join([option_name(name) for name in names])} is allowed, got {length}')
125
117
  return values
126
118
 
127
119
 
128
120
  def exactly_one(values: Dict[str, Any], *names: str):
129
- if (length := _matchy_length(values, names)) != 1:
121
+ if (length := _matchy_length(values, *names)) != 1:
130
122
  raise ValueError(f'exactly one of {', '.join([option_name(name) for name in names])} is required, got {length}')
131
123
  return values
132
124
 
133
125
 
134
126
  def one_or_more(values: Dict[str, Any], *names: str):
135
- if _matchy_length(values, names) == 0:
127
+ if _matchy_length(values, *names) == 0:
136
128
  raise ValueError(f'one or more of {', '.join([option_name(name) for name in names])} is required')
137
129
  return values
138
130
 
139
131
 
140
132
  def option_name(name: str):
141
133
  return f'{('-' if len(name) == 1 else '--')}{name.replace('_', '-')}'
134
+
135
+
136
+ def _matchy_length(values: Dict[str, Any], *names: str) -> int:
137
+ return len([name for name in names if values.get(name)])