ratomic 0.2.0 → 0.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -7,17 +7,17 @@ mod sem;
7
7
 
8
8
  use counter::AtomicCounter;
9
9
  use fixed_size_object_pool::FixedSizeObjectPool;
10
- use hashmap::ConcurrentHashMap;
10
+ use hashmap::MapStore;
11
11
  use magnus::{
12
- data_type_builder, method, IntoValue,
12
+ data_type_builder, method,
13
13
  prelude::*,
14
14
  typed_data::{DataType, DataTypeFunctions},
15
15
  value::Lazy,
16
- Error, RClass, Ruby, TryConvert, TypedData, Value,
16
+ Error, IntoValue, RClass, Ruby, TryConvert, TypedData, Value,
17
17
  };
18
18
  use mpmc_queue::MpmcQueue;
19
- use rb_sys::{rb_ext_ractor_safe, rb_thread_call_without_gvl, ruby_special_consts, VALUE};
20
19
  use parking_lot::Mutex;
20
+ use rb_sys::{rb_ext_ractor_safe, rb_thread_call_without_gvl, ruby_special_consts, VALUE};
21
21
  use std::{ffi::c_void, mem::transmute};
22
22
 
23
23
  fn value_to_raw(value: Value) -> VALUE {
@@ -64,7 +64,8 @@ impl DataTypeFunctions for Counter {}
64
64
  unsafe impl TypedData for Counter {
65
65
  fn class(ruby: &Ruby) -> RClass {
66
66
  static CLASS: Lazy<RClass> = Lazy::new(|ruby| {
67
- let class = ruby.define_module("Ratomic")
67
+ let class = ruby
68
+ .define_module("Ratomic")
68
69
  .unwrap()
69
70
  .define_class("Counter", ruby.class_object())
70
71
  .unwrap();
@@ -75,17 +76,18 @@ unsafe impl TypedData for Counter {
75
76
  }
76
77
 
77
78
  fn data_type() -> &'static DataType {
78
- static DATA_TYPE: DataType =
79
- data_type_builder!(Counter, "ratomic/counter").frozen_shareable().build();
79
+ static DATA_TYPE: DataType = data_type_builder!(Counter, "ratomic/counter")
80
+ .frozen_shareable()
81
+ .build();
80
82
  &DATA_TYPE
81
83
  }
82
84
  }
83
85
 
84
- struct HashMap(ConcurrentHashMap);
86
+ struct HashMap(MapStore);
85
87
 
86
88
  impl HashMap {
87
89
  fn new(ruby: &Ruby, class: RClass) -> Result<Value, Error> {
88
- let value = ruby.wrap_as(Self(ConcurrentHashMap::new()), class).as_value();
90
+ let value = ruby.wrap_as(Self(MapStore::new()), class).as_value();
89
91
  make_shareable(ruby, value)
90
92
  }
91
93
 
@@ -94,10 +96,19 @@ impl HashMap {
94
96
  unsafe { value_from_raw(raw) }.into_value_with(ruby)
95
97
  }
96
98
 
99
+ fn contains_key(&self, key: Value) -> bool {
100
+ self.0.contains_key(value_to_raw(key))
101
+ }
102
+
97
103
  fn set(&self, key: Value, value: Value) {
98
104
  self.0.set(value_to_raw(key), value_to_raw(value));
99
105
  }
100
106
 
107
+ fn delete(ruby: &Ruby, rb_self: &Self, key: Value) -> Value {
108
+ let raw = rb_self.0.delete(value_to_raw(key)).unwrap_or_else(qnil_raw);
109
+ unsafe { value_from_raw(raw) }.into_value_with(ruby)
110
+ }
111
+
101
112
  fn clear(&self) {
102
113
  self.0.clear();
103
114
  }
@@ -115,26 +126,130 @@ impl HashMap {
115
126
  }
116
127
 
117
128
  let proc = ruby.block_proc()?;
129
+ let mut error = None;
118
130
  rb_self.0.fetch_and_modify(value_to_raw(key), |value| {
119
- let result: Value = proc.call((unsafe { value_from_raw(value) },)).unwrap();
120
- value_to_raw(result)
131
+ match proc.call::<_, Value>((unsafe { value_from_raw(value) },)) {
132
+ Ok(result) => value_to_raw(result),
133
+ Err(err) => {
134
+ error = Some(err);
135
+ value
136
+ }
137
+ }
121
138
  });
122
- Ok(())
139
+
140
+ if let Some(err) = error {
141
+ Err(err)
142
+ } else {
143
+ Ok(())
144
+ }
145
+ }
146
+
147
+ fn compute(ruby: &Ruby, rb_self: &Self, key: Value) -> Result<Value, Error> {
148
+ if !ruby.block_given() {
149
+ return Err(Error::new(
150
+ ruby.exception_local_jump_error(),
151
+ "no block given",
152
+ ));
153
+ }
154
+
155
+ let proc = ruby.block_proc()?;
156
+ let raw = rb_self.0.compute(value_to_raw(key), qnil_raw(), |value| {
157
+ proc.call::<_, Value>((unsafe { value_from_raw(value) },))
158
+ .map(value_to_raw)
159
+ })?;
160
+
161
+ Ok(unsafe { value_from_raw(raw) }.into_value_with(ruby))
162
+ }
163
+
164
+ fn fetch_or_store(ruby: &Ruby, rb_self: &Self, key: Value) -> Result<Value, Error> {
165
+ if !ruby.block_given() {
166
+ return Err(Error::new(
167
+ ruby.exception_local_jump_error(),
168
+ "no block given",
169
+ ));
170
+ }
171
+
172
+ let proc = ruby.block_proc()?;
173
+ let raw = rb_self.0.fetch_or_store(value_to_raw(key), || {
174
+ proc.call::<_, Value>(()).map(value_to_raw)
175
+ })?;
176
+
177
+ Ok(unsafe { value_from_raw(raw) }.into_value_with(ruby))
178
+ }
179
+
180
+ fn upsert(ruby: &Ruby, rb_self: &Self, key: Value, initial: Value) -> Result<Value, Error> {
181
+ if !ruby.block_given() {
182
+ return Err(Error::new(
183
+ ruby.exception_local_jump_error(),
184
+ "no block given",
185
+ ));
186
+ }
187
+
188
+ let proc = ruby.block_proc()?;
189
+ let raw = rb_self
190
+ .0
191
+ .upsert(value_to_raw(key), value_to_raw(initial), |value| {
192
+ proc.call::<_, Value>((unsafe { value_from_raw(value) },))
193
+ .map(value_to_raw)
194
+ })?;
195
+
196
+ Ok(unsafe { value_from_raw(raw) }.into_value_with(ruby))
197
+ }
198
+
199
+ fn increment_numeric(
200
+ ruby: &Ruby,
201
+ rb_self: &Self,
202
+ key: Value,
203
+ by: Value,
204
+ ) -> Result<Value, Error> {
205
+ let key_inspect = key.funcall::<_, _, String>("inspect", ())?;
206
+ let numeric_class: RClass = ruby.class_object().const_get("Numeric")?;
207
+ let raw = rb_self.0.update(value_to_raw(key), |current| {
208
+ let next = match current {
209
+ Some(value) if value == qnil_raw() => {
210
+ return Err(Error::new(
211
+ ruby.exception_type_error(),
212
+ format!("existing value for {key_inspect} must be numeric: nil"),
213
+ ));
214
+ }
215
+ Some(value) => {
216
+ let old_value = unsafe { value_from_raw(value) };
217
+ if !old_value.funcall::<_, _, bool>("is_a?", (numeric_class,))? {
218
+ let old_value_inspect = old_value.funcall::<_, _, String>("inspect", ())?;
219
+ return Err(Error::new(
220
+ ruby.exception_type_error(),
221
+ format!(
222
+ "existing value for {key_inspect} must be numeric: {old_value_inspect}"
223
+ ),
224
+ ));
225
+ }
226
+
227
+ old_value.funcall::<_, _, Value>("+", (by,))?
228
+ }
229
+ None => by,
230
+ };
231
+
232
+ Ok(value_to_raw(next))
233
+ })?;
234
+
235
+ Ok(unsafe { value_from_raw(raw) }.into_value_with(ruby))
123
236
  }
124
237
  }
125
238
 
126
239
  impl DataTypeFunctions for HashMap {
127
240
  fn mark(&self, marker: &magnus::gc::Marker) {
128
- self.0.mark(|value| marker.mark(unsafe { value_from_raw(value) }));
241
+ self.0
242
+ .mark(|value| marker.mark(unsafe { value_from_raw(value) }));
129
243
  }
130
244
  }
131
245
 
132
246
  unsafe impl TypedData for HashMap {
133
247
  fn class(ruby: &Ruby) -> RClass {
134
248
  static CLASS: Lazy<RClass> = Lazy::new(|ruby| {
135
- let class = ruby.define_module("Ratomic")
249
+ let class = ruby
250
+ .define_module("Ratomic")
136
251
  .unwrap()
137
- .define_class("ConcurrentHashMap", ruby.class_object())
252
+ .define_class("Map", ruby.class_object())
138
253
  .unwrap();
139
254
  class.undef_default_alloc_func();
140
255
  class
@@ -234,14 +349,16 @@ impl Queue {
234
349
 
235
350
  impl DataTypeFunctions for Queue {
236
351
  fn mark(&self, marker: &magnus::gc::Marker) {
237
- self.0.mark(|value| marker.mark(unsafe { value_from_raw(value) }));
352
+ self.0
353
+ .mark(|value| marker.mark(unsafe { value_from_raw(value) }));
238
354
  }
239
355
  }
240
356
 
241
357
  unsafe impl TypedData for Queue {
242
358
  fn class(ruby: &Ruby) -> RClass {
243
359
  static CLASS: Lazy<RClass> = Lazy::new(|ruby| {
244
- let class = ruby.define_module("Ratomic")
360
+ let class = ruby
361
+ .define_module("Ratomic")
245
362
  .unwrap()
246
363
  .define_class("Queue", ruby.class_object())
247
364
  .unwrap();
@@ -267,7 +384,10 @@ impl Pool {
267
384
  if args.len() > 2 {
268
385
  return Err(Error::new(
269
386
  ruby.exception_arg_error(),
270
- format!("wrong number of arguments (given {}, expected 0..2)", args.len()),
387
+ format!(
388
+ "wrong number of arguments (given {}, expected 0..2)",
389
+ args.len()
390
+ ),
271
391
  ));
272
392
  }
273
393
  let size = args
@@ -285,7 +405,10 @@ impl Pool {
285
405
  .unwrap_or(1000);
286
406
 
287
407
  if size == 0 {
288
- return Err(Error::new(ruby.exception_arg_error(), "pool size must be positive"));
408
+ return Err(Error::new(
409
+ ruby.exception_arg_error(),
410
+ "pool size must be positive",
411
+ ));
289
412
  }
290
413
  if !ruby.block_given() {
291
414
  return Err(Error::new(
@@ -340,7 +463,8 @@ impl DataTypeFunctions for Pool {
340
463
  unsafe impl TypedData for Pool {
341
464
  fn class(ruby: &Ruby) -> RClass {
342
465
  static CLASS: Lazy<RClass> = Lazy::new(|ruby| {
343
- let class = ruby.define_module("Ratomic")
466
+ let class = ruby
467
+ .define_module("Ratomic")
344
468
  .unwrap()
345
469
  .define_class("FixedSizeObjectPool", ruby.class_object())
346
470
  .unwrap();
@@ -372,14 +496,23 @@ fn init(ruby: &Ruby) -> Result<(), Error> {
372
496
  counter.define_method("decrement", method!(Counter::decrement, 1))?;
373
497
  counter.define_method("read", method!(Counter::read, 0))?;
374
498
 
375
- let hashmap = root.define_class("ConcurrentHashMap", ruby.class_object())?;
499
+ let hashmap = root.define_class("Map", ruby.class_object())?;
376
500
  hashmap.undef_default_alloc_func();
377
501
  hashmap.define_singleton_method("new", method!(HashMap::new, 0))?;
378
502
  hashmap.define_method("get", method!(HashMap::get, 1))?;
503
+ hashmap.define_method("key?", method!(HashMap::contains_key, 1))?;
379
504
  hashmap.define_method("set", method!(HashMap::set, 2))?;
505
+ hashmap.define_method("delete", method!(HashMap::delete, 1))?;
380
506
  hashmap.define_method("clear", method!(HashMap::clear, 0))?;
381
507
  hashmap.define_method("size", method!(HashMap::size, 0))?;
382
508
  hashmap.define_method("fetch_and_modify", method!(HashMap::fetch_and_modify, 1))?;
509
+ hashmap.define_method("compute", method!(HashMap::compute, 1))?;
510
+ hashmap.define_method("fetch_or_store", method!(HashMap::fetch_or_store, 1))?;
511
+ hashmap.define_method("upsert", method!(HashMap::upsert, 2))?;
512
+ hashmap.define_private_method(
513
+ "__increment_numeric",
514
+ method!(HashMap::increment_numeric, 2),
515
+ )?;
383
516
 
384
517
  let queue = root.define_class("Queue", ruby.class_object())?;
385
518
  queue.undef_default_alloc_func();
@@ -15,7 +15,7 @@ unsafe impl Sync for QueueElement {}
15
15
 
16
16
  pub struct MpmcQueue {
17
17
  buffer: Vec<QueueElement>,
18
- buffer_mask: usize,
18
+ buffer_size: usize,
19
19
  enqueue_pos: AtomicUsize,
20
20
  dequeue_pos: AtomicUsize,
21
21
  gc_guard: GcGuard,
@@ -25,7 +25,7 @@ pub struct MpmcQueue {
25
25
 
26
26
  impl MpmcQueue {
27
27
  pub fn new(buffer_size: usize, default: VALUE) -> Self {
28
- let mut buffer = Vec::with_capacity(buffer_size.next_power_of_two());
28
+ let mut buffer = Vec::with_capacity(buffer_size);
29
29
  for i in 0..buffer_size {
30
30
  buffer.push(QueueElement {
31
31
  sequence: AtomicUsize::new(i),
@@ -42,7 +42,7 @@ impl MpmcQueue {
42
42
 
43
43
  Self {
44
44
  buffer,
45
- buffer_mask: buffer_size - 1,
45
+ buffer_size,
46
46
  enqueue_pos: AtomicUsize::new(0),
47
47
  dequeue_pos: AtomicUsize::new(0),
48
48
  gc_guard,
@@ -55,7 +55,7 @@ impl MpmcQueue {
55
55
  let mut cell;
56
56
  let mut pos = self.enqueue_pos.load(Ordering::Relaxed);
57
57
  loop {
58
- cell = &self.buffer[pos & self.buffer_mask];
58
+ cell = &self.buffer[pos % self.buffer_size];
59
59
  let seq = cell.sequence.load(Ordering::Acquire);
60
60
  let diff = seq as isize - pos as isize;
61
61
  if diff == 0 {
@@ -82,7 +82,7 @@ impl MpmcQueue {
82
82
  let mut cell;
83
83
  let mut pos = self.dequeue_pos.load(Ordering::Relaxed);
84
84
  loop {
85
- cell = &self.buffer[pos & self.buffer_mask];
85
+ cell = &self.buffer[pos % self.buffer_size];
86
86
  let seq = cell.sequence.load(Ordering::Acquire);
87
87
  let diff = seq as isize - (pos + 1) as isize;
88
88
  if diff == 0 {
@@ -102,7 +102,7 @@ impl MpmcQueue {
102
102
 
103
103
  let data = cell.data.get();
104
104
  cell.sequence
105
- .store(pos + self.buffer_mask + 1, Ordering::Release);
105
+ .store(pos + self.buffer_size, Ordering::Release);
106
106
  self.write_sem.post();
107
107
 
108
108
  #[cfg(feature = "simulation")]
@@ -131,7 +131,7 @@ impl MpmcQueue {
131
131
 
132
132
  pub fn peek(&self) -> Option<VALUE> {
133
133
  let pos = self.dequeue_pos.load(Ordering::Relaxed);
134
- let cell = &self.buffer[pos & self.buffer_mask];
134
+ let cell = &self.buffer[pos % self.buffer_size];
135
135
  let seq = cell.sequence.load(Ordering::Acquire);
136
136
  let diff = seq as isize - (pos + 1) as isize;
137
137
  if diff == 0 {
@@ -1,8 +1,38 @@
1
1
  # frozen_string_literal: true
2
2
 
3
3
  module Ratomic
4
- # Ruby convenience methods for {Counter}.
5
- module CounterMethods
4
+ # A Ractor-shareable atomic counter.
5
+ #
6
+ # Counter stores an unsigned integer in native Rust atomics and can be shared
7
+ # safely across Ractors.
8
+ #
9
+ # @example Count work across Ractors
10
+ # counter = Ratomic::Counter.new
11
+ # counter.increment(1)
12
+ # counter.read # => 1
13
+ #
14
+ # @!method self.new
15
+ # Create a counter initialized to zero.
16
+ #
17
+ # @return [Ratomic::Counter]
18
+ #
19
+ # @!method read
20
+ # Read the current counter value.
21
+ #
22
+ # @return [Integer]
23
+ #
24
+ # @!method increment(amt)
25
+ # Increment the counter by +amt+.
26
+ #
27
+ # @param amt [Integer]
28
+ # @return [void]
29
+ #
30
+ # @!method decrement(amt)
31
+ # Decrement the counter by +amt+.
32
+ #
33
+ # @param amt [Integer]
34
+ # @return [void]
35
+ class Counter
6
36
  # Read the current counter value.
7
37
  #
8
38
  # @return [Integer]
@@ -23,29 +53,5 @@ module Ratomic
23
53
  def zero?
24
54
  read.zero?
25
55
  end
26
-
27
- # Increment the counter.
28
- #
29
- # @param amt [Integer] amount to add
30
- # @raise [ArgumentError] if +amt+ is negative
31
- # @return [void]
32
- def inc(amt = 1)
33
- raise ArgumentError, "amount must be positive: #{amt}" if amt.negative?
34
-
35
- increment(amt)
36
- end
37
-
38
- # Decrement the counter.
39
- #
40
- # @param amt [Integer] amount to subtract
41
- # @raise [ArgumentError] if +amt+ is negative
42
- # @return [void]
43
- def dec(amt = 1)
44
- raise ArgumentError, "amount must be positive: #{amt}" if amt.negative?
45
-
46
- decrement(amt)
47
- end
48
56
  end
49
-
50
- Counter.prepend(CounterMethods)
51
57
  end