sprinkles-opts 0.1.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,387 @@
1
+ # typed: strict
2
+ # frozen_string_literal: true
3
+
4
+ require 'optparse'
5
+ require 'sorbet-runtime'
6
+
7
+ module Sprinkles; module Opts; end; end
8
+
9
+ module Sprinkles::Opts
10
+ class GetOpt
11
+ extend T::Sig
12
+ extend T::Helpers
13
+ abstract!
14
+
15
+ sig { overridable.returns(String) }
16
+ def self.program_name
17
+ $PROGRAM_NAME
18
+ end
19
+
20
+ class Option < T::Struct
21
+ extend T::Sig
22
+
23
+ const :name, Symbol
24
+ const :short, T.nilable(String)
25
+ const :long, T.nilable(String)
26
+ const :type, T.untyped
27
+ const :placeholder, T.nilable(String)
28
+ const :factory, T.nilable(T.proc.returns(T.untyped))
29
+ const :description, T.nilable(String)
30
+
31
+ sig { returns(T::Boolean) }
32
+ def optional?
33
+ return true unless factory.nil?
34
+ return true if type.is_a?(T::Types::Union) && type.types.any? { |t| t == T::Utils.coerce(NilClass) }
35
+
36
+ false
37
+ end
38
+
39
+ sig { returns(T::Boolean) }
40
+ def positional?
41
+ short.nil? && long.nil?
42
+ end
43
+
44
+ sig { returns(T::Boolean) }
45
+ def repeated?
46
+ type.is_a?(T::Types::TypedArray) || type.is_a?(T::Types::TypedSet)
47
+ end
48
+
49
+ sig { params(default: String).returns(String) }
50
+ def get_placeholder(default='VALUE')
51
+ if type.is_a?(Class) && type < T::Enum
52
+ # if the type is an enum, we can enumerate the possible
53
+ # values in a rich way
54
+ possible_values = type.values.map(&:serialize).join("|")
55
+ return "<#{possible_values}>"
56
+ end
57
+
58
+ placeholder || default
59
+ end
60
+
61
+ sig { returns(T::Array[String]) }
62
+ def optparse_args
63
+ args = []
64
+ if type == T::Boolean
65
+ args << "-#{short}" if short
66
+ args << "--[no-]#{long}" if long
67
+ else
68
+ args << "-#{short}#{get_placeholder}" if short
69
+ args << "--#{long}=#{get_placeholder}" if long
70
+ end
71
+ args << description if description
72
+ args
73
+ end
74
+ end
75
+
76
+ # for appeasing Sorbet, even though this isn't how we're using the
77
+ # props methods. (we're also not allowing `prop` at all, only
78
+ # `const`.)
79
+ sig { params(rest: T.untyped).returns(T.untyped) }
80
+ def self.decorator(*rest); end
81
+
82
+ sig { returns(T::Array[Option]) }
83
+ private_class_method def self.fields
84
+ @fields = T.let(@fields, T.nilable(T::Array[Option]))
85
+ @fields ||= []
86
+ end
87
+
88
+ sig do
89
+ params(
90
+ name: Symbol,
91
+ type: T.untyped,
92
+ short: String,
93
+ long: String,
94
+ factory: T.nilable(T.proc.returns(T.untyped)),
95
+ placeholder: String,
96
+ description: String,
97
+ without_accessors: TrueClass
98
+ )
99
+ .returns(T.untyped)
100
+ end
101
+ def self.const(
102
+ name,
103
+ type,
104
+ short: '',
105
+ long: '',
106
+ factory: nil,
107
+ placeholder: '',
108
+ description: '',
109
+ without_accessors: true
110
+ )
111
+ # we don't want to let the user pass in nil explicitly, so the
112
+ # default values here are all '' instead, but we will treat ''
113
+ # as if the argument was not provided
114
+ short = nil if short.empty?
115
+ long = nil if long.empty?
116
+
117
+ placeholder = nil if placeholder.empty?
118
+
119
+ opt = Option.new(
120
+ name: name,
121
+ type: type,
122
+ short: short,
123
+ long: long,
124
+ factory: factory,
125
+ placeholder: placeholder,
126
+ description: description
127
+ )
128
+ validate!(opt)
129
+ fields << opt
130
+ self.define_method(name) { instance_variable_get("@#{name}") }
131
+ end
132
+
133
+ sig { params(opt: Option).void }
134
+ private_class_method def self.validate!(opt)
135
+ raise 'Do not start options with -' if opt.short&.start_with?('-')
136
+ raise 'Do not start options with -' if opt.long&.start_with?('-')
137
+ if (opt.short == 'h') || (opt.long == 'help')
138
+ raise <<~RB
139
+ The options `-h` and `--help` are reserved by Sprinkles::Opts::GetOpt
140
+ RB
141
+ end
142
+ if !valid_type?(opt.type)
143
+ raise "`#{opt.type}` is not a valid parameter type"
144
+ end
145
+
146
+ # the invariant we want to keep is that all mandatory positional
147
+ # fields come first while all optional positional fields come
148
+ # after: this makes matching up positional fields a _lot_ easier
149
+ # and less surprising
150
+ if opt.positional? && opt.optional?
151
+ @seen_optional_positional = T.let(true, T.nilable(TrueClass))
152
+ end
153
+
154
+ if opt.positional? && @seen_repeated_positional
155
+ raise "The positional parameter `#{opt.name}` comes after the "\
156
+ "repeated parameter `#{@seen_repeated_positional.name}`"
157
+ end
158
+
159
+ if opt.positional? && opt.repeated?
160
+ if @seen_optional_positional
161
+ raise "The repeated parameter `#{opt.name}` comes after an "\
162
+ "optional parameter."
163
+ end
164
+
165
+ @seen_repeated_positional = T.let(opt, T.nilable(Option))
166
+ end
167
+
168
+ if opt.positional? && !opt.optional? && @seen_optional_positional
169
+ # this means we're looking at a _mandatory_ positional field
170
+ # coming after an _optional_ positional field. To make things
171
+ # easy, we simply reject this case.
172
+ prev = fields.select {|f| f.positional? && f.optional?}
173
+ prev = prev.map {|f| "`#{f.name}`"}.to_a.join(", ")
174
+ raise "`#{opt.name}` is a mandatory positional field "\
175
+ "but it comes after the optional field(s) #{prev}"
176
+ end
177
+ end
178
+
179
+ sig { params(type: T.untyped).returns(T::Boolean) }
180
+ private_class_method def self.valid_type?(type)
181
+ type = type.raw_type if type.is_a?(T::Types::Simple)
182
+ # true if the type is one of the valid types
183
+ return true if type == String || type == Symbol || type == Integer || type == Float || type == T::Boolean
184
+ # allow enumeration types
185
+ return true if type.is_a?(Class) && type < T::Enum
186
+ # true if it's a nilable valid type
187
+ if type.is_a?(T::Types::Union)
188
+ other_types = type.types.to_set - [T::Utils.coerce(NilClass)]
189
+ return false if other_types.size > 1
190
+ return valid_type?(other_types.first)
191
+ elsif type.is_a?(T::Types::TypedArray) || type.is_a?(T::Types::TypedSet)
192
+ return valid_type?(type.type)
193
+ end
194
+
195
+ # otherwise we probably don't handle it
196
+ false
197
+ end
198
+
199
+ sig { params(value: String, type: T.untyped).returns(T.untyped) }
200
+ def self.convert_str(value, type)
201
+ type = type.raw_type if type.is_a?(T::Types::Simple)
202
+ if type.is_a?(T::Types::Union)
203
+ # Right now, the assumption is that this is mostly used for
204
+ # `T.nilable`, but with a bit of work we maybe could support
205
+ # other kinds of unions
206
+ possible_types = type.types.to_set - [T::Utils.coerce(NilClass)]
207
+ raise 'TODO: generic union types' if possible_types.size > 1
208
+
209
+ convert_str(value, possible_types.first)
210
+ elsif type.is_a?(Class) && type < T::Enum
211
+ type.deserialize(value)
212
+ elsif type == String
213
+ value
214
+ elsif type == Symbol
215
+ value.to_sym
216
+ elsif type == Integer
217
+ value.to_i
218
+ elsif type == Float
219
+ value.to_f
220
+ else
221
+ raise "Don't know how to convert a string to #{type}"
222
+ end
223
+ end
224
+
225
+ sig { params(values: T::Hash[Symbol, T::Array[String]]).returns(T.attached_class) }
226
+ private_class_method def self.build_config(values)
227
+ o = new
228
+ serialized = {}
229
+ fields.each do |field|
230
+ if field.type == T::Boolean
231
+ default = false
232
+ default = field.factory&.call if !field.factory.nil?
233
+ v = !!values.fetch(field.name, [default]).fetch(0)
234
+ elsif values.include?(field.name)
235
+ begin
236
+ if !field.repeated?
237
+ val = values.fetch(field.name).fetch(0)
238
+ v = Sprinkles::Opts::GetOpt.convert_str(val, field.type)
239
+ else
240
+ v = values.fetch(field.name).map do |val|
241
+ Sprinkles::Opts::GetOpt.convert_str(val, field.type.type)
242
+ end
243
+ # we allow both arrays and sets but we use arrays
244
+ # internally, so convert to a set just in case
245
+ v = v.to_set if field.type.is_a?(T::Types::TypedSet)
246
+ end
247
+ rescue KeyError => exn
248
+ usage!("Invalid value `#{val}` for field `#{field.name}`:\n #{exn.message}")
249
+ end
250
+ elsif !field.factory.nil?
251
+ v = T.must(field.factory).call
252
+ elsif field.optional?
253
+ v = nil
254
+ elsif field.repeated?
255
+ if field.type.is_a?(T::Types::TypedArray)
256
+ v = []
257
+ else
258
+ v = Set.new
259
+ end
260
+ else
261
+ usage!("Expected a value for `#{field.name}`")
262
+ end
263
+ o.instance_variable_set("@#{field.name}", v)
264
+ serialized[field.name] = v
265
+ end
266
+ o.define_singleton_method(:_serialize) {serialized}
267
+ o
268
+ end
269
+
270
+ sig { params(argv: T::Array[String]).returns(T::Hash[Symbol, T::Array[String]]) }
271
+ private_class_method def self.match_positional_fields(argv)
272
+ pos_fields = fields.select(&:positional?).reject(&:repeated?)
273
+ total_positional = pos_fields.size
274
+ min_positional = total_positional - pos_fields.count(&:optional?)
275
+
276
+ pos_values = T::Hash[Symbol, T::Array[String]].new
277
+
278
+ usage!("Not enough arguments!") if argv.size < min_positional
279
+ if argv.size > total_positional
280
+ # we only want to warn about too many args if there isn't a
281
+ # repeated arg to grab them
282
+ usage!("Too many arguments!") if !fields.select(&:positional?).any?(&:repeated?)
283
+
284
+ # we verify on construction that there's at most one
285
+ # positional repeated field, and we don't intermingle repeated
286
+ # and optional fields
287
+ rest = T.must(fields.find {|f| f.positional? && f.repeated?})
288
+ pos_values[rest.name] = argv.drop(total_positional)
289
+ end
290
+
291
+ pos_fields.zip(argv).each do |field, arg|
292
+ next if arg.nil?
293
+ pos_values[field.name] = [arg]
294
+ end
295
+
296
+ pos_values
297
+ end
298
+
299
+ sig { params(msg: String).void }
300
+ private_class_method def self.usage!(msg='')
301
+ raise <<~RB if @opts.nil?
302
+ Internal error: tried to call `usage!` before building option parser!
303
+ RB
304
+
305
+ puts msg if !msg.empty?
306
+ puts @opts
307
+ exit
308
+ end
309
+
310
+ sig { returns(String) }
311
+ private_class_method def self.cmdline
312
+ cmd_line = T::Array[String].new
313
+ field_count = 0
314
+ # first, all the positional fields
315
+ fields.each do |field|
316
+ next if !field.positional?
317
+ field_count += 1
318
+ field_name = field.get_placeholder(field.name.to_s.upcase)
319
+ if field.optional?
320
+ cmd_line << "[#{field_name}]"
321
+ elsif field.repeated?
322
+ cmd_line << "[#{field_name} ...]"
323
+ else
324
+ cmd_line << field_name
325
+ end
326
+ end
327
+ # next, the non-positional but mandatory flags
328
+ # (leaving the optional flags for the other help)
329
+ fields.each do |field|
330
+ next if field.positional? || field.optional? || field.repeated?
331
+ field_count += 1
332
+ if field.long
333
+ cmd_line << "--#{field.long}=#{field.get_placeholder}"
334
+ else
335
+ cmd_line << "-#{field.short}#{field.get_placeholder}"
336
+ end
337
+ end
338
+
339
+ # then the repeated fields, which the special `...` to denote
340
+ # their repetition
341
+ fields.each do |field|
342
+ next if !field.repeated?
343
+ field_count += 1
344
+ if field.long
345
+ cmd_line << "[--#{field.long}=#{field.get_placeholder} ...]"
346
+ else
347
+ cmd_line << "[-#{field.short}#{field.get_placeholder} ...]"
348
+ end
349
+ end
350
+
351
+ # we'll only add the final [OPTS...] if there are other things
352
+ # we haven't listed yet
353
+ cmd_line << "[OPTS...]" if fields.size > field_count
354
+ cmd_line.join(" ")
355
+ end
356
+
357
+ sig { params(argv: T::Array[String]).returns(T.attached_class) }
358
+ def self.parse(argv=ARGV)
359
+ # we're going to destructively modify this
360
+ argv = argv.clone
361
+
362
+ values = T::Hash[Symbol, T::Array[String]].new
363
+ parser = OptionParser.new do |opts|
364
+ @opts = T.let(opts, T.nilable(OptionParser))
365
+ opts.banner = "Usage: #{program_name} #{cmdline}"
366
+ opts.on('-h', '--help', 'Print this help') do
367
+ usage!
368
+ end
369
+
370
+ fields.each do |field|
371
+ next if field.positional?
372
+ opts.on(*field.optparse_args) do |v|
373
+ if field.repeated?
374
+ (values[field.name] ||= []) << v
375
+ else
376
+ values[field.name] = [v]
377
+ end
378
+ end
379
+ end
380
+ end.parse!(argv)
381
+
382
+ values.merge!(match_positional_fields(argv))
383
+
384
+ build_config(values)
385
+ end
386
+ end
387
+ end
data/sorbet/config ADDED
@@ -0,0 +1,3 @@
1
+ --dir
2
+ .
3
+ --ignore=/vendor/bundle