Skip to content

Commit

Permalink
Get compress performance to match paper algorithm 4 (#3)
Browse files Browse the repository at this point in the history
This gets us close to 2-3 cycles per byte or so that they reference in
the paper for predicated scalar compression.


![image](https://github.com/user-attachments/assets/5e0c6c24-cb71-435d-ae5c-51f291018f94)

^ the benchmark is compression on string with length 50, so compression
is roughly 1-2ns per byte (roughly 3-5 cycles on my M2)
  • Loading branch information
a10y committed Aug 15, 2024
1 parent 851eb96 commit 31351ca
Show file tree
Hide file tree
Showing 15 changed files with 735 additions and 177 deletions.
7 changes: 0 additions & 7 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,9 +1,2 @@
/target
.idea/


# Added by cargo
#
# already existing elements were commented out

#/target
13 changes: 11 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 14 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
[package]
name = "fsst-rs"
version = "0.0.1"
description = "Pure-Rust implementation of Fast Static Symbol Tables algorithm for string compression"
authors = ["SpiralDB Developers <[email protected]>"]
license = "Apache-2.0"
repository = "https://github.com/spiraldb/fsst"
edition = "2021"

[lints.rust]
Expand All @@ -22,7 +26,16 @@ use_debug = { level = "deny" }
criterion = "0.5"
lz4 = "1"

[[example]]
name = "round_trip"
bench = false
test = false

[[bench]]
name = "compress"
harness = false
bench = true

[[test]]
name = "correctness"
test = true
bench = false
36 changes: 7 additions & 29 deletions benches/compress.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
//!
//! Also contains LZ4 baseline.
#![allow(missing_docs)]
use core::str;
use std::io::{Cursor, Read, Write};

use criterion::{black_box, criterion_group, criterion_main, Criterion};
use lz4::liblz4::BlockChecksum;
use lz4::{BlockSize, ContentChecksum};

use fsst_rs::{train, Code};
use fsst_rs::{train, ESCAPE_CODE};

const CORPUS: &str = include_str!("dracula.txt");
const TEST: &str = "I found my smattering of German very useful here";
Expand All @@ -26,17 +27,17 @@ fn bench_fsst(c: &mut Criterion) {
let plaintext = TEST.as_bytes();

let compressed = table.compress(plaintext);
let escape_count = compressed
.iter()
.filter(|b| **b == Code::ESCAPE_CODE)
.count();
let escape_count = compressed.iter().filter(|b| **b == ESCAPE_CODE).count();
let ratio = (plaintext.len() as f64) / (compressed.len() as f64);
println!(
"Escapes = {escape_count}/{}, compression_ratio = {ratio}",
compressed.len()
);

assert_eq!(table.decompress(&compressed), TEST.as_bytes());
let decompressed = table.decompress(&compressed);
let decompressed = str::from_utf8(&decompressed).unwrap();
println!("DECODED: {}", decompressed);
assert_eq!(decompressed, TEST);

group.bench_function("compress-single", |b| {
b.iter(|| black_box(table.compress(black_box(plaintext))));
Expand All @@ -50,29 +51,6 @@ fn bench_fsst(c: &mut Criterion) {
fn bench_lz4(c: &mut Criterion) {
let mut group = c.benchmark_group("lz4");

// {
// let compressed = Vec::with_capacity(10_000);
// let mut encoder = lz4::EncoderBuilder::new()
// .block_size(BlockSize::Max64KB)
// .build(compressed)
// .unwrap();
//
// encoder.write_all(TEST.as_bytes()).unwrap();
// let (compressed, result) = encoder.finish();
// result.unwrap();
//
// let ratio = (TEST.as_bytes().len() as f64) / (compressed.len() as f64);
// println!("LZ4 compress_ratio = {ratio}");
//
// // ensure decodes cleanly
// let cursor = Cursor::new(compressed);
// let mut decoder = lz4::Decoder::new(cursor).unwrap();
// let mut output = String::new();
//
// decoder.read_to_string(&mut output).unwrap();
// assert_eq!(output.as_str(), TEST);
// }

group.bench_function("compress-single", |b| {
let mut compressed = Vec::with_capacity(100_000_000);
let mut encoder = lz4::EncoderBuilder::new()
Expand Down
70 changes: 70 additions & 0 deletions examples/file_compressor.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
#![allow(missing_docs, clippy::use_debug)]

//! This is a command line program that expects two input files as arguments.
//!
//! The first is the file to train a symbol table on.
//!
//! The second is the file to compress. The compressor will run and compress
//! in chunks of 16MB, logging the compression ratio for each chunk.
//!
//! Example:
//!
//! ```
//! cargo run --release --example file_compressor -- file1.csv file2.csv
//! ```
use std::{
fs::File,
io::Read,
os::unix::fs::{FileExt, MetadataExt},
path::Path,
};

fn main() {
let args: Vec<_> = std::env::args().skip(1).collect();
assert!(args.len() >= 2, "args TRAINING and FILE must be provided");

let train_path = Path::new(&args[0]);
let input_path = Path::new(&args[1]);

let mut train_bytes = Vec::new();
{
let mut f = File::open(train_path).unwrap();
f.read_to_end(&mut train_bytes).unwrap();
}

println!("building the compressor from {train_path:?}...");
let compressor = fsst_rs::train(&train_bytes);

println!("compressing blocks of {input_path:?} with compressor...");

let f = File::open(input_path).unwrap();
let size_bytes = f.metadata().unwrap().size() as usize;

const CHUNK_SIZE: usize = 16 * 1024 * 1024;

let mut chunk_idx = 1;
let mut pos = 0;
let mut chunk = vec![0u8; CHUNK_SIZE];
while pos + CHUNK_SIZE < size_bytes {
f.read_exact_at(&mut chunk, pos as u64).unwrap();
// Compress the chunk, don't write it anywhere.
let compact = compressor.compress(&chunk);
let compression_ratio = (CHUNK_SIZE as f64) / (compact.len() as f64);
println!("compressed chunk {chunk_idx} with ratio {compression_ratio}");

pos += CHUNK_SIZE;
chunk_idx += 1;
}

// Read last chunk with a new custom-sized buffer.
if pos < size_bytes {
let amount = size_bytes - pos;
chunk = vec![0u8; size_bytes - pos];
f.read_exact_at(&mut chunk, pos as u64).unwrap();
// Compress the chunk, don't write it anywhere.
let compact = compressor.compress(&chunk[0..amount]);
let compression_ratio = (amount as f64) / (compact.len() as f64);
println!("compressed chunk {chunk_idx} with ratio {compression_ratio}");
}
println!("done");
}
19 changes: 19 additions & 0 deletions examples/round_trip.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
//! Simple example where we show round-tripping a string through the static symbol table.

use core::str;

fn main() {
// Train on a sample.
let sample = "the quick brown fox jumped over the lazy dog";
let trained = fsst_rs::train(sample.as_bytes());
let compressed = trained.compress(sample.as_bytes());
println!("compressed: {} => {}", sample.len(), compressed.len());
// decompress now
let decode = trained.decompress(&compressed);
let output = str::from_utf8(&decode).unwrap();
println!(
"decoded to the original: len={} text='{}'",
decode.len(),
output
);
}
3 changes: 1 addition & 2 deletions rust-toolchain.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
[toolchain]
channel = "nightly-2024-06-19"
channel = "nightly-2024-08-14"
components = ["rust-src", "rustfmt", "clippy"]
profile = "minimal"

67 changes: 36 additions & 31 deletions src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
use std::cmp::Ordering;
use std::collections::BinaryHeap;

use crate::{Code, Symbol, SymbolTable};
use crate::find_longest::FindLongestSymbol;
use crate::{Symbol, SymbolTable, MAX_CODE};

#[derive(Debug, Clone)]
struct Counter {
Expand All @@ -21,29 +22,29 @@ struct Counter {
impl Counter {
fn new() -> Self {
Self {
counts1: vec![0; Code::CODE_MAX as usize],
counts2: vec![vec![0; Code::CODE_MAX as usize]; Code::CODE_MAX as usize],
counts1: vec![0; MAX_CODE as usize],
counts2: vec![vec![0; MAX_CODE as usize]; MAX_CODE as usize],
}
}

#[inline]
fn record_count1(&mut self, code1: Code) {
self.counts1[code1.0 as usize] += 1;
fn record_count1(&mut self, code1: u16) {
self.counts1[code1 as usize] += 1;
}

#[inline]
fn record_count2(&mut self, code1: Code, code2: Code) {
self.counts2[code1.0 as usize][code2.0 as usize] += 1;
fn record_count2(&mut self, code1: u16, code2: u16) {
self.counts2[code1 as usize][code2 as usize] += 1;
}

#[inline]
fn count1(&self, code: Code) -> usize {
self.counts1[code.0 as usize]
fn count1(&self, code: u16) -> usize {
self.counts1[code as usize]
}

#[inline]
fn count2(&self, code1: Code, code2: Code) -> usize {
self.counts2[code1.0 as usize][code2.0 as usize]
fn count2(&self, code1: u16, code2: u16) -> usize {
self.counts2[code1 as usize][code2 as usize]
}
}

Expand All @@ -65,6 +66,9 @@ pub fn train(corpus: impl AsRef<[u8]>) -> SymbolTable {
let mut table = SymbolTable::default();
// TODO(aduffy): handle truncating/sampling if corpus > requires sample size.
let sample = corpus.as_ref();
if sample.is_empty() {
return table;
}
for _generation in 0..MAX_GENERATIONS {
let counter = table.compress_count(sample);
table = table.optimize(counter);
Expand All @@ -81,13 +85,13 @@ impl SymbolTable {
let len = sample.len();
let mut prev_code = self.find_longest_symbol(sample);
counter.record_count1(prev_code);
let mut pos = self.symbols[prev_code.0 as usize].len();
let mut pos = self.symbols[prev_code as usize].len();

while pos < len {
let code = self.find_longest_symbol(&sample[pos..len]);
counter.record_count1(code);
counter.record_count2(prev_code, code);
pos += self.symbols[code.0 as usize].len();
pos += self.symbols[code as usize].len();
prev_code = code;
}

Expand All @@ -100,17 +104,15 @@ impl SymbolTable {
let mut res = SymbolTable::default();
let mut pqueue = BinaryHeap::new();
for code1 in 0..511 {
let code1 = Code::from_u16(code1);
let symbol1 = self.symbols[code1.0 as usize];
let symbol1 = self.symbols[code1 as usize];
let gain = counters.count1(code1) * symbol1.len();
pqueue.push(Candidate {
symbol: symbol1,
gain,
});

for code2 in 0..511 {
let code2 = Code::from_u16(code2);
let symbol2 = &self.symbols[code2.0 as usize];
let symbol2 = &self.symbols[code2 as usize];
// If either symbol is zero-length, or if merging would yield a symbol of
// length greater than 8, skip.
if symbol1.len() + symbol2.len() >= 8 || symbol1.is_empty() || symbol2.is_empty() {
Expand All @@ -133,10 +135,13 @@ impl SymbolTable {
}

// Pop the 255 best symbols.
pqueue
.iter()
.take(255)
.for_each(|candidate| res.insert(candidate.symbol));
let mut n_symbols = 0;
while !pqueue.is_empty() && n_symbols < 255 {
let candidate = pqueue.pop().unwrap();
if res.insert(candidate.symbol) {
n_symbols += 1;
}
}

res
}
Expand Down Expand Up @@ -181,7 +186,7 @@ impl Ord for Candidate {

#[cfg(test)]
mod test {
use crate::{train, Code};
use crate::{train, ESCAPE_CODE};

#[test]
fn test_builder() {
Expand All @@ -193,26 +198,26 @@ mod test {
let compressed = table.compress(text.as_bytes());

// Ensure that the compressed string has no escape bytes
assert!(compressed.iter().all(|b| *b != Code::ESCAPE_CODE));
assert!(compressed.iter().all(|b| *b != ESCAPE_CODE));

// Ensure that we can compress a string with no values seen at training time.
// Ensure that we can compress a string with no values seen at training time, with escape bytes
let compressed = table.compress("xyz123".as_bytes());
assert_eq!(
compressed,
vec![
Code::ESCAPE_CODE,
ESCAPE_CODE,
b'x',
Code::ESCAPE_CODE,
ESCAPE_CODE,
b'y',
Code::ESCAPE_CODE,
ESCAPE_CODE,
b'z',
Code::ESCAPE_CODE,
ESCAPE_CODE,
b'1',
Code::ESCAPE_CODE,
ESCAPE_CODE,
b'2',
Code::ESCAPE_CODE,
ESCAPE_CODE,
b'3',
]
)
);
}
}
5 changes: 5 additions & 0 deletions src/find_longest/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
mod naive;

pub trait FindLongestSymbol {
fn find_longest_symbol(&self, text: &[u8]) -> u16;
}
Loading

0 comments on commit 31351ca

Please sign in to comment.