protobug 0.1.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,746 @@
1
+ # frozen_string_literal: true
2
+
3
+ require_relative "binary_encoding"
4
+
5
+ module Protobug
6
+ class Field
7
+ attr_accessor :number, :name, :json_name, :cardinality, :oneof, :ivar, :setter,
8
+ :adder, :haser, :clearer
9
+
10
+ def initialize(number, name, json_name: nil, cardinality: :optional, oneof: nil, packed: false,
11
+ proto3_optional: cardinality == :optional)
12
+ @number = number
13
+ @name = name.to_sym
14
+ @json_name = json_name || name.to_s
15
+ @cardinality = cardinality || raise(ArgumentError, "cardinality is required")
16
+ @oneof = oneof
17
+ @setter = :"#{name}="
18
+ @adder = :"add_#{name}" if repeated?
19
+ @ivar = :"@#{name}"
20
+ @clearer = :"clear_#{name}"
21
+ @haser = :"#{name}?"
22
+ @packed = packed
23
+ @proto3_optional = proto3_optional
24
+ end
25
+
26
+ def pretty_print(pp)
27
+ pp.group 0, "#{self.class}.new(", ")" do
28
+ pp.text @number.to_s
29
+ pp.breakable(", ")
30
+ pp.text(@name.inspect)
31
+ pp.breakable(", ")
32
+ if json_name != name.name
33
+ pp.breakable(", ")
34
+ pp.text("json_name: ")
35
+ pp.text(@json_name.inspect)
36
+ end
37
+ pp.breakable(", ")
38
+ pp.text("cardinality: ")
39
+ pp.pp(@cardinality)
40
+ if oneof
41
+ pp.breakable(", ")
42
+ pp.text("oneof: ")
43
+ pp.text(@oneof.inspect)
44
+ end
45
+ end
46
+ end
47
+
48
+ def repeated?
49
+ cardinality == :repeated
50
+ end
51
+
52
+ def packed?
53
+ @packed
54
+ end
55
+
56
+ def optional?
57
+ cardinality == :optional
58
+ end
59
+
60
+ def proto3_optional?
61
+ @proto3_optional
62
+ end
63
+
64
+ def define_adder(message)
65
+ field = self
66
+ message.define_method(adder) do |value|
67
+ field.validate!(value, self)
68
+
69
+ existing = instance_variable_get(field.ivar)
70
+ if UNSET == existing
71
+ existing = field.default
72
+ instance_variable_set(field.ivar, existing)
73
+ end
74
+
75
+ existing << value
76
+ end
77
+ end
78
+
79
+ def to_text(value)
80
+ case [cardinality, json_scalar?]
81
+ when [:repeated, true]
82
+ Array(value).map { |v| "#{name}: #{scalar_to_text(v)}" }.join("\n")
83
+ when [:repeated, false]
84
+ Array(value).map { |v| "#{name} {\n#{v.to_text.gsub(/^/, " ")}\n}" }.join("\n")
85
+ when [:optional, true]
86
+ "#{name}: #{scalar_to_text(value)}"
87
+ when [:optional, false]
88
+ "#{name} {\n#{value.to_text.gsub(/^/, " ")}\n}"
89
+ end
90
+ end
91
+
92
+ def binary_encode(value, outbuf)
93
+ if repeated?
94
+ if packed?
95
+ binary_encode_packed(value, outbuf)
96
+ else
97
+ value.each do |v|
98
+ BinaryEncoding.encode_varint (number << 3) | wire_type, outbuf
99
+ binary_encode_one(v, outbuf)
100
+ end
101
+ end
102
+ elsif (!optional? || !proto3_optional?) && !oneof && default == value
103
+ # omit
104
+ else
105
+ BinaryEncoding.encode_varint (number << 3) | wire_type, outbuf
106
+ binary_encode_one(value, outbuf)
107
+ end
108
+ end
109
+
110
+ def binary_decode(binary, message, registry, wire_type)
111
+ if repeated? && wire_type == 2 && [0, 1, 5].include?(self.wire_type)
112
+ len = StringIO.new(BinaryEncoding.decode_length(binary))
113
+ len.binmode
114
+
115
+ message.send(adder, binary_decode_one(len, message, registry, self.wire_type)) until len.eof?
116
+ elsif wire_type != self.wire_type
117
+ raise DecodeError, "wrong wire type for #{self}: #{wire_type.inspect}"
118
+ else
119
+ message.send(adder || setter, binary_decode_one(binary, message, registry, wire_type))
120
+ end
121
+ end
122
+
123
+ def json_encode(value, print_unknown_fields:)
124
+ if repeated?
125
+ value.map { |v| json_encode_one(v, print_unknown_fields: print_unknown_fields) }
126
+ elsif (!optional? || !proto3_optional?) && !oneof && default == value
127
+ # omit
128
+ else
129
+ json_encode_one(value, print_unknown_fields: print_unknown_fields)
130
+ end
131
+ end
132
+
133
+ def json_key_encode(value)
134
+ case value
135
+ when String
136
+ value
137
+ when Integer, Float
138
+ value.to_s
139
+ when TrueClass
140
+ "true"
141
+ when FalseClass
142
+ "false"
143
+ else
144
+ raise EncodeError, "unexpected type for map key: #{value.inspect}"
145
+ end
146
+ end
147
+
148
+ def json_decode(value, message, ignore_unknown_fields, registry)
149
+ if repeated?
150
+ return if value.nil?
151
+
152
+ unless value.is_a?(Array)
153
+ raise DecodeError,
154
+ "expected Array for #{inspect}, got #{value.inspect}"
155
+ end
156
+
157
+ value.map do |v|
158
+ message.send(adder, json_decode_one(v, ignore_unknown_fields, registry))
159
+ end
160
+ else
161
+ value = json_decode_one(value, ignore_unknown_fields, registry)
162
+ message.send(setter, value) unless UNSET == value
163
+ end
164
+ end
165
+
166
+ def validate!(value, message)
167
+ raise DecodeError, "nil is invalid for #{name} in #{message}" if UNSET == value
168
+
169
+ return unless oneof
170
+
171
+ message.class.oneofs[oneof].each do |f|
172
+ next if f == self
173
+
174
+ message.send(f.clearer)
175
+ end
176
+ end
177
+
178
+ private
179
+
180
+ def binary_encode_packed(value, outbuf)
181
+ BinaryEncoding.encode_varint (number << 3) | 2, outbuf
182
+
183
+ BinaryEncoding.encode_length(value.each_with_object("".b) do |v, buf|
184
+ binary_encode_one(v, buf)
185
+ end, outbuf)
186
+ end
187
+
188
+ class MessageField < Field
189
+ attr_reader :message_type
190
+
191
+ def initialize(number, name, cardinality:, message_type:, json_name: name, oneof: nil,
192
+ proto3_optional: cardinality == :optional)
193
+ super(number, name, json_name: json_name, cardinality: cardinality, oneof: oneof,
194
+ proto3_optional: proto3_optional)
195
+ @message_type = message_type
196
+ end
197
+
198
+ def binary_encode_one(value, outbuf)
199
+ BinaryEncoding.encode_length value.class.encode(value), outbuf
200
+ end
201
+
202
+ def binary_decode_one(io, message, registry, wire_type)
203
+ value = BinaryEncoding.read_field_value(io, wire_type)
204
+ kwargs = {}
205
+ kwargs[:object] = message.send(name) if !repeated? && message.send(haser)
206
+ type_lookup(registry).decode(StringIO.new(value), registry: registry, **kwargs)
207
+ end
208
+
209
+ def json_decode_one(value, ignore_unknown_fields, registry)
210
+ klass = type_lookup(registry)
211
+ klass.decode_json_hash(value, registry: registry, ignore_unknown_fields: ignore_unknown_fields)
212
+ end
213
+
214
+ def type_lookup(registry)
215
+ registry.fetch(message_type)
216
+ end
217
+
218
+ def json_encode_one(value, print_unknown_fields:)
219
+ value.as_json(print_unknown_fields: print_unknown_fields)
220
+ end
221
+
222
+ def default
223
+ return [] if repeated?
224
+
225
+ # TODO: message_type.default
226
+ nil
227
+ end
228
+
229
+ def wire_type = 2
230
+ end
231
+
232
+ class MapField < MessageField
233
+ SUPER_INITIALIZE = instance_method(:initialize).super_method
234
+ def initialize(number, name, key_type:, value_type:, json_name: name, oneof: nil, # rubocop:disable Lint/MissingSuper,
235
+ enum_type: nil, message_type: nil)
236
+ SUPER_INITIALIZE.bind_call(
237
+ self, number, name,
238
+ cardinality: :repeated,
239
+ json_name: json_name,
240
+ oneof: oneof
241
+ )
242
+
243
+ @map_class = Class.new do
244
+ extend Protobug::Message
245
+
246
+ optional(1, "key", type: key_type, proto3_optional: false)
247
+ value_type_kwargs = { enum_type: enum_type, message_type: message_type }
248
+ value_type_kwargs.compact!
249
+ optional(2, "value", type: value_type, **value_type_kwargs, proto3_optional: false)
250
+ end
251
+ end
252
+
253
+ def repeated = true
254
+ def default = {}
255
+ def repeated? = true
256
+
257
+ def binary_encode(value, outbuf)
258
+ value.each_with_object(@map_class.new) do |(k, v), entry|
259
+ entry.key = k
260
+ entry.value = v
261
+ BinaryEncoding.encode_varint (number << 3) | wire_type, outbuf
262
+ BinaryEncoding.encode_length @map_class.encode(entry), outbuf
263
+ end
264
+ end
265
+
266
+ def json_encode(value, print_unknown_fields:)
267
+ value.to_h do |k, v|
268
+ value = @map_class.fields_by_name["value"].json_encode(v, print_unknown_fields: print_unknown_fields)
269
+ [json_key_encode(k), value]
270
+ end
271
+ end
272
+
273
+ def json_decode(value, message, ignore_unknown_fields, registry)
274
+ return if value.nil?
275
+
276
+ unless value.is_a?(Hash)
277
+ raise DecodeError,
278
+ "expected Hash for #{inspect}, got #{value.inspect}"
279
+ end
280
+
281
+ value.each do |k, v|
282
+ entry = @map_class.decode_json_hash(
283
+ { "key" => k, "value" => v },
284
+ registry: registry,
285
+ ignore_unknown_fields: ignore_unknown_fields
286
+ )
287
+ # can't use haser because default values should also be counted...
288
+ if UNSET == entry.instance_variable_get(:@value)
289
+ next if ignore_unknown_fields && @map_class.fields_by_name.fetch("value").is_a?(EnumField)
290
+
291
+ raise DecodeError, "nil values are not allowed in map #{name} in #{message.class}"
292
+ end
293
+
294
+ message.send(adder, entry)
295
+ end
296
+ end
297
+
298
+ def type_lookup(_registry) = @map_class
299
+
300
+ def define_adder(message)
301
+ field = self
302
+ message.define_method(adder) do |msg|
303
+ existing = instance_variable_get(field.ivar)
304
+ if UNSET == existing
305
+ existing = field.default
306
+ instance_variable_set(field.ivar, existing)
307
+ end
308
+
309
+ existing[msg.key] = msg.value
310
+ end
311
+ end
312
+ end
313
+
314
+ class BytesField < Field
315
+ def self.type = :bytes
316
+
317
+ def initialize(number, name, cardinality:, json_name: name, oneof: nil,
318
+ proto3_optional: cardinality == :optional)
319
+ super(number, name, json_name: json_name, cardinality: cardinality, oneof: oneof,
320
+ proto3_optional: proto3_optional)
321
+ end
322
+
323
+ def binary_encode_one(value, outbuf)
324
+ BinaryEncoding.encode_length value.b, outbuf
325
+ end
326
+
327
+ def binary_decode_one(io, _message, _registry, wire_type)
328
+ BinaryEncoding.read_field_value(io, wire_type)
329
+ end
330
+
331
+ def json_decode_one(value, _ignore_unknown_fields, _registry)
332
+ return UNSET if value.nil?
333
+
334
+ # url decode 64
335
+ value.tr!("-_", "+/")
336
+ begin
337
+ value = value.unpack1("m").force_encoding(Encoding::BINARY)
338
+ rescue ArgumentError => e
339
+ raise DecodeError, "Invalid URL-encoded base64 #{value.inspect} for #{inspect}: #{e}"
340
+ end
341
+
342
+ value
343
+ end
344
+
345
+ def json_encode_one(value, print_unknown_fields:) # rubocop:disable Lint/UnusedMethodArgument
346
+ [value].pack("m0")
347
+ end
348
+
349
+ def default
350
+ return [] if repeated?
351
+
352
+ "".b
353
+ end
354
+
355
+ def wire_type = 2
356
+ end
357
+
358
+ class StringField < BytesField
359
+ def self.type = :string
360
+
361
+ def initialize(number, name, cardinality:, json_name: name, oneof: nil,
362
+ proto3_optional: cardinality == :optional)
363
+ super(number, name, json_name: json_name, cardinality: cardinality, oneof: oneof,
364
+ proto3_optional: proto3_optional)
365
+ end
366
+
367
+ def binary_encode_one(value, outbuf)
368
+ value = value.encode("utf-8") if value.encoding != Encoding::UTF_8
369
+ super
370
+ end
371
+
372
+ def binary_decode_one(io, _message, _registry, wire_type)
373
+ value = super
374
+
375
+ value.force_encoding("utf-8") if value.encoding != Encoding::UTF_8
376
+ raise DecodeError, "invalid utf-8 for string" unless value.valid_encoding?
377
+
378
+ value
379
+ end
380
+
381
+ def json_decode_one(value, _ignore_unknown_fields, _registry)
382
+ return UNSET if value.nil?
383
+ raise DecodeError, "expected string, got #{value.inspect}" unless value.is_a?(String)
384
+
385
+ value.force_encoding("utf-8") if value.encoding != Encoding::UTF_8
386
+ raise DecodeError, "invalid utf-8 for string" unless value.valid_encoding?
387
+
388
+ value
389
+ end
390
+
391
+ def json_encode_one(value, print_unknown_fields:) # rubocop:disable Lint/UnusedMethodArgument
392
+ value.encode("utf-8")
393
+ end
394
+
395
+ def default
396
+ return [] if repeated?
397
+
398
+ +""
399
+ end
400
+ end
401
+
402
+ class IntegerField < Field
403
+ def default
404
+ return [] if repeated?
405
+
406
+ 0
407
+ end
408
+
409
+ def binary_decode_one(io, _message, _registry, wire_type)
410
+ value = BinaryEncoding.read_field_value(io, wire_type)
411
+ case encoding
412
+ when :zigzag
413
+ BinaryEncoding.decode_zigzag bit_length, value
414
+ when :varint
415
+ length_mask = (2**bit_length) - 1
416
+ negative = signed && value & (2**bit_length.pred) != 0
417
+ # warn negative
418
+ length_mask >> 1 if signed
419
+ if negative
420
+ value &= length_mask # remove sign bit
421
+
422
+ # 2's complement
423
+ value ^= length_mask
424
+ value += 1
425
+ # value &= length_mask
426
+ -value
427
+ else
428
+ value & length_mask
429
+ end
430
+ when :fixed
431
+ value.unpack1(binary_pack)
432
+ end
433
+ end
434
+
435
+ def binary_encode_one(value, outbuf)
436
+ case encoding
437
+ when :zigzag
438
+ BinaryEncoding.encode_zigzag bit_length, value, outbuf
439
+ when :varint
440
+ BinaryEncoding.encode_varint value, outbuf
441
+ when :fixed
442
+ [value].pack(binary_pack, buffer: outbuf)
443
+ end
444
+ end
445
+
446
+ def json_decode_one(value, _ignore_unknown_fields, _registry)
447
+ return UNSET if value.nil?
448
+
449
+ case value
450
+ when Integer
451
+ # nothing
452
+ when /\A-?\d+\z/
453
+ value = Integer(value)
454
+ when Float
455
+ value, remainder = value.divmod(1)
456
+ raise DecodeError, "expected integer for #{inspect}, got #{value.inspect}" unless remainder.zero?
457
+ else
458
+ raise DecodeError, "expected integer for #{inspect}, got #{value.inspect}"
459
+ end
460
+ raise DecodeError, "#{value.inspect} does not fit in 64 bits" if value && value.bit_length > 64
461
+
462
+ value
463
+ end
464
+
465
+ def json_encode_one(value, print_unknown_fields:) # rubocop:disable Lint/UnusedMethodArgument
466
+ if bit_length >= 64
467
+ value.to_s
468
+ else
469
+ value
470
+ end
471
+ end
472
+
473
+ def validate!(value, message)
474
+ raise InvalidValueError.new(message, self, value, "expected integer") unless value.is_a?(Integer)
475
+
476
+ if signed
477
+ min = -2**(bit_length - 1)
478
+ max = 2**(bit_length - 1)
479
+ else
480
+ min = 0
481
+ max = 2**bit_length
482
+ end
483
+
484
+ if value < min || value >= max
485
+ raise InvalidValueError.new(message, self, value, "does not fit into [#{min}, #{max})")
486
+ end
487
+
488
+ super
489
+ end
490
+ end
491
+
492
+ # encoding: fixed, varint, zigzag
493
+ # bitlength: 32, 64
494
+ # signed: true, false
495
+ # EXCEPT: no unsigned zigzag
496
+ class Int64Field < IntegerField
497
+ def encoding = :varint
498
+ def bit_length = 64
499
+ def signed = true
500
+ def wire_type = 0
501
+ end
502
+
503
+ class UInt64Field < IntegerField
504
+ def encoding = :varint
505
+ def bit_length = 64
506
+ def signed = false
507
+ def wire_type = 0
508
+ end
509
+
510
+ class SInt64Field < IntegerField
511
+ def encoding = :zigzag
512
+ def bit_length = 64
513
+ def signed = true
514
+ def wire_type = 0
515
+ end
516
+
517
+ class Fixed64Field < IntegerField
518
+ def encoding = :fixed
519
+ def bit_length = 64
520
+ def signed = false
521
+ def wire_type = 1
522
+ def binary_pack = "Q"
523
+ end
524
+
525
+ class SFixed64Field < IntegerField
526
+ def encoding = :fixed
527
+ def bit_length = 64
528
+ def signed = true
529
+ def wire_type = 1
530
+ def binary_pack = "q"
531
+ end
532
+
533
+ class Int32Field < IntegerField
534
+ def encoding = :varint
535
+ def bit_length = 32
536
+ def signed = true
537
+ def wire_type = 0
538
+ end
539
+
540
+ class UInt32Field < IntegerField
541
+ def encoding = :varint
542
+ def bit_length = 32
543
+ def signed = false
544
+ def wire_type = 0
545
+ end
546
+
547
+ class SInt32Field < IntegerField
548
+ def encoding = :zigzag
549
+ def bit_length = 32
550
+ def signed = true
551
+ def wire_type = 0
552
+ end
553
+
554
+ class Fixed32Field < IntegerField
555
+ def encoding = :fixed
556
+ def bit_length = 32
557
+ def signed = false
558
+ def wire_type = 5
559
+ def binary_pack = "V"
560
+ end
561
+
562
+ class SFixed32Field < IntegerField
563
+ def encoding = :fixed
564
+ def bit_length = 32
565
+ def signed = true
566
+ def wire_type = 5
567
+ def binary_pack = "l"
568
+ end
569
+
570
+ class BoolField < UInt64Field
571
+ def binary_decode_one(*)
572
+ super != 0
573
+ end
574
+
575
+ def binary_encode_one(value, outbuf)
576
+ super(value ? 1 : 0, outbuf)
577
+ end
578
+
579
+ def json_decode_one(value, _ignore_unknown_fields, _registry)
580
+ case value
581
+ when TrueClass, FalseClass
582
+ value
583
+ when "true"
584
+ true
585
+ when "false"
586
+ false
587
+ when NilClass
588
+ UNSET
589
+ else
590
+ raise DecodeError, "expected boolean, got #{value.inspect}"
591
+ end
592
+ end
593
+
594
+ def validate!(value, message)
595
+ raise "expected boolean, got #{value.inspect}" unless [true, false].include?(value)
596
+
597
+ super(value ? 1 : 0, message)
598
+ end
599
+
600
+ def default
601
+ return [] if repeated?
602
+
603
+ false
604
+ end
605
+ end
606
+
607
+ class EnumField < Int32Field
608
+ attr_reader :enum_type
609
+
610
+ def initialize(number, name, cardinality:, enum_type:, json_name: name, oneof: nil, packed: false,
611
+ proto3_optional: cardinality == :optional)
612
+ super(number, name, json_name: json_name, cardinality: cardinality, oneof: oneof,
613
+ proto3_optional: proto3_optional, packed: packed)
614
+ @enum_type = enum_type
615
+ end
616
+
617
+ def json_decode(value, message, ignore_unknown_fields, registry)
618
+ return super unless ignore_unknown_fields
619
+
620
+ if repeated?
621
+ return if value.nil?
622
+
623
+ unless value.is_a?(Array)
624
+ raise DecodeError,
625
+ "expected Array for #{inspect}, got #{value.inspect}"
626
+ end
627
+
628
+ value.map do |v|
629
+ v = json_decode_one(v, ignore_unknown_fields, registry)
630
+ next if UNSET == v
631
+
632
+ message.send(adder, v)
633
+ end.tap(&:compact!)
634
+ else
635
+ value = json_decode_one(value, ignore_unknown_fields, registry)
636
+ message.send(setter, value) unless UNSET == value
637
+ end
638
+ end
639
+
640
+ def binary_encode_one(value, outbuf)
641
+ super(value.value, outbuf)
642
+ end
643
+
644
+ def binary_decode_one(io, _message, registry, wire_type)
645
+ value = super
646
+ registry.fetch(enum_type).decode(value)
647
+ end
648
+
649
+ def json_decode_one(value, ignore_unknown_fields, registry)
650
+ klass = registry.fetch(enum_type)
651
+ klass.decode_json_hash(value, registry: registry, ignore_unknown_fields: ignore_unknown_fields)
652
+ end
653
+
654
+ def json_encode_one(value, print_unknown_fields:) # rubocop:disable Lint/UnusedMethodArgument
655
+ value.as_json
656
+ end
657
+
658
+ def default
659
+ return [] if repeated?
660
+
661
+ # TODO: enum_type.default
662
+ 0
663
+ end
664
+
665
+ def validate!(value, message)
666
+ value = value.value if value.is_a?(Enum::InstanceMethods)
667
+ super
668
+ end
669
+ end
670
+
671
+ class DoubleField < Field
672
+ def type = :double
673
+ def binary_pack = "E"
674
+ def wire_type = 1
675
+
676
+ def initialize(number, name, cardinality:, json_name: name, oneof: nil, packed: false,
677
+ proto3_optional: cardinality == :optional)
678
+ super(number, name, json_name: json_name, cardinality: cardinality, oneof: oneof,
679
+ proto3_optional: proto3_optional, packed: packed)
680
+ end
681
+
682
+ def binary_encode_one(value, outbuf)
683
+ [value].pack(binary_pack, buffer: outbuf)
684
+ end
685
+
686
+ def binary_decode_one(io, _message, _registry, wire_type)
687
+ value = BinaryEncoding.read_field_value(io, wire_type)
688
+ value.unpack1(binary_pack)
689
+ end
690
+
691
+ def json_decode_one(value, _ignore_unknown_fields, _registry)
692
+ case value
693
+ when Float
694
+ value
695
+ when Integer
696
+ value.to_f
697
+ when "Infinity"
698
+ Float::INFINITY
699
+ when "-Infinity"
700
+ -Float::INFINITY
701
+ when "NaN"
702
+ Float::NAN
703
+ when /\A-?\d+\z/
704
+ Float(value)
705
+ when NilClass
706
+ UNSET
707
+ else
708
+ raise DecodeError, "expected float for #{inspect}, got #{value.inspect}"
709
+ end
710
+ end
711
+
712
+ def json_encode_one(value, print_unknown_fields:) # rubocop:disable Lint/UnusedMethodArgument
713
+ if value.nan?
714
+ "NaN"
715
+ elsif (sign = value.infinite?)
716
+ if sign == -1
717
+ "-Infinity"
718
+ else
719
+ "Infinity"
720
+ end
721
+ else
722
+ value
723
+ end
724
+ end
725
+
726
+ def default
727
+ return [] if repeated?
728
+
729
+ 0.0
730
+ end
731
+ end
732
+
733
+ class FloatField < DoubleField
734
+ def type = :float
735
+ def binary_pack = "e"
736
+ def wire_type = 5
737
+ end
738
+
739
+ class GroupField < Field
740
+ def initialize(*args, group_type:, **kwargs)
741
+ _ = group_type
742
+ super(*args, **kwargs)
743
+ end
744
+ end
745
+ end
746
+ end