Files
bit-set/src/lib.rs
2018-01-16 10:38:00 +01:00

712 lines
16 KiB
Rust

extern crate rayon;
extern crate serde;
extern crate unreachable;
use std::fmt;
use std::ops;
use std::iter::{FromIterator, IntoIterator, Iterator};
use std::default::Default;
use std::collections::{HashMap, LinkedList};
use rayon::iter::{FromParallelIterator, IntoParallelIterator, ParallelExtend, ParallelIterator};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use serde::ser::SerializeSeq;
use serde::de::{SeqAccess, Visitor};
mod hasher;
use hasher::BitBuildHasher;
const BITS: u64 = 64;
type Block = u64;
type Storage = HashMap<u64, Block, BitBuildHasher>;
pub type BitHashMap<V> = HashMap<u64, V, BitBuildHasher>;
#[inline]
fn block_bit(x: &u64, d: &u64) -> (u64, u64) {
(x / d, x % d)
}
#[derive(PartialEq, Eq, Clone)]
pub struct BitSet {
blocks: Storage,
nbits: usize,
}
impl BitSet {
#[inline]
pub fn new() -> BitSet {
BitSet {
blocks: Storage::default(),
nbits: 0,
}
}
#[inline]
pub fn with_capacity(capacity: usize) -> BitSet {
BitSet {
blocks: Storage::with_capacity_and_hasher(capacity / BITS as usize, Default::default()),
nbits: 0,
}
}
#[inline]
pub fn capacity(&self) -> usize {
self.blocks.capacity() * BITS as usize
}
pub fn reserve(&mut self, additional: usize) {
self.blocks.reserve(additional / BITS as usize)
}
pub fn shrink_to_fit(&mut self) {
self.blocks.retain(|_, block| *block != 0);
self.blocks.shrink_to_fit()
}
pub fn iter(&self) -> Iter {
Iter {
iter: self.blocks.iter(),
block: 0,
bits: 0,
bit: BITS,
}
}
#[inline]
pub fn len(&self) -> usize {
self.nbits
}
#[inline]
pub fn is_empty(&self) -> bool {
self.nbits == 0
}
pub fn clear(&mut self) {
self.blocks.clear();
self.nbits = 0;
}
#[inline]
pub fn contains(&self, value: &u64) -> bool {
let (block, bit) = block_bit(value, &BITS);
match self.blocks.get(&block) {
Some(block) => (block & (1 << bit)) != 0,
None => false,
}
}
#[inline]
pub fn is_subset(&self, other: &BitSet) -> bool {
if self.nbits > other.nbits {
false
} else {
self.blocks
.iter()
.all(|(key, block_a)| match other.blocks.get(key) {
Some(block_b) => block_a | block_b == *block_b,
None => *block_a == 0,
})
}
}
#[inline]
pub fn is_superset(&self, other: &BitSet) -> bool {
other.is_subset(self)
}
#[inline]
pub fn insert(&mut self, value: u64) -> bool {
let (block, bit) = block_bit(&value, &BITS);
let block = self.blocks.entry(block).or_insert(0);
let n = 1 << bit;
if (*block & n) == 0 {
*block |= n;
self.nbits += 1;
true
} else {
false
}
}
#[inline]
pub fn remove(&mut self, value: &u64) -> bool {
let (block, bit) = block_bit(value, &BITS);
let block = self.blocks.entry(block).or_insert(0);
let n = 1 << bit;
if (*block & n) != 0 {
*block &= !n;
self.nbits -= 1;
true
} else {
false
}
}
#[inline]
pub fn union_with(&mut self, other: &Self) {
for (key, value) in &other.blocks {
let block = self.blocks.entry(*key).or_insert(0);
let n = block.count_ones();
*block |= value;
self.nbits += (block.count_ones() - n) as usize;
}
}
}
impl fmt::Debug for BitSet {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_set().entries(self.iter()).finish()
}
}
impl FromIterator<u64> for BitSet {
#[inline]
fn from_iter<I: IntoIterator<Item = u64>>(iter: I) -> BitSet {
let mut set = BitSet::new();
set.extend(iter);
set
}
}
impl Extend<u64> for BitSet {
#[inline]
fn extend<I: IntoIterator<Item = u64>>(&mut self, iter: I) {
for i in iter {
self.insert(i);
}
}
}
impl<'a> Extend<&'a u64> for BitSet {
#[inline]
fn extend<I: IntoIterator<Item = &'a u64>>(&mut self, iter: I) {
for i in iter {
self.insert(*i);
}
}
}
impl Default for BitSet {
#[inline]
fn default() -> BitSet {
BitSet::new()
}
}
impl ops::BitOr for BitSet {
type Output = Self;
#[inline]
fn bitor(mut self, rhs: Self) -> Self {
for (key, value) in &rhs.blocks {
let block = self.blocks.entry(*key).or_insert(0);
let n = block.count_ones();
*block |= value;
self.nbits += (block.count_ones() - n) as usize;
}
self
}
}
impl<'a> ops::BitOr<&'a Self> for BitSet {
type Output = Self;
#[inline]
fn bitor(mut self, rhs: &'a Self) -> Self {
for (key, value) in &rhs.blocks {
let block = self.blocks.entry(*key).or_insert(0);
let n = block.count_ones();
*block |= value;
self.nbits += (block.count_ones() - n) as usize;
}
self
}
}
pub struct Iter<'a> {
iter: std::collections::hash_map::Iter<'a, u64, u64>,
block: u64,
bits: u64,
bit: u64,
}
pub struct IntoIter {
iter: std::collections::hash_map::IntoIter<u64, u64>,
block: u64,
bits: u64,
bit: u64,
}
impl<'a> IntoIterator for &'a BitSet {
type Item = u64;
type IntoIter = Iter<'a>;
fn into_iter(self) -> Iter<'a> {
self.iter()
}
}
impl IntoIterator for BitSet {
type Item = u64;
type IntoIter = IntoIter;
fn into_iter(self) -> IntoIter {
IntoIter {
iter: self.blocks.into_iter(),
block: 0,
bits: 0,
bit: BITS,
}
}
}
impl<'a> Iterator for Iter<'a> {
type Item = u64;
#[inline]
fn next(&mut self) -> Option<Self::Item> {
loop {
if self.bits == 0 || self.bit == BITS {
match self.iter.next() {
Some((block, bits)) => {
self.block = *block;
self.bits = *bits;
self.bit = 0;
}
None => return None,
}
}
for i in self.bit..BITS {
if self.bits & (1 << i) != 0 {
self.bit = i + 1;
return Some((self.block * BITS) + i);
}
}
self.bit = BITS;
}
}
}
impl Iterator for IntoIter {
type Item = u64;
#[inline]
fn next(&mut self) -> Option<Self::Item> {
loop {
if self.bits == 0 || self.bit == BITS {
match self.iter.next() {
Some((block, bits)) => {
self.block = block;
self.bits = bits;
self.bit = 0;
}
None => return None,
}
}
for i in self.bit..BITS {
if self.bits & (1 << i) != 0 {
self.bit = i + 1;
return Some((self.block * BITS) + i);
}
}
self.bit = BITS;
}
}
}
impl ParallelExtend<u64> for BitSet {
fn par_extend<I>(&mut self, par_iter: I)
where
I: IntoParallelIterator<Item = u64>,
{
let list = par_iter
.into_par_iter()
.fold(Vec::new, |mut vec, elem| {
vec.push(elem);
vec
})
.collect::<LinkedList<Vec<u64>>>();
let len = list.iter().map(Vec::len).sum();
let mut set = BitSet::with_capacity(len);
for vec in list {
set.extend(vec);
}
}
}
impl FromParallelIterator<u64> for BitSet {
fn from_par_iter<I>(par_iter: I) -> Self
where
I: IntoParallelIterator<Item = u64>,
{
let mut set = BitSet::new();
set.par_extend(par_iter);
set
}
}
impl Serialize for BitSet {
#[inline]
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut seq = serializer.serialize_seq(Some(self.len()))?;
for element in self {
seq.serialize_element(&element)?;
}
seq.end()
}
}
impl<'de> Deserialize<'de> for BitSet {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct SeqVisitor;
impl<'de> Visitor<'de> for SeqVisitor {
type Value = BitSet;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a sequence")
}
#[inline]
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: SeqAccess<'de>,
{
let mut values = BitSet::new();
while let Some(value) = seq.next_element()? {
BitSet::insert(&mut values, value);
}
Ok(values)
}
}
let visitor = SeqVisitor;
deserializer.deserialize_seq(visitor)
}
fn deserialize_in_place<D>(deserializer: D, place: &mut Self) -> Result<(), D::Error>
where
D: Deserializer<'de>,
{
struct SeqInPlaceVisitor<'a>(&'a mut BitSet);
impl<'a, 'de> Visitor<'de> for SeqInPlaceVisitor<'a> {
type Value = ();
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a sequence")
}
#[inline]
fn visit_seq<A>(mut self, mut seq: A) -> Result<(), A::Error>
where
A: SeqAccess<'de>,
{
BitSet::clear(&mut self.0);
while let Some(value) = try!(seq.next_element()) {
BitSet::insert(&mut self.0, value);
}
Ok(())
}
}
deserializer.deserialize_seq(SeqInPlaceVisitor(place))
}
}
#[cfg(test)]
mod tests {
extern crate serde_test;
use self::serde_test::{assert_tokens, Token};
use super::*;
#[test]
fn test_zero_capacities() {
let s = BitSet::new();
assert_eq!(s.capacity(), 0);
let s = BitSet::default();
assert_eq!(s.capacity(), 0);
let s = BitSet::with_capacity(0);
assert_eq!(s.capacity(), 0);
let mut s = BitSet::new();
s.insert(1);
s.insert(2);
s.remove(&1);
s.remove(&2);
s.shrink_to_fit();
assert_eq!(s.capacity(), 0);
let mut s = BitSet::new();
s.reserve(0);
assert_eq!(s.capacity(), 0);
}
#[test]
fn test_subset_and_superset() {
let mut a = BitSet::new();
assert!(a.insert(0));
assert!(a.insert(5));
assert!(a.insert(11));
assert!(a.insert(7));
let mut b = BitSet::new();
assert!(b.insert(0));
assert!(b.insert(7));
assert!(b.insert(19));
assert!(b.insert(250));
assert!(b.insert(11));
assert!(b.insert(200));
assert!(!a.is_subset(&b));
assert!(!a.is_superset(&b));
assert!(!b.is_subset(&a));
assert!(!b.is_superset(&a));
assert!(b.insert(5));
assert!(a.is_subset(&b));
assert!(!a.is_superset(&b));
assert!(!b.is_subset(&a));
assert!(b.is_superset(&a));
let mut a = BitSet::new();
assert!(a.insert(0));
assert!(a.insert(5));
assert!(a.insert(11));
assert!(a.insert(7));
let mut b = BitSet::new();
assert!(b.insert(0));
assert!(b.insert(7));
assert!(b.insert(250));
assert!(!b.is_subset(&a));
b.remove(&250);
assert!(b.is_subset(&a));
}
#[test]
fn test_iterate() {
let mut a = BitSet::new();
for i in 0..32 {
assert!(a.insert(i));
}
let mut observed: u32 = 0;
for k in &a {
observed |= 1 << k;
}
assert_eq!(observed, 0xFFFF_FFFF);
}
#[test]
fn test_from_iter() {
let xs = [1, 2, 3, 4, 5, 6, 7, 8, 9];
let set: BitSet = xs.iter().cloned().collect();
for x in &xs {
assert!(set.contains(x));
}
}
#[test]
fn test_move_iter() {
let hs = {
let mut hs = BitSet::new();
hs.insert(1);
hs.insert(2);
hs
};
let v = hs.into_iter().collect::<Vec<u64>>();
assert_eq!(v, [1, 2]);
}
#[test]
fn test_eq() {
let mut s1 = BitSet::new();
s1.insert(1);
s1.insert(2);
s1.insert(3);
let mut s2 = BitSet::new();
s2.insert(1);
s2.insert(2);
assert!(s1 != s2);
s2.insert(3);
assert_eq!(s1, s2);
}
#[test]
fn test_show() {
let mut set = BitSet::new();
let empty = BitSet::new();
set.insert(1);
set.insert(2);
let set_str = format!("{:?}", set);
assert_eq!(set_str, "{1, 2}");
assert_eq!(format!("{:?}", empty), "{}");
}
#[test]
fn test_extend_ref() {
let mut a = BitSet::new();
a.insert(1);
a.extend(&[2, 3, 4]);
assert_eq!(a.len(), 4);
assert!(a.contains(&1));
assert!(a.contains(&2));
assert!(a.contains(&3));
assert!(a.contains(&4));
let mut b = BitSet::new();
b.insert(5);
b.insert(6);
a.extend(&b);
assert_eq!(a.len(), 6);
assert!(a.contains(&1));
assert!(a.contains(&2));
assert!(a.contains(&3));
assert!(a.contains(&4));
assert!(a.contains(&5));
assert!(a.contains(&6));
}
// -------------------------
#[test]
fn test_insert() {
let mut set = BitSet::with_capacity(10);
assert_eq!(set.contains(&0), false);
assert_eq!(set.contains(&10), false);
set.insert(0);
set.insert(10);
assert_eq!(set.contains(&0), true);
assert_eq!(set.contains(&10), true);
assert_eq!(set.contains(&100), false);
set.insert(100);
assert_eq!(set.contains(&100), true);
}
#[test]
fn test_bitor() {
let set_a = [1, 2, 3].iter().cloned().collect::<BitSet>();
let set_b = [3, 4, 5].iter().cloned().collect::<BitSet>();
let set = set_a | set_b;
assert_eq!(set.len(), 5);
assert_eq!(set.contains(&1), true);
assert_eq!(set.contains(&2), true);
assert_eq!(set.contains(&3), true);
assert_eq!(set.contains(&4), true);
assert_eq!(set.contains(&5), true);
}
#[test]
fn test_union_with() {
let mut set = [1, 2, 3].iter().cloned().collect::<BitSet>();
let other = [3, 4, 5].iter().cloned().collect::<BitSet>();
set.union_with(&other);
assert_eq!(set.len(), 5);
assert_eq!(set.contains(&1), true);
assert_eq!(set.contains(&2), true);
assert_eq!(set.contains(&3), true);
assert_eq!(set.contains(&4), true);
assert_eq!(set.contains(&5), true);
}
#[test]
fn test_serde_serialize() {
let mut set = BitSet::new();
set.insert(20);
set.insert(10);
set.insert(30);
assert_tokens(
&set,
&[
Token::Seq { len: Some(3) },
Token::U64(10),
Token::U64(20),
Token::U64(30),
Token::SeqEnd,
],
);
}
}