protobug 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,746 @@
1
+ # frozen_string_literal: true
2
+
3
+ require_relative "binary_encoding"
4
+
5
+ module Protobug
6
+ class Field
7
+ attr_accessor :number, :name, :json_name, :cardinality, :oneof, :ivar, :setter,
8
+ :adder, :haser, :clearer
9
+
10
+ def initialize(number, name, json_name: nil, cardinality: :optional, oneof: nil, packed: false,
11
+ proto3_optional: cardinality == :optional)
12
+ @number = number
13
+ @name = name.to_sym
14
+ @json_name = json_name || name.to_s
15
+ @cardinality = cardinality || raise(ArgumentError, "cardinality is required")
16
+ @oneof = oneof
17
+ @setter = :"#{name}="
18
+ @adder = :"add_#{name}" if repeated?
19
+ @ivar = :"@#{name}"
20
+ @clearer = :"clear_#{name}"
21
+ @haser = :"#{name}?"
22
+ @packed = packed
23
+ @proto3_optional = proto3_optional
24
+ end
25
+
26
+ def pretty_print(pp)
27
+ pp.group 0, "#{self.class}.new(", ")" do
28
+ pp.text @number.to_s
29
+ pp.breakable(", ")
30
+ pp.text(@name.inspect)
31
+ pp.breakable(", ")
32
+ if json_name != name.name
33
+ pp.breakable(", ")
34
+ pp.text("json_name: ")
35
+ pp.text(@json_name.inspect)
36
+ end
37
+ pp.breakable(", ")
38
+ pp.text("cardinality: ")
39
+ pp.pp(@cardinality)
40
+ if oneof
41
+ pp.breakable(", ")
42
+ pp.text("oneof: ")
43
+ pp.text(@oneof.inspect)
44
+ end
45
+ end
46
+ end
47
+
48
+ def repeated?
49
+ cardinality == :repeated
50
+ end
51
+
52
+ def packed?
53
+ @packed
54
+ end
55
+
56
+ def optional?
57
+ cardinality == :optional
58
+ end
59
+
60
+ def proto3_optional?
61
+ @proto3_optional
62
+ end
63
+
64
+ def define_adder(message)
65
+ field = self
66
+ message.define_method(adder) do |value|
67
+ field.validate!(value, self)
68
+
69
+ existing = instance_variable_get(field.ivar)
70
+ if UNSET == existing
71
+ existing = field.default
72
+ instance_variable_set(field.ivar, existing)
73
+ end
74
+
75
+ existing << value
76
+ end
77
+ end
78
+
79
+ def to_text(value)
80
+ case [cardinality, json_scalar?]
81
+ when [:repeated, true]
82
+ Array(value).map { |v| "#{name}: #{scalar_to_text(v)}" }.join("\n")
83
+ when [:repeated, false]
84
+ Array(value).map { |v| "#{name} {\n#{v.to_text.gsub(/^/, " ")}\n}" }.join("\n")
85
+ when [:optional, true]
86
+ "#{name}: #{scalar_to_text(value)}"
87
+ when [:optional, false]
88
+ "#{name} {\n#{value.to_text.gsub(/^/, " ")}\n}"
89
+ end
90
+ end
91
+
92
+ def binary_encode(value, outbuf)
93
+ if repeated?
94
+ if packed?
95
+ binary_encode_packed(value, outbuf)
96
+ else
97
+ value.each do |v|
98
+ BinaryEncoding.encode_varint (number << 3) | wire_type, outbuf
99
+ binary_encode_one(v, outbuf)
100
+ end
101
+ end
102
+ elsif (!optional? || !proto3_optional?) && !oneof && default == value
103
+ # omit
104
+ else
105
+ BinaryEncoding.encode_varint (number << 3) | wire_type, outbuf
106
+ binary_encode_one(value, outbuf)
107
+ end
108
+ end
109
+
110
+ def binary_decode(binary, message, registry, wire_type)
111
+ if repeated? && wire_type == 2 && [0, 1, 5].include?(self.wire_type)
112
+ len = StringIO.new(BinaryEncoding.decode_length(binary))
113
+ len.binmode
114
+
115
+ message.send(adder, binary_decode_one(len, message, registry, self.wire_type)) until len.eof?
116
+ elsif wire_type != self.wire_type
117
+ raise DecodeError, "wrong wire type for #{self}: #{wire_type.inspect}"
118
+ else
119
+ message.send(adder || setter, binary_decode_one(binary, message, registry, wire_type))
120
+ end
121
+ end
122
+
123
+ def json_encode(value, print_unknown_fields:)
124
+ if repeated?
125
+ value.map { |v| json_encode_one(v, print_unknown_fields: print_unknown_fields) }
126
+ elsif (!optional? || !proto3_optional?) && !oneof && default == value
127
+ # omit
128
+ else
129
+ json_encode_one(value, print_unknown_fields: print_unknown_fields)
130
+ end
131
+ end
132
+
133
+ def json_key_encode(value)
134
+ case value
135
+ when String
136
+ value
137
+ when Integer, Float
138
+ value.to_s
139
+ when TrueClass
140
+ "true"
141
+ when FalseClass
142
+ "false"
143
+ else
144
+ raise EncodeError, "unexpected type for map key: #{value.inspect}"
145
+ end
146
+ end
147
+
148
+ def json_decode(value, message, ignore_unknown_fields, registry)
149
+ if repeated?
150
+ return if value.nil?
151
+
152
+ unless value.is_a?(Array)
153
+ raise DecodeError,
154
+ "expected Array for #{inspect}, got #{value.inspect}"
155
+ end
156
+
157
+ value.map do |v|
158
+ message.send(adder, json_decode_one(v, ignore_unknown_fields, registry))
159
+ end
160
+ else
161
+ value = json_decode_one(value, ignore_unknown_fields, registry)
162
+ message.send(setter, value) unless UNSET == value
163
+ end
164
+ end
165
+
166
+ def validate!(value, message)
167
+ raise DecodeError, "nil is invalid for #{name} in #{message}" if UNSET == value
168
+
169
+ return unless oneof
170
+
171
+ message.class.oneofs[oneof].each do |f|
172
+ next if f == self
173
+
174
+ message.send(f.clearer)
175
+ end
176
+ end
177
+
178
+ private
179
+
180
+ def binary_encode_packed(value, outbuf)
181
+ BinaryEncoding.encode_varint (number << 3) | 2, outbuf
182
+
183
+ BinaryEncoding.encode_length(value.each_with_object("".b) do |v, buf|
184
+ binary_encode_one(v, buf)
185
+ end, outbuf)
186
+ end
187
+
188
+ class MessageField < Field
189
+ attr_reader :message_type
190
+
191
+ def initialize(number, name, cardinality:, message_type:, json_name: name, oneof: nil,
192
+ proto3_optional: cardinality == :optional)
193
+ super(number, name, json_name: json_name, cardinality: cardinality, oneof: oneof,
194
+ proto3_optional: proto3_optional)
195
+ @message_type = message_type
196
+ end
197
+
198
+ def binary_encode_one(value, outbuf)
199
+ BinaryEncoding.encode_length value.class.encode(value), outbuf
200
+ end
201
+
202
+ def binary_decode_one(io, message, registry, wire_type)
203
+ value = BinaryEncoding.read_field_value(io, wire_type)
204
+ kwargs = {}
205
+ kwargs[:object] = message.send(name) if !repeated? && message.send(haser)
206
+ type_lookup(registry).decode(StringIO.new(value), registry: registry, **kwargs)
207
+ end
208
+
209
+ def json_decode_one(value, ignore_unknown_fields, registry)
210
+ klass = type_lookup(registry)
211
+ klass.decode_json_hash(value, registry: registry, ignore_unknown_fields: ignore_unknown_fields)
212
+ end
213
+
214
+ def type_lookup(registry)
215
+ registry.fetch(message_type)
216
+ end
217
+
218
+ def json_encode_one(value, print_unknown_fields:)
219
+ value.as_json(print_unknown_fields: print_unknown_fields)
220
+ end
221
+
222
+ def default
223
+ return [] if repeated?
224
+
225
+ # TODO: message_type.default
226
+ nil
227
+ end
228
+
229
+ def wire_type = 2
230
+ end
231
+
232
+ class MapField < MessageField
233
+ SUPER_INITIALIZE = instance_method(:initialize).super_method
234
+ def initialize(number, name, key_type:, value_type:, json_name: name, oneof: nil, # rubocop:disable Lint/MissingSuper,
235
+ enum_type: nil, message_type: nil)
236
+ SUPER_INITIALIZE.bind_call(
237
+ self, number, name,
238
+ cardinality: :repeated,
239
+ json_name: json_name,
240
+ oneof: oneof
241
+ )
242
+
243
+ @map_class = Class.new do
244
+ extend Protobug::Message
245
+
246
+ optional(1, "key", type: key_type, proto3_optional: false)
247
+ value_type_kwargs = { enum_type: enum_type, message_type: message_type }
248
+ value_type_kwargs.compact!
249
+ optional(2, "value", type: value_type, **value_type_kwargs, proto3_optional: false)
250
+ end
251
+ end
252
+
253
+ def repeated = true
254
+ def default = {}
255
+ def repeated? = true
256
+
257
+ def binary_encode(value, outbuf)
258
+ value.each_with_object(@map_class.new) do |(k, v), entry|
259
+ entry.key = k
260
+ entry.value = v
261
+ BinaryEncoding.encode_varint (number << 3) | wire_type, outbuf
262
+ BinaryEncoding.encode_length @map_class.encode(entry), outbuf
263
+ end
264
+ end
265
+
266
+ def json_encode(value, print_unknown_fields:)
267
+ value.to_h do |k, v|
268
+ value = @map_class.fields_by_name["value"].json_encode(v, print_unknown_fields: print_unknown_fields)
269
+ [json_key_encode(k), value]
270
+ end
271
+ end
272
+
273
+ def json_decode(value, message, ignore_unknown_fields, registry)
274
+ return if value.nil?
275
+
276
+ unless value.is_a?(Hash)
277
+ raise DecodeError,
278
+ "expected Hash for #{inspect}, got #{value.inspect}"
279
+ end
280
+
281
+ value.each do |k, v|
282
+ entry = @map_class.decode_json_hash(
283
+ { "key" => k, "value" => v },
284
+ registry: registry,
285
+ ignore_unknown_fields: ignore_unknown_fields
286
+ )
287
+ # can't use haser because default values should also be counted...
288
+ if UNSET == entry.instance_variable_get(:@value)
289
+ next if ignore_unknown_fields && @map_class.fields_by_name.fetch("value").is_a?(EnumField)
290
+
291
+ raise DecodeError, "nil values are not allowed in map #{name} in #{message.class}"
292
+ end
293
+
294
+ message.send(adder, entry)
295
+ end
296
+ end
297
+
298
+ def type_lookup(_registry) = @map_class
299
+
300
+ def define_adder(message)
301
+ field = self
302
+ message.define_method(adder) do |msg|
303
+ existing = instance_variable_get(field.ivar)
304
+ if UNSET == existing
305
+ existing = field.default
306
+ instance_variable_set(field.ivar, existing)
307
+ end
308
+
309
+ existing[msg.key] = msg.value
310
+ end
311
+ end
312
+ end
313
+
314
+ class BytesField < Field
315
+ def self.type = :bytes
316
+
317
+ def initialize(number, name, cardinality:, json_name: name, oneof: nil,
318
+ proto3_optional: cardinality == :optional)
319
+ super(number, name, json_name: json_name, cardinality: cardinality, oneof: oneof,
320
+ proto3_optional: proto3_optional)
321
+ end
322
+
323
+ def binary_encode_one(value, outbuf)
324
+ BinaryEncoding.encode_length value.b, outbuf
325
+ end
326
+
327
+ def binary_decode_one(io, _message, _registry, wire_type)
328
+ BinaryEncoding.read_field_value(io, wire_type)
329
+ end
330
+
331
+ def json_decode_one(value, _ignore_unknown_fields, _registry)
332
+ return UNSET if value.nil?
333
+
334
+ # url decode 64
335
+ value.tr!("-_", "+/")
336
+ begin
337
+ value = value.unpack1("m").force_encoding(Encoding::BINARY)
338
+ rescue ArgumentError => e
339
+ raise DecodeError, "Invalid URL-encoded base64 #{value.inspect} for #{inspect}: #{e}"
340
+ end
341
+
342
+ value
343
+ end
344
+
345
+ def json_encode_one(value, print_unknown_fields:) # rubocop:disable Lint/UnusedMethodArgument
346
+ [value].pack("m0")
347
+ end
348
+
349
+ def default
350
+ return [] if repeated?
351
+
352
+ "".b
353
+ end
354
+
355
+ def wire_type = 2
356
+ end
357
+
358
+ class StringField < BytesField
359
+ def self.type = :string
360
+
361
+ def initialize(number, name, cardinality:, json_name: name, oneof: nil,
362
+ proto3_optional: cardinality == :optional)
363
+ super(number, name, json_name: json_name, cardinality: cardinality, oneof: oneof,
364
+ proto3_optional: proto3_optional)
365
+ end
366
+
367
+ def binary_encode_one(value, outbuf)
368
+ value = value.encode("utf-8") if value.encoding != Encoding::UTF_8
369
+ super
370
+ end
371
+
372
+ def binary_decode_one(io, _message, _registry, wire_type)
373
+ value = super
374
+
375
+ value.force_encoding("utf-8") if value.encoding != Encoding::UTF_8
376
+ raise DecodeError, "invalid utf-8 for string" unless value.valid_encoding?
377
+
378
+ value
379
+ end
380
+
381
+ def json_decode_one(value, _ignore_unknown_fields, _registry)
382
+ return UNSET if value.nil?
383
+ raise DecodeError, "expected string, got #{value.inspect}" unless value.is_a?(String)
384
+
385
+ value.force_encoding("utf-8") if value.encoding != Encoding::UTF_8
386
+ raise DecodeError, "invalid utf-8 for string" unless value.valid_encoding?
387
+
388
+ value
389
+ end
390
+
391
+ def json_encode_one(value, print_unknown_fields:) # rubocop:disable Lint/UnusedMethodArgument
392
+ value.encode("utf-8")
393
+ end
394
+
395
+ def default
396
+ return [] if repeated?
397
+
398
+ +""
399
+ end
400
+ end
401
+
402
+ class IntegerField < Field
403
+ def default
404
+ return [] if repeated?
405
+
406
+ 0
407
+ end
408
+
409
+ def binary_decode_one(io, _message, _registry, wire_type)
410
+ value = BinaryEncoding.read_field_value(io, wire_type)
411
+ case encoding
412
+ when :zigzag
413
+ BinaryEncoding.decode_zigzag bit_length, value
414
+ when :varint
415
+ length_mask = (2**bit_length) - 1
416
+ negative = signed && value & (2**bit_length.pred) != 0
417
+ # warn negative
418
+ length_mask >> 1 if signed
419
+ if negative
420
+ value &= length_mask # remove sign bit
421
+
422
+ # 2's complement
423
+ value ^= length_mask
424
+ value += 1
425
+ # value &= length_mask
426
+ -value
427
+ else
428
+ value & length_mask
429
+ end
430
+ when :fixed
431
+ value.unpack1(binary_pack)
432
+ end
433
+ end
434
+
435
+ def binary_encode_one(value, outbuf)
436
+ case encoding
437
+ when :zigzag
438
+ BinaryEncoding.encode_zigzag bit_length, value, outbuf
439
+ when :varint
440
+ BinaryEncoding.encode_varint value, outbuf
441
+ when :fixed
442
+ [value].pack(binary_pack, buffer: outbuf)
443
+ end
444
+ end
445
+
446
+ def json_decode_one(value, _ignore_unknown_fields, _registry)
447
+ return UNSET if value.nil?
448
+
449
+ case value
450
+ when Integer
451
+ # nothing
452
+ when /\A-?\d+\z/
453
+ value = Integer(value)
454
+ when Float
455
+ value, remainder = value.divmod(1)
456
+ raise DecodeError, "expected integer for #{inspect}, got #{value.inspect}" unless remainder.zero?
457
+ else
458
+ raise DecodeError, "expected integer for #{inspect}, got #{value.inspect}"
459
+ end
460
+ raise DecodeError, "#{value.inspect} does not fit in 64 bits" if value && value.bit_length > 64
461
+
462
+ value
463
+ end
464
+
465
+ def json_encode_one(value, print_unknown_fields:) # rubocop:disable Lint/UnusedMethodArgument
466
+ if bit_length >= 64
467
+ value.to_s
468
+ else
469
+ value
470
+ end
471
+ end
472
+
473
+ def validate!(value, message)
474
+ raise InvalidValueError.new(message, self, value, "expected integer") unless value.is_a?(Integer)
475
+
476
+ if signed
477
+ min = -2**(bit_length - 1)
478
+ max = 2**(bit_length - 1)
479
+ else
480
+ min = 0
481
+ max = 2**bit_length
482
+ end
483
+
484
+ if value < min || value >= max
485
+ raise InvalidValueError.new(message, self, value, "does not fit into [#{min}, #{max})")
486
+ end
487
+
488
+ super
489
+ end
490
+ end
491
+
492
+ # encoding: fixed, varint, zigzag
493
+ # bitlength: 32, 64
494
+ # signed: true, false
495
+ # EXCEPT: no unsigned zigzag
496
+ class Int64Field < IntegerField
497
+ def encoding = :varint
498
+ def bit_length = 64
499
+ def signed = true
500
+ def wire_type = 0
501
+ end
502
+
503
+ class UInt64Field < IntegerField
504
+ def encoding = :varint
505
+ def bit_length = 64
506
+ def signed = false
507
+ def wire_type = 0
508
+ end
509
+
510
+ class SInt64Field < IntegerField
511
+ def encoding = :zigzag
512
+ def bit_length = 64
513
+ def signed = true
514
+ def wire_type = 0
515
+ end
516
+
517
+ class Fixed64Field < IntegerField
518
+ def encoding = :fixed
519
+ def bit_length = 64
520
+ def signed = false
521
+ def wire_type = 1
522
+ def binary_pack = "Q"
523
+ end
524
+
525
+ class SFixed64Field < IntegerField
526
+ def encoding = :fixed
527
+ def bit_length = 64
528
+ def signed = true
529
+ def wire_type = 1
530
+ def binary_pack = "q"
531
+ end
532
+
533
+ class Int32Field < IntegerField
534
+ def encoding = :varint
535
+ def bit_length = 32
536
+ def signed = true
537
+ def wire_type = 0
538
+ end
539
+
540
+ class UInt32Field < IntegerField
541
+ def encoding = :varint
542
+ def bit_length = 32
543
+ def signed = false
544
+ def wire_type = 0
545
+ end
546
+
547
+ class SInt32Field < IntegerField
548
+ def encoding = :zigzag
549
+ def bit_length = 32
550
+ def signed = true
551
+ def wire_type = 0
552
+ end
553
+
554
+ class Fixed32Field < IntegerField
555
+ def encoding = :fixed
556
+ def bit_length = 32
557
+ def signed = false
558
+ def wire_type = 5
559
+ def binary_pack = "V"
560
+ end
561
+
562
+ class SFixed32Field < IntegerField
563
+ def encoding = :fixed
564
+ def bit_length = 32
565
+ def signed = true
566
+ def wire_type = 5
567
+ def binary_pack = "l"
568
+ end
569
+
570
+ class BoolField < UInt64Field
571
+ def binary_decode_one(*)
572
+ super != 0
573
+ end
574
+
575
+ def binary_encode_one(value, outbuf)
576
+ super(value ? 1 : 0, outbuf)
577
+ end
578
+
579
+ def json_decode_one(value, _ignore_unknown_fields, _registry)
580
+ case value
581
+ when TrueClass, FalseClass
582
+ value
583
+ when "true"
584
+ true
585
+ when "false"
586
+ false
587
+ when NilClass
588
+ UNSET
589
+ else
590
+ raise DecodeError, "expected boolean, got #{value.inspect}"
591
+ end
592
+ end
593
+
594
+ def validate!(value, message)
595
+ raise "expected boolean, got #{value.inspect}" unless [true, false].include?(value)
596
+
597
+ super(value ? 1 : 0, message)
598
+ end
599
+
600
+ def default
601
+ return [] if repeated?
602
+
603
+ false
604
+ end
605
+ end
606
+
607
+ class EnumField < Int32Field
608
+ attr_reader :enum_type
609
+
610
+ def initialize(number, name, cardinality:, enum_type:, json_name: name, oneof: nil, packed: false,
611
+ proto3_optional: cardinality == :optional)
612
+ super(number, name, json_name: json_name, cardinality: cardinality, oneof: oneof,
613
+ proto3_optional: proto3_optional, packed: packed)
614
+ @enum_type = enum_type
615
+ end
616
+
617
+ def json_decode(value, message, ignore_unknown_fields, registry)
618
+ return super unless ignore_unknown_fields
619
+
620
+ if repeated?
621
+ return if value.nil?
622
+
623
+ unless value.is_a?(Array)
624
+ raise DecodeError,
625
+ "expected Array for #{inspect}, got #{value.inspect}"
626
+ end
627
+
628
+ value.map do |v|
629
+ v = json_decode_one(v, ignore_unknown_fields, registry)
630
+ next if UNSET == v
631
+
632
+ message.send(adder, v)
633
+ end.tap(&:compact!)
634
+ else
635
+ value = json_decode_one(value, ignore_unknown_fields, registry)
636
+ message.send(setter, value) unless UNSET == value
637
+ end
638
+ end
639
+
640
+ def binary_encode_one(value, outbuf)
641
+ super(value.value, outbuf)
642
+ end
643
+
644
+ def binary_decode_one(io, _message, registry, wire_type)
645
+ value = super
646
+ registry.fetch(enum_type).decode(value)
647
+ end
648
+
649
+ def json_decode_one(value, ignore_unknown_fields, registry)
650
+ klass = registry.fetch(enum_type)
651
+ klass.decode_json_hash(value, registry: registry, ignore_unknown_fields: ignore_unknown_fields)
652
+ end
653
+
654
+ def json_encode_one(value, print_unknown_fields:) # rubocop:disable Lint/UnusedMethodArgument
655
+ value.as_json
656
+ end
657
+
658
+ def default
659
+ return [] if repeated?
660
+
661
+ # TODO: enum_type.default
662
+ 0
663
+ end
664
+
665
+ def validate!(value, message)
666
+ value = value.value if value.is_a?(Enum::InstanceMethods)
667
+ super
668
+ end
669
+ end
670
+
671
+ class DoubleField < Field
672
+ def type = :double
673
+ def binary_pack = "E"
674
+ def wire_type = 1
675
+
676
+ def initialize(number, name, cardinality:, json_name: name, oneof: nil, packed: false,
677
+ proto3_optional: cardinality == :optional)
678
+ super(number, name, json_name: json_name, cardinality: cardinality, oneof: oneof,
679
+ proto3_optional: proto3_optional, packed: packed)
680
+ end
681
+
682
+ def binary_encode_one(value, outbuf)
683
+ [value].pack(binary_pack, buffer: outbuf)
684
+ end
685
+
686
+ def binary_decode_one(io, _message, _registry, wire_type)
687
+ value = BinaryEncoding.read_field_value(io, wire_type)
688
+ value.unpack1(binary_pack)
689
+ end
690
+
691
+ def json_decode_one(value, _ignore_unknown_fields, _registry)
692
+ case value
693
+ when Float
694
+ value
695
+ when Integer
696
+ value.to_f
697
+ when "Infinity"
698
+ Float::INFINITY
699
+ when "-Infinity"
700
+ -Float::INFINITY
701
+ when "NaN"
702
+ Float::NAN
703
+ when /\A-?\d+\z/
704
+ Float(value)
705
+ when NilClass
706
+ UNSET
707
+ else
708
+ raise DecodeError, "expected float for #{inspect}, got #{value.inspect}"
709
+ end
710
+ end
711
+
712
+ def json_encode_one(value, print_unknown_fields:) # rubocop:disable Lint/UnusedMethodArgument
713
+ if value.nan?
714
+ "NaN"
715
+ elsif (sign = value.infinite?)
716
+ if sign == -1
717
+ "-Infinity"
718
+ else
719
+ "Infinity"
720
+ end
721
+ else
722
+ value
723
+ end
724
+ end
725
+
726
+ def default
727
+ return [] if repeated?
728
+
729
+ 0.0
730
+ end
731
+ end
732
+
733
+ class FloatField < DoubleField
734
+ def type = :float
735
+ def binary_pack = "e"
736
+ def wire_type = 5
737
+ end
738
+
739
+ class GroupField < Field
740
+ def initialize(*args, group_type:, **kwargs)
741
+ _ = group_type
742
+ super(*args, **kwargs)
743
+ end
744
+ end
745
+ end
746
+ end