diff --git a/src/BizHawk.Client.Common/rewind/ZwinderBuffer.cs b/src/BizHawk.Client.Common/rewind/ZwinderBuffer.cs index 2dd914000f..7496d65d96 100644 --- a/src/BizHawk.Client.Common/rewind/ZwinderBuffer.cs +++ b/src/BizHawk.Client.Common/rewind/ZwinderBuffer.cs @@ -439,10 +439,9 @@ namespace BizHawk.Client.Common long requestedSize = _position + 1; while (requestedSize > _notifySize) _notifySize = _notifySizeReached(); + _backingStore.Position = (_position + _offset) & _mask; _backingStore.WriteByte(value); _position++; - if (_position + _offset == BufferLength) - _backingStore.Position = 0; } } @@ -455,6 +454,7 @@ namespace BizHawk.Client.Common _offset = offset; _size = size; _mask = mask; + _backingStore.Position = _offset; } private readonly Stream _backingStore; @@ -489,7 +489,6 @@ namespace BizHawk.Client.Common { var start = (_position + _offset) & _mask; var end = (start + n) & _mask; - _backingStore.Position = start; if (end < start) { long m = BufferLength - start; diff --git a/src/BizHawk.Common/CRC32.cs b/src/BizHawk.Common/CRC32.cs index 564f63fe7b..1be0360aa0 100644 --- a/src/BizHawk.Common/CRC32.cs +++ b/src/BizHawk.Common/CRC32.cs @@ -1,4 +1,6 @@ -namespace BizHawk.Common +using System; + +namespace BizHawk.Common { // we could get a little list of crcs from here and make it clear which crc this class was for, and expose others // http://www.ross.net/crc/download/crc_v3.txt @@ -30,7 +32,7 @@ } } - public static int Calculate(byte[] data) + public static int Calculate(ReadOnlySpan data) { uint result = 0xFFFFFFFF; foreach (var b in data) diff --git a/src/BizHawk.Tests/Client.Common/Movie/ZwinderStateManagerTests.cs b/src/BizHawk.Tests/Client.Common/Movie/ZwinderStateManagerTests.cs index 92d75ee07e..032b18b670 100644 --- a/src/BizHawk.Tests/Client.Common/Movie/ZwinderStateManagerTests.cs +++ b/src/BizHawk.Tests/Client.Common/Movie/ZwinderStateManagerTests.cs @@ -1,6 +1,8 @@ +using System; using System.IO; using System.Linq; using BizHawk.Client.Common; +using BizHawk.Common; using BizHawk.Emulation.Common; using Microsoft.VisualStudio.TestTools.UnitTesting; @@ -443,6 +445,132 @@ namespace BizHawk.Tests.Client.Common.Movie zw.SaveStateBinary(new BinaryWriter(new MemoryStream())); } + [TestMethod] + public void TestReadByteCorruption() + { + using var zw = new ZwinderBuffer(new RewindConfig + { + BufferSize = 1, + TargetFrameLength = 1 + }); + zw.Capture(0, s => + { + s.Write(new byte[] { 1, 2, 3, 4 }); + }); + zw.Capture(1, s => + { + s.Write(new byte[] { 5, 6, 7, 8 }); + }); + var state = zw.GetState(0); + Assert.AreEqual(0, state.Frame); + Assert.AreEqual(4, state.Size); + Assert.AreEqual(1, state.GetReadStream().ReadByte()); + } + + [TestMethod] + public void TestReadBytesCorruption() + { + using var zw = new ZwinderBuffer(new RewindConfig + { + BufferSize = 1, + TargetFrameLength = 1 + }); + zw.Capture(0, s => + { + s.Write(new byte[] { 1, 2, 3, 4 }); + }); + zw.Capture(1, s => + { + s.Write(new byte[] { 5, 6, 7, 8 }); + }); + var state = zw.GetState(0); + Assert.AreEqual(0, state.Frame); + Assert.AreEqual(4, state.Size); + var bb = new byte[2]; + state.GetReadStream().Read(bb, 0, 2); + Assert.AreEqual(1, bb[0]); + Assert.AreEqual(2, bb[1]); + } + + [TestMethod] + public void TestWriteByteCorruption() + { + using var zw = new ZwinderBuffer(new RewindConfig + { + BufferSize = 1, + TargetFrameLength = 1 + }); + zw.Capture(0, s => + { + s.WriteByte(1); + s.WriteByte(2); + s.WriteByte(3); + s.WriteByte(4); + }); + zw.Capture(1, s => + { + s.WriteByte(5); + s.WriteByte(6); + s.WriteByte(7); + s.WriteByte(8); + }); + zw.GetState(0).GetReadStream(); // Rewinds the backing store + zw.Capture(2, s => + { + s.WriteByte(9); + s.WriteByte(10); + s.WriteByte(11); + s.WriteByte(12); + }); + + var state = zw.GetState(0); + Assert.AreEqual(0, state.Frame); + Assert.AreEqual(4, state.Size); + Assert.AreEqual(1, state.GetReadStream().ReadByte()); + } + + [TestMethod] + public void BufferStressTest() + { + var r = new Random(8675309); + using var zw = new ZwinderBuffer(new RewindConfig + { + BufferSize = 1, + TargetFrameLength = 1 + }); + var buff = new byte[40000]; + + for (int round = 0; round < 10; round++) + { + for (int i = 0; i < 500; i++) + { + zw.Capture(i, s => + { + var length = r.Next(40000); + var bw = new BinaryWriter(s); + Span bytes = buff[0..length]; + r.NextBytes(bytes); + bw.Write(length); + bw.Write(bytes); + bw.Write(CRC32.Calculate(bytes)); + }); + } + for (int i = 0; i < zw.Count; i++) + { + var info = zw.GetState(i); + var s = info.GetReadStream(); + var br = new BinaryReader(s); + var length = info.Size; + if (length != br.ReadInt32() + 8) + throw new Exception("Length field corrupted"); + Span bytes = buff[0..(length - 8)]; + br.Read(bytes); + if (br.ReadInt32() != CRC32.Calculate(bytes)) + throw new Exception("Data or CRC field corrupted"); + } + } + } + private class StateSource : IStatable { public int Frame { get; set; }