diff --git a/src/BizHawk.Client.Common/movie/tasproj/ZwinderStateManager.cs b/src/BizHawk.Client.Common/movie/tasproj/ZwinderStateManager.cs index 0a8009a49f..67bc5afd36 100644 --- a/src/BizHawk.Client.Common/movie/tasproj/ZwinderStateManager.cs +++ b/src/BizHawk.Client.Common/movie/tasproj/ZwinderStateManager.cs @@ -2,6 +2,7 @@ using System; using System.Collections.Generic; using System.IO; using System.Linq; +using BizHawk.Common; using BizHawk.Emulation.Common; namespace BizHawk.Client.Common @@ -11,7 +12,7 @@ namespace BizHawk.Client.Common private static readonly byte[] NonState = new byte[0]; private readonly Func _reserveCallback; - internal readonly SortedSet StateCache = new SortedSet(); + internal readonly SortedList StateCache = new SortedList(); private ZwinderBuffer _current; private ZwinderBuffer _recent; @@ -45,7 +46,7 @@ namespace BizHawk.Client.Common if (!_reserved.ContainsKey(0)) { _reserved.Add(0, frameZeroState); - StateCache.Add(0); + AddStateCache(0); } } @@ -149,7 +150,9 @@ namespace BizHawk.Client.Common { StateCache.Clear(); foreach (StateInfo state in AllStates()) - StateCache.Add(state.Frame); + { + AddStateCache(state.Frame); + } } public int Count => _current.Count + _recent.Count + _gapFiller.Count + _reserved.Count; @@ -227,7 +230,7 @@ namespace BizHawk.Client.Common var ms = new MemoryStream(); source.SaveStateBinary(new BinaryWriter(ms)); _reserved.Add(frame, ms.ToArray()); - StateCache.Add(frame); + AddStateCache(frame); } private void AddToReserved(ZwinderBuffer.StateInformation state) @@ -241,7 +244,15 @@ namespace BizHawk.Client.Common var ms = new MemoryStream(bb); state.GetReadStream().CopyTo(ms); _reserved.Add(state.Frame, bb); - StateCache.Add(state.Frame); + AddStateCache(state.Frame); + } + + private void AddStateCache(int frame) + { + if (!StateCache.Contains(frame)) + { + StateCache.Add(frame); + } } public void EvictReserved(int frame) @@ -285,7 +296,7 @@ namespace BizHawk.Client.Common s => { source.SaveStateBinary(new BinaryWriter(s)); - StateCache.Add(frame); + AddStateCache(frame); }, index => { @@ -303,7 +314,7 @@ namespace BizHawk.Client.Common s => { state.GetReadStream().CopyTo(s); - StateCache.Add(state.Frame); + AddStateCache(state.Frame); }, index2 => { @@ -369,7 +380,7 @@ namespace BizHawk.Client.Common _gapFiller.Capture( frame, s => { - StateCache.Add(frame); + AddStateCache(frame); source.SaveStateBinary(new BinaryWriter(s)); }, index => StateCache.Remove(index)); @@ -381,7 +392,7 @@ namespace BizHawk.Client.Common _recent.InvalidateEnd(0); _gapFiller.InvalidateEnd(0); StateCache.Clear(); - StateCache.Add(0); + AddStateCache(0); _reserved = _reserved .Where(kvp => kvp.Key == 0) .ToDictionary(kvp => kvp.Key, kvp => kvp.Value); @@ -409,7 +420,7 @@ namespace BizHawk.Client.Common if (state.Frame > frame) { var last = GapStates().First(); - StateCache.RemoveWhere(s => s >= state.Frame && s <= last.Frame); // TODO: be consistent, other invalidate methods do not touch cache and it is addressed in the public InvalidateAfter + StateCache.RemoveAll(s => s >= state.Frame && s <= last.Frame); // TODO: be consistent, other invalidate methods do not touch cache and it is addressed in the public InvalidateAfter _gapFiller.InvalidateEnd(i); return true; @@ -458,7 +469,7 @@ namespace BizHawk.Client.Common var b1 = InvalidateNormal(frame); var b2 = InvalidateGaps(frame); var b3 = InvalidateReserved(frame); - StateCache.RemoveWhere(s => s > frame); + StateCache.RemoveAfter(frame); return b1 || b2 || b3; } diff --git a/src/BizHawk.Common/CustomCollections.cs b/src/BizHawk.Common/CustomCollections.cs index 6367488f52..265f197778 100644 --- a/src/BizHawk.Common/CustomCollections.cs +++ b/src/BizHawk.Common/CustomCollections.cs @@ -36,6 +36,89 @@ namespace BizHawk.Common public IEnumerator>> GetKVPEnumerator() => dictionary.GetEnumerator(); } + public class SortedList : ICollection + where T : IComparable + { + protected readonly List _list; + + public virtual int Count => _list.Count; + + public virtual bool IsReadOnly { get; } = false; + + public SortedList() => _list = new List(); + + public SortedList(IEnumerable collection) + { + _list = new List(collection); + _list.Sort(); + } + + public virtual T this[int index] => _list[index]; + + public virtual void Add(T item) + { + var i = _list.BinarySearch(item); + _list.Insert(i < 0 ? ~i : i, item); + } + + public virtual int BinarySearch(T item) => _list.BinarySearch(item); + + public virtual void Clear() => _list.Clear(); + + public virtual bool Contains(T item) => !(_list.BinarySearch(item) < 0); // can't use `!= -1`, BinarySearch can return multiple negative values + + public virtual void CopyTo(T[] array, int arrayIndex) => _list.CopyTo(array, arrayIndex); + + public virtual IEnumerator GetEnumerator() => _list.GetEnumerator(); + + public virtual int IndexOf(T item) + { + var i = _list.BinarySearch(item); + return i < 0 ? -1 : i; + } + + public virtual bool Remove(T item) + { +#if true + var i = _list.BinarySearch(item); + if (i < 0) return false; + _list.RemoveAt(i); + return true; +#else //TODO is this any slower? + return _list.Remove(item); +#endif + } + + + public virtual int RemoveAll(Predicate match) => _list.RemoveAll(match); + + public virtual void RemoveAt(int index) => _list.RemoveAt(index); + + /// Remove all items after the specific item (but not the given item). + public virtual void RemoveAfter(T item) + { + var startIndex = _list.BinarySearch(item); + if (startIndex < 0) + { + // If BinarySearch doesn't find the item, + // it returns the bitwise complement of the index of the next element + // that is larger than item + startIndex = ~startIndex; + } + else + { + // All items *after* the item + startIndex = startIndex + 1; + } + if (startIndex < _list.Count) + { + _list.RemoveRange(startIndex, _list.Count - startIndex); + } + } + + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + } + /// A dictionary whose index getter creates an entry if the requested key isn't part of the collection, making it always safe to use the returned value. The new entry's value will be the result of the default constructor of . [Serializable] public class WorkingDictionary : Dictionary diff --git a/src/BizHawk.Tests/Common/CustomCollections/CustomCollectionTests.cs b/src/BizHawk.Tests/Common/CustomCollections/CustomCollectionTests.cs new file mode 100644 index 0000000000..37d79783f9 --- /dev/null +++ b/src/BizHawk.Tests/Common/CustomCollections/CustomCollectionTests.cs @@ -0,0 +1,43 @@ +using System.Linq; + +using BizHawk.Common; + +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace BizHawk.Tests.Common.CustomCollections +{ + [TestClass] + public class CustomCollectionTests + { + [TestMethod] + public void TestSortedListAddRemove() + { + var list = new SortedList(new[] { 1, 3, 4, 7, 8, 9, 11 }); // this causes one sort, collection initializer syntax wouldn't + list.Add(5); // `Insert` when `BinarySearch` returns negative + list.Add(8); // `Insert` when `BinarySearch` returns non-negative + list.Remove(3); // `Remove` when `BinarySearch` returns non-negative + Assert.IsTrue(list.ToArray().SequenceEqual(new[] { 1, 4, 5, 7, 8, 8, 9, 11 })); + Assert.IsFalse(list.Remove(10)); // `Remove` when `BinarySearch` returns negative + } + + [TestMethod] + public void TestSortedListContains() + { + var list = new SortedList(new[] { 1, 3, 4, 7, 8, 9, 11 }); + Assert.IsFalse(list.Contains(6)); // `Contains` when `BinarySearch` returns negative + Assert.IsTrue(list.Contains(11)); // `Contains` when `BinarySearch` returns non-negative + } + + + [TestMethod] + [DataRow(new[] {1, 5, 9, 10, 11, 12}, new[] {1, 5, 9}, 9)] + [DataRow(new[] { 2, 3 }, new[] { 2, 3 }, 5)] + [DataRow(new[] { 4, 7 }, new int[] { }, 0)] + public void TestSortedListRemoveAfter(int[] before, int[] after, int removeItem) + { + var sortlist = new SortedList(before); + sortlist.RemoveAfter(removeItem); + Assert.IsTrue(sortlist.ToArray().SequenceEqual(after)); + } + } +}