diff --git a/src/reader.rs b/src/reader.rs index d60b5b16..e9bf312e 100644 --- a/src/reader.rs +++ b/src/reader.rs @@ -319,7 +319,7 @@ impl<'a, R: Read + Seek> Reader<'a, R> { } let bits_read = amt * 8; - self.last_bits_read_amt += bits_read; + self.last_bits_read_amt = bits_read; self.bits_read += bits_read; #[cfg(feature = "logging")] @@ -403,7 +403,7 @@ impl<'a, R: Read + Seek> Reader<'a, R> { return Err(DekuError::Io(e.kind())); } - self.last_bits_read_amt += N * 8; + self.last_bits_read_amt = N * 8; self.bits_read += N * 8; #[cfg(feature = "logging")] @@ -518,4 +518,52 @@ mod tests { let _ = reader.read_bytes(1, &mut buf); assert_eq!([0xaa], buf); } + + #[test] + fn test_seek_last_read_bytes() { + // bytes + let input = hex!("aa"); + let mut cursor = Cursor::new(input); + let mut reader = Reader::new(&mut cursor); + let mut buf = [0; 1]; + let _ = reader.read_bytes(1, &mut buf); + assert_eq!([0xaa], buf); + reader.seek_last_read().unwrap(); + let _ = reader.read_bytes(1, &mut buf); + assert_eq!([0xaa], buf); + + // 2 bytes (and const) + let input = hex!("aabb"); + let mut cursor = Cursor::new(input); + let mut reader = Reader::new(&mut cursor); + let mut buf = [0; 2]; + let _ = reader.read_bytes_const::<2>(&mut buf); + assert_eq!([0xaa, 0xbb], buf); + reader.seek_last_read().unwrap(); + let _ = reader.read_bytes_const::<2>(&mut buf); + assert_eq!([0xaa, 0xbb], buf); + } + + #[cfg(feature = "bits")] + #[test] + fn test_seek_last_read_bits() { + let input = hex!("ab"); + let mut cursor = Cursor::new(input); + let mut reader = Reader::new(&mut cursor); + let bits = reader.read_bits(4).unwrap(); + assert_eq!(bits, Some(bitvec![u8, Msb0; 1, 0, 1, 0])); + reader.seek_last_read().unwrap(); + let bits = reader.read_bits(4).unwrap(); + assert_eq!(bits, Some(bitvec![u8, Msb0; 1, 0, 1, 0])); + + // more than byte + let input = hex!("abd0"); + let mut cursor = Cursor::new(input); + let mut reader = Reader::new(&mut cursor); + let bits = reader.read_bits(9).unwrap(); + assert_eq!(bits, Some(bitvec![u8, Msb0; 1, 0, 1, 0, 1, 0, 1, 1, 1])); + reader.seek_last_read().unwrap(); + let bits = reader.read_bits(9).unwrap(); + assert_eq!(bits, Some(bitvec![u8, Msb0; 1, 0, 1, 0, 1, 0, 1, 1, 1])); + } }