Fix corruption in ZWinderBuffer when using Stream.ReadByte / Stream.WriteByte overloads (#2630)

The ZWinderBuffer implementations of Stream.ReadByte and Stream.WriteByte could process data incorrectly in certain circumstances.  This had been broken since f4e98fd.

ReadByte: When the first read from a state stream was a ReadByte, the underlying buffer would be in the wrong place
WriteByte: If a state was evicted and then the eviction was immediately followed by a WriteByte, the underlying buffer would be in the wrong place.

This impacts pretty heavily the rewinder and tasstatemanager for any core whose save and/or loadstate methods happened to use those methods.
This commit is contained in:
nattthebear 2021-02-20 12:21:56 -05:00 committed by GitHub
parent 7923b4c8ef
commit 13b7b43db6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 134 additions and 5 deletions

View File

@ -439,10 +439,9 @@ namespace BizHawk.Client.Common
long requestedSize = _position + 1;
while (requestedSize > _notifySize)
_notifySize = _notifySizeReached();
_backingStore.Position = (_position + _offset) & _mask;
_backingStore.WriteByte(value);
_position++;
if (_position + _offset == BufferLength)
_backingStore.Position = 0;
}
}
@ -455,6 +454,7 @@ namespace BizHawk.Client.Common
_offset = offset;
_size = size;
_mask = mask;
_backingStore.Position = _offset;
}
private readonly Stream _backingStore;
@ -489,7 +489,6 @@ namespace BizHawk.Client.Common
{
var start = (_position + _offset) & _mask;
var end = (start + n) & _mask;
_backingStore.Position = start;
if (end < start)
{
long m = BufferLength - start;

View File

@ -1,4 +1,6 @@
namespace BizHawk.Common
using System;
namespace BizHawk.Common
{
// we could get a little list of crcs from here and make it clear which crc this class was for, and expose others
// http://www.ross.net/crc/download/crc_v3.txt
@ -30,7 +32,7 @@
}
}
public static int Calculate(byte[] data)
public static int Calculate(ReadOnlySpan<byte> data)
{
uint result = 0xFFFFFFFF;
foreach (var b in data)

View File

@ -1,6 +1,8 @@
using System;
using System.IO;
using System.Linq;
using BizHawk.Client.Common;
using BizHawk.Common;
using BizHawk.Emulation.Common;
using Microsoft.VisualStudio.TestTools.UnitTesting;
@ -443,6 +445,132 @@ namespace BizHawk.Tests.Client.Common.Movie
zw.SaveStateBinary(new BinaryWriter(new MemoryStream()));
}
[TestMethod]
public void TestReadByteCorruption()
{
using var zw = new ZwinderBuffer(new RewindConfig
{
BufferSize = 1,
TargetFrameLength = 1
});
zw.Capture(0, s =>
{
s.Write(new byte[] { 1, 2, 3, 4 });
});
zw.Capture(1, s =>
{
s.Write(new byte[] { 5, 6, 7, 8 });
});
var state = zw.GetState(0);
Assert.AreEqual(0, state.Frame);
Assert.AreEqual(4, state.Size);
Assert.AreEqual(1, state.GetReadStream().ReadByte());
}
[TestMethod]
public void TestReadBytesCorruption()
{
using var zw = new ZwinderBuffer(new RewindConfig
{
BufferSize = 1,
TargetFrameLength = 1
});
zw.Capture(0, s =>
{
s.Write(new byte[] { 1, 2, 3, 4 });
});
zw.Capture(1, s =>
{
s.Write(new byte[] { 5, 6, 7, 8 });
});
var state = zw.GetState(0);
Assert.AreEqual(0, state.Frame);
Assert.AreEqual(4, state.Size);
var bb = new byte[2];
state.GetReadStream().Read(bb, 0, 2);
Assert.AreEqual(1, bb[0]);
Assert.AreEqual(2, bb[1]);
}
[TestMethod]
public void TestWriteByteCorruption()
{
using var zw = new ZwinderBuffer(new RewindConfig
{
BufferSize = 1,
TargetFrameLength = 1
});
zw.Capture(0, s =>
{
s.WriteByte(1);
s.WriteByte(2);
s.WriteByte(3);
s.WriteByte(4);
});
zw.Capture(1, s =>
{
s.WriteByte(5);
s.WriteByte(6);
s.WriteByte(7);
s.WriteByte(8);
});
zw.GetState(0).GetReadStream(); // Rewinds the backing store
zw.Capture(2, s =>
{
s.WriteByte(9);
s.WriteByte(10);
s.WriteByte(11);
s.WriteByte(12);
});
var state = zw.GetState(0);
Assert.AreEqual(0, state.Frame);
Assert.AreEqual(4, state.Size);
Assert.AreEqual(1, state.GetReadStream().ReadByte());
}
[TestMethod]
public void BufferStressTest()
{
var r = new Random(8675309);
using var zw = new ZwinderBuffer(new RewindConfig
{
BufferSize = 1,
TargetFrameLength = 1
});
var buff = new byte[40000];
for (int round = 0; round < 10; round++)
{
for (int i = 0; i < 500; i++)
{
zw.Capture(i, s =>
{
var length = r.Next(40000);
var bw = new BinaryWriter(s);
Span<byte> bytes = buff[0..length];
r.NextBytes(bytes);
bw.Write(length);
bw.Write(bytes);
bw.Write(CRC32.Calculate(bytes));
});
}
for (int i = 0; i < zw.Count; i++)
{
var info = zw.GetState(i);
var s = info.GetReadStream();
var br = new BinaryReader(s);
var length = info.Size;
if (length != br.ReadInt32() + 8)
throw new Exception("Length field corrupted");
Span<byte> bytes = buff[0..(length - 8)];
br.Read(bytes);
if (br.ReadInt32() != CRC32.Calculate(bytes))
throw new Exception("Data or CRC field corrupted");
}
}
}
private class StateSource : IStatable
{
public int Frame { get; set; }