diff --git a/TestHNSW/HNSWDemo/Program.cs b/TestHNSW/HNSWDemo/Program.cs index 39d8b4e..df4d556 100644 --- a/TestHNSW/HNSWDemo/Program.cs +++ b/TestHNSW/HNSWDemo/Program.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Diagnostics; +using System.IO; using System.Linq; using ZeroLevel.HNSW; @@ -98,10 +99,267 @@ namespace HNSWDemo static void Main(string[] args) { - FilterTest(); + TransformToCompactWorldTestWithAccuracity(); Console.ReadKey(); } + static void TransformToCompactWorldTest() + { + var count = 10000; + var dimensionality = 128; + var samples = RandomVectors(dimensionality, count); + var world = new SmallWorld(NSWOptions.Create(6, 15, 200, 200, CosineDistance.ForUnits, true, true, selectionHeuristic: NeighbourSelectionHeuristic.SelectSimple)); + var ids = world.AddItems(samples.ToArray()); + + Console.WriteLine("Start test"); + + byte[] dump; + using (var ms = new MemoryStream()) + { + world.Serialize(ms); + dump = ms.ToArray(); + } + Console.WriteLine($"Full dump size: {dump.Length} bytes"); + + ReadOnlySmallWorld compactWorld; + using (var ms = new MemoryStream(dump)) + { + compactWorld = SmallWorld.CreateReadOnlyWorldFrom(NSWReadOnlyOption.Create(200, CosineDistance.ForUnits, true, true, selectionHeuristic: NeighbourSelectionHeuristic.SelectSimple), ms); + } + + // Compare worlds outputs + int K = 200; + var hits = 0; + var miss = 0; + var testCount = 1000; + var sw = new Stopwatch(); + var timewatchesHNSW = new List(); + var timewatchesHNSWCompact = new List(); + var test_vectors = RandomVectors(dimensionality, testCount); + + foreach (var v in test_vectors) + { + sw.Restart(); + var gt = world.Search(v, K).Select(e => e.Item1).ToHashSet(); + sw.Stop(); + timewatchesHNSW.Add(sw.ElapsedMilliseconds); + + sw.Restart(); + var result = compactWorld.Search(v, K).Select(e => e.Item1).ToHashSet(); + sw.Stop(); + timewatchesHNSWCompact.Add(sw.ElapsedMilliseconds); + + foreach (var r in result) + { + if (gt.Contains(r)) + { + hits++; + } + else + { + miss++; + } + } + } + + byte[] smallWorldDump; + using (var ms = new MemoryStream()) + { + compactWorld.Serialize(ms); + smallWorldDump = ms.ToArray(); + } + var p = smallWorldDump.Length * 100.0f / dump.Length; + Console.WriteLine($"Compact dump size: {smallWorldDump.Length} bytes. Decrease: {100 - p}%"); + + Console.WriteLine($"HITS: {hits}"); + Console.WriteLine($"MISSES: {miss}"); + + Console.WriteLine($"MIN HNSW TIME: {timewatchesHNSW.Min()} ms"); + Console.WriteLine($"AVG HNSW TIME: {timewatchesHNSW.Average()} ms"); + Console.WriteLine($"MAX HNSW TIME: {timewatchesHNSW.Max()} ms"); + + Console.WriteLine($"MIN HNSWCompact TIME: {timewatchesHNSWCompact.Min()} ms"); + Console.WriteLine($"AVG HNSWCompact TIME: {timewatchesHNSWCompact.Average()} ms"); + Console.WriteLine($"MAX HNSWCompact TIME: {timewatchesHNSWCompact.Max()} ms"); + } + + static void TransformToCompactWorldTestWithAccuracity() + { + var count = 10000; + var dimensionality = 128; + var samples = RandomVectors(dimensionality, count); + + var test = new VectorsDirectCompare(samples, CosineDistance.ForUnits); + var world = new SmallWorld(NSWOptions.Create(6, 15, 200, 200, CosineDistance.ForUnits, true, true, selectionHeuristic: NeighbourSelectionHeuristic.SelectSimple)); + var ids = world.AddItems(samples.ToArray()); + + Console.WriteLine("Start test"); + + byte[] dump; + using (var ms = new MemoryStream()) + { + world.Serialize(ms); + dump = ms.ToArray(); + } + + ReadOnlySmallWorld compactWorld; + using (var ms = new MemoryStream(dump)) + { + compactWorld = SmallWorld.CreateReadOnlyWorldFrom(NSWReadOnlyOption.Create(200, CosineDistance.ForUnits, true, true, selectionHeuristic: NeighbourSelectionHeuristic.SelectSimple), ms); + } + + // Compare worlds outputs + int K = 200; + var hits = 0; + var miss = 0; + + var testCount = 2000; + var sw = new Stopwatch(); + var timewatchesNP = new List(); + var timewatchesHNSW = new List(); + var timewatchesHNSWCompact = new List(); + var test_vectors = RandomVectors(dimensionality, testCount); + + var totalHitsHNSW = new List(); + var totalHitsHNSWCompact = new List(); + + foreach (var v in test_vectors) + { + var npHitsHNSW = 0; + var npHitsHNSWCompact = 0; + + sw.Restart(); + var gtNP = test.KNearest(v, K).Select(p => p.Item1).ToHashSet(); + sw.Stop(); + timewatchesNP.Add(sw.ElapsedMilliseconds); + + sw.Restart(); + var gt = world.Search(v, K).Select(e => e.Item1).ToHashSet(); + sw.Stop(); + timewatchesHNSW.Add(sw.ElapsedMilliseconds); + + sw.Restart(); + var result = compactWorld.Search(v, K).Select(e => e.Item1).ToHashSet(); + sw.Stop(); + timewatchesHNSWCompact.Add(sw.ElapsedMilliseconds); + + foreach (var r in result) + { + if (gt.Contains(r)) + { + hits++; + } + else + { + miss++; + } + if (gtNP.Contains(r)) + { + npHitsHNSWCompact++; + } + } + + foreach (var r in gt) + { + if (gtNP.Contains(r)) + { + npHitsHNSW++; + } + } + + totalHitsHNSW.Add(npHitsHNSW); + totalHitsHNSWCompact.Add(npHitsHNSWCompact); + } + + byte[] smallWorldDump; + using (var ms = new MemoryStream()) + { + compactWorld.Serialize(ms); + smallWorldDump = ms.ToArray(); + } + var p = smallWorldDump.Length * 100.0f / dump.Length; + Console.WriteLine($"Full dump size: {dump.Length} bytes"); + Console.WriteLine($"Compact dump size: {smallWorldDump.Length} bytes. Decrease: {100 - p}%"); + + Console.WriteLine($"HITS: {hits}"); + Console.WriteLine($"MISSES: {miss}"); + + Console.WriteLine($"MIN NP TIME: {timewatchesNP.Min()} ms"); + Console.WriteLine($"AVG NP TIME: {timewatchesNP.Average()} ms"); + Console.WriteLine($"MAX NP TIME: {timewatchesNP.Max()} ms"); + + Console.WriteLine($"MIN HNSW TIME: {timewatchesHNSW.Min()} ms"); + Console.WriteLine($"AVG HNSW TIME: {timewatchesHNSW.Average()} ms"); + Console.WriteLine($"MAX HNSW TIME: {timewatchesHNSW.Max()} ms"); + + Console.WriteLine($"MIN HNSWCompact TIME: {timewatchesHNSWCompact.Min()} ms"); + Console.WriteLine($"AVG HNSWCompact TIME: {timewatchesHNSWCompact.Average()} ms"); + Console.WriteLine($"MAX HNSWCompact TIME: {timewatchesHNSWCompact.Max()} ms"); + + Console.WriteLine($"MIN HNSW Accuracity: {totalHitsHNSW.Min() * 100 / K}%"); + Console.WriteLine($"AVG HNSW Accuracity: {totalHitsHNSW.Average() * 100 / K}%"); + Console.WriteLine($"MAX HNSW Accuracity: {totalHitsHNSW.Max() * 100 / K}%"); + + Console.WriteLine($"MIN HNSWCompact Accuracity: {totalHitsHNSWCompact.Min() * 100 / K}%"); + Console.WriteLine($"AVG HNSWCompact Accuracity: {totalHitsHNSWCompact.Average() * 100 / K}%"); + Console.WriteLine($"MAX HNSWCompact Accuracity: {totalHitsHNSWCompact.Max() * 100 / K}%"); + } + + static void SaveRestoreTest() + { + var count = 1000; + var dimensionality = 128; + var samples = RandomVectors(dimensionality, count); + var world = new SmallWorld(NSWOptions.Create(6, 15, 200, 200, CosineDistance.ForUnits, true, true, selectionHeuristic: NeighbourSelectionHeuristic.SelectSimple)); + var sw = new Stopwatch(); + sw.Start(); + var ids = world.AddItems(samples.ToArray()); + sw.Stop(); + Console.WriteLine($"Insert {ids.Length} items on {sw.ElapsedMilliseconds} ms"); + Console.WriteLine("Start test"); + + byte[] dump; + using (var ms = new MemoryStream()) + { + world.Serialize(ms); + dump = ms.ToArray(); + } + Console.WriteLine($"Full dump size: {dump.Length} bytes"); + + byte[] testDump; + var restoredWorld = new SmallWorld(NSWOptions.Create(6, 15, 200, 200, CosineDistance.ForUnits, true, true, selectionHeuristic: NeighbourSelectionHeuristic.SelectSimple)); + using (var ms = new MemoryStream(dump)) + { + restoredWorld.Deserialize(ms); + } + + using (var ms = new MemoryStream()) + { + restoredWorld.Serialize(ms); + testDump = ms.ToArray(); + } + if (testDump.Length != dump.Length) + { + Console.WriteLine($"Incorrect restored size. Got {testDump.Length}. Expected: {dump.Length}"); + return; + } + + ReadOnlySmallWorld compactWorld; + using (var ms = new MemoryStream(dump)) + { + compactWorld = SmallWorld.CreateReadOnlyWorldFrom(NSWReadOnlyOption.Create(200, CosineDistance.ForUnits, true, true, selectionHeuristic: NeighbourSelectionHeuristic.SelectSimple), ms); + } + + byte[] smallWorldDump; + using (var ms = new MemoryStream()) + { + compactWorld.Serialize(ms); + smallWorldDump = ms.ToArray(); + } + var p = smallWorldDump.Length * 100.0f / dump.Length; + Console.WriteLine($"Compact dump size: {smallWorldDump.Length} bytes. Decrease: {100 - p}%"); + } + static void FilterTest() { var count = 5000; @@ -110,7 +368,7 @@ namespace HNSWDemo var samples = Person.GenerateRandom(dimensionality, count); var world = new SmallWorld(NSWOptions.Create(6, 15, 200, 200, CosineDistance.ForUnits, true, true, selectionHeuristic: NeighbourSelectionHeuristic.SelectSimple)); - + var ids = world.AddItems(samples.Select(i => i.Item1).ToArray()); for (int bi = 0; bi < samples.Count; bi++) { @@ -167,7 +425,7 @@ namespace HNSWDemo Console.WriteLine($"Insert {ids.Length} items on {sw.ElapsedMilliseconds} ms"); Console.WriteLine("Start test"); - + var test_vectors = RandomVectors(dimensionality, testCount); foreach (var v in test_vectors) { diff --git a/ZeroLevel.HNSW/Model/NSWOptions.cs b/ZeroLevel.HNSW/Model/NSWOptions.cs index b4af8cc..344b042 100644 --- a/ZeroLevel.HNSW/Model/NSWOptions.cs +++ b/ZeroLevel.HNSW/Model/NSWOptions.cs @@ -2,29 +2,12 @@ namespace ZeroLevel.HNSW { - /// - /// Type of heuristic to select best neighbours for a node. - /// - public enum NeighbourSelectionHeuristic - { - /// - /// Marker for the Algorithm 3 (SELECT-NEIGHBORS-SIMPLE) from the article. Implemented in - /// - SelectSimple, - - /// - /// Marker for the Algorithm 4 (SELECT-NEIGHBORS-HEURISTIC) from the article. Implemented in - /// - SelectHeuristic - } - public sealed class NSWOptions { /// /// Mox node connections on Layer /// public readonly int M; - /// /// Max search buffer /// diff --git a/ZeroLevel.HNSW/Model/NSWReadOnlyOption.cs b/ZeroLevel.HNSW/Model/NSWReadOnlyOption.cs new file mode 100644 index 0000000..bdc9d97 --- /dev/null +++ b/ZeroLevel.HNSW/Model/NSWReadOnlyOption.cs @@ -0,0 +1,44 @@ +using System; + +namespace ZeroLevel.HNSW +{ + public sealed class NSWReadOnlyOption + { + /// + /// Max search buffer + /// + public readonly int EF; + /// + /// Distance function beetween vectors + /// + public readonly Func Distance; + + public readonly bool ExpandBestSelection; + + public readonly bool KeepPrunedConnections; + + public readonly NeighbourSelectionHeuristic SelectionHeuristic; + + private NSWReadOnlyOption( + int ef, + Func distance, + bool expandBestSelection, + bool keepPrunedConnections, + NeighbourSelectionHeuristic selectionHeuristic) + { + EF = ef; + Distance = distance; + ExpandBestSelection = expandBestSelection; + KeepPrunedConnections = keepPrunedConnections; + SelectionHeuristic = selectionHeuristic; + } + + public static NSWReadOnlyOption Create( + int EF, + Func distance, + bool expandBestSelection = false, + bool keepPrunedConnections = false, + NeighbourSelectionHeuristic selectionHeuristic = NeighbourSelectionHeuristic.SelectSimple) => + new NSWReadOnlyOption(EF, distance, expandBestSelection, keepPrunedConnections, selectionHeuristic); + } +} diff --git a/ZeroLevel.HNSW/Model/NeighbourSelectionHeuristic.cs b/ZeroLevel.HNSW/Model/NeighbourSelectionHeuristic.cs new file mode 100644 index 0000000..48b88fb --- /dev/null +++ b/ZeroLevel.HNSW/Model/NeighbourSelectionHeuristic.cs @@ -0,0 +1,18 @@ +namespace ZeroLevel.HNSW +{ + /// + /// Type of heuristic to select best neighbours for a node. + /// + public enum NeighbourSelectionHeuristic + { + /// + /// Marker for the Algorithm 3 (SELECT-NEIGHBORS-SIMPLE) from the article. Implemented in + /// + SelectSimple, + + /// + /// Marker for the Algorithm 4 (SELECT-NEIGHBORS-HEURISTIC) from the article. Implemented in + /// + SelectHeuristic + } +} diff --git a/ZeroLevel.HNSW/ReadOnlySmallWorld.cs b/ZeroLevel.HNSW/ReadOnlySmallWorld.cs new file mode 100644 index 0000000..7f6929e --- /dev/null +++ b/ZeroLevel.HNSW/ReadOnlySmallWorld.cs @@ -0,0 +1,152 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using ZeroLevel.Services.Serialization; + +namespace ZeroLevel.HNSW +{ + public class ReadOnlySmallWorld + { + private readonly NSWReadOnlyOption _options; + private ReadOnlyVectorSet _vectors; + private ReadOnlyLayer[] _layers; + private int EntryPoint = 0; + private int MaxLayer = 0; + + private ReadOnlySmallWorld() { } + + internal ReadOnlySmallWorld(NSWReadOnlyOption options, Stream stream) + { + _options = options; + Deserialize(stream); + } + + /// + /// Search in the graph K for vectors closest to a given vector + /// + /// Given vector + /// Count of elements for search + /// + /// + public IEnumerable<(int, TItem, float)> Search(TItem vector, int k) + { + foreach (var pair in KNearest(vector, k)) + { + yield return (pair.Item1, _vectors[pair.Item1], pair.Item2); + } + } + + public IEnumerable<(int, TItem, float)> Search(TItem vector, int k, HashSet activeNodes) + { + if (activeNodes == null) + { + foreach (var pair in KNearest(vector, k)) + { + yield return (pair.Item1, _vectors[pair.Item1], pair.Item2); + } + } + else + { + foreach (var pair in KNearest(vector, k, activeNodes)) + { + yield return (pair.Item1, _vectors[pair.Item1], pair.Item2); + } + } + } + + /// + /// Algorithm 5 + /// + private IEnumerable<(int, float)> KNearest(TItem q, int k) + { + if (_vectors.Count == 0) + { + return Enumerable.Empty<(int, float)>(); + } + var distance = new Func(candidate => _options.Distance(q, _vectors[candidate])); + + // W ← ∅ // set for the current nearest elements + var W = new Dictionary(k + 1); + // ep ← get enter point for hnsw + var ep = EntryPoint; + // L ← level of ep // top layer for hnsw + var L = MaxLayer; + // for lc ← L … 1 + for (int layer = L; layer > 0; --layer) + { + // W ← SEARCH-LAYER(q, ep, ef = 1, lc) + _layers[layer].KNearestAtLayer(ep, distance, W, 1); + // ep ← get nearest element from W to q + ep = W.OrderBy(p => p.Value).First().Key; + W.Clear(); + } + // W ← SEARCH-LAYER(q, ep, ef, lc =0) + _layers[0].KNearestAtLayer(ep, distance, W, k); + // return K nearest elements from W to q + return W.Select(p => (p.Key, p.Value)); + } + + private IEnumerable<(int, float)> KNearest(TItem q, int k, HashSet activeNodes) + { + if (_vectors.Count == 0) + { + return Enumerable.Empty<(int, float)>(); + } + var distance = new Func(candidate => _options.Distance(q, _vectors[candidate])); + + // W ← ∅ // set for the current nearest elements + var W = new Dictionary(k + 1); + // ep ← get enter point for hnsw + var ep = EntryPoint; + // L ← level of ep // top layer for hnsw + var L = MaxLayer; + // for lc ← L … 1 + for (int layer = L; layer > 0; --layer) + { + // W ← SEARCH-LAYER(q, ep, ef = 1, lc) + _layers[layer].KNearestAtLayer(ep, distance, W, 1); + // ep ← get nearest element from W to q + ep = W.OrderBy(p => p.Value).First().Key; + W.Clear(); + } + // W ← SEARCH-LAYER(q, ep, ef, lc =0) + _layers[0].KNearestAtLayer(ep, distance, W, k, activeNodes); + // return K nearest elements from W to q + return W.Select(p => (p.Key, p.Value)); + } + + public void Serialize(Stream stream) + { + using (var writer = new MemoryStreamWriter(stream)) + { + writer.WriteInt32(EntryPoint); + writer.WriteInt32(MaxLayer); + _vectors.Serialize(writer); + writer.WriteInt32(_layers.Length); + foreach (var l in _layers) + { + l.Serialize(writer); + } + } + } + + public void Deserialize(Stream stream) + { + using (var reader = new MemoryStreamReader(stream)) + { + this.EntryPoint = reader.ReadInt32(); + this.MaxLayer = reader.ReadInt32(); + _vectors = new ReadOnlyVectorSet(); + _vectors.Deserialize(reader); + var countLayers = reader.ReadInt32(); + _layers = new ReadOnlyLayer[countLayers]; + for (int i = 0; i < countLayers; i++) + { + _layers[i] = new ReadOnlyLayer(_options, _vectors); + _layers[i].Deserialize(reader); + } + } + } + } +} diff --git a/ZeroLevel.HNSW/Services/CompactBiDirectionalLinksSet.cs b/ZeroLevel.HNSW/Services/CompactBiDirectionalLinksSet.cs index 1ea22db..5c9f81e 100644 --- a/ZeroLevel.HNSW/Services/CompactBiDirectionalLinksSet.cs +++ b/ZeroLevel.HNSW/Services/CompactBiDirectionalLinksSet.cs @@ -2,11 +2,12 @@ using System.Collections.Generic; using System.Linq; using System.Threading; +using ZeroLevel.Services.Serialization; namespace ZeroLevel.HNSW { internal sealed class CompactBiDirectionalLinksSet - : IDisposable + : IBinarySerializable, IDisposable { private readonly ReaderWriterLockSlim _rwLock = new ReaderWriterLockSlim(); @@ -14,7 +15,7 @@ namespace ZeroLevel.HNSW private SortedList _set = new SortedList(); - public (int, int) this[int index] + internal (int, int) this[int index] { get { @@ -25,12 +26,12 @@ namespace ZeroLevel.HNSW } } - public int Count => _set.Count; + internal int Count => _set.Count; /// /// Разрывает связи id1 - id2 и id2 - id1, и строит новые id1 - id, id - id1 /// - public void Relink(int id1, int id2, int id, float distance) + internal void Relink(int id1, int id2, int id, float distance) { long k1old = (((long)(id1)) << HALF_LONG_BITS) + id2; long k2old = (((long)(id2)) << HALF_LONG_BITS) + id1; @@ -57,7 +58,7 @@ namespace ZeroLevel.HNSW /// /// Разрывает связи id1 - id2 и id2 - id1, и строит новые id1 - id, id - id1, id2 - id, id - id2 /// - public void Relink(int id1, int id2, int id, float distanceToId1, float distanceToId2) + internal void Relink(int id1, int id2, int id, float distanceToId1, float distanceToId2) { long k_id1_id2 = (((long)(id1)) << HALF_LONG_BITS) + id2; long k_id2_id1 = (((long)(id2)) << HALF_LONG_BITS) + id1; @@ -96,7 +97,7 @@ namespace ZeroLevel.HNSW } } - public IEnumerable<(int, int, float)> FindLinksForId(int id) + internal IEnumerable<(int, int, float)> FindLinksForId(int id) { _rwLock.EnterReadLock(); try @@ -114,7 +115,7 @@ namespace ZeroLevel.HNSW } } - public IEnumerable<(int, int, float)> Items() + internal IEnumerable<(int, int, float)> Items() { _rwLock.EnterReadLock(); try @@ -132,7 +133,7 @@ namespace ZeroLevel.HNSW } } - public void RemoveIndex(int id) + internal void RemoveIndex(int id) { long[] forward; long[] backward; @@ -169,7 +170,7 @@ namespace ZeroLevel.HNSW } } - public bool Add(int id1, int id2, float distance) + internal bool Add(int id1, int id2, float distance) { _rwLock.EnterWriteLock(); try @@ -193,7 +194,7 @@ namespace ZeroLevel.HNSW return false; } - static IEnumerable<(long, float)> Search(SortedList set, int index) + private static IEnumerable<(long, float)> Search(SortedList set, int index) { long k = ((long)index) << HALF_LONG_BITS; int left = 0; @@ -232,7 +233,7 @@ namespace ZeroLevel.HNSW return Enumerable.Empty<(long, float)>(); } - static IEnumerable<(long, float)> SearchByPosition(SortedList set, long k, int position) + private static IEnumerable<(long, float)> SearchByPosition(SortedList set, long k, int position) { var start = position; var end = position; @@ -259,5 +260,34 @@ namespace ZeroLevel.HNSW _set.Clear(); _set = null; } + + public void Serialize(IBinaryWriter writer) + { + writer.WriteBoolean(true); // true - set with weights + writer.WriteInt32(_set.Count); + foreach (var record in _set) + { + writer.WriteLong(record.Key); + writer.WriteFloat(record.Value); + } + } + + public void Deserialize(IBinaryReader reader) + { + if (reader.ReadBoolean() == false) + { + throw new InvalidOperationException("Incompatible data format. The set does not contain weights."); + } + _set.Clear(); + _set = null; + var count = reader.ReadInt32(); + _set = new SortedList(count + 1); + for (int i = 0; i < count; i++) + { + var key = reader.ReadLong(); + var value = reader.ReadFloat(); + _set.Add(key, value); + } + } } } diff --git a/ZeroLevel.HNSW/Layer.cs b/ZeroLevel.HNSW/Services/Layer.cs similarity index 98% rename from ZeroLevel.HNSW/Layer.cs rename to ZeroLevel.HNSW/Services/Layer.cs index d7ec5f3..65d90e2 100644 --- a/ZeroLevel.HNSW/Layer.cs +++ b/ZeroLevel.HNSW/Services/Layer.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Linq; +using ZeroLevel.Services.Serialization; namespace ZeroLevel.HNSW { @@ -8,6 +9,7 @@ namespace ZeroLevel.HNSW /// NSW graph /// internal sealed class Layer + : IBinarySerializable { private readonly NSWOptions _options; private readonly VectorSet _vectors; @@ -352,5 +354,15 @@ namespace ZeroLevel.HNSW #endregion private IEnumerable GetNeighbors(int id) => _links.FindLinksForId(id).Select(d => d.Item2); + + public void Serialize(IBinaryWriter writer) + { + _links.Serialize(writer); + } + + public void Deserialize(IBinaryReader reader) + { + _links.Deserialize(reader); + } } } \ No newline at end of file diff --git a/ZeroLevel.HNSW/Services/ReadOnly/ReadOnlyCompactBiDirectionalLinksSet.cs b/ZeroLevel.HNSW/Services/ReadOnly/ReadOnlyCompactBiDirectionalLinksSet.cs new file mode 100644 index 0000000..2b190ff --- /dev/null +++ b/ZeroLevel.HNSW/Services/ReadOnly/ReadOnlyCompactBiDirectionalLinksSet.cs @@ -0,0 +1,107 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using ZeroLevel.Services.Serialization; + +namespace ZeroLevel.HNSW +{ + internal sealed class ReadOnlyCompactBiDirectionalLinksSet + : IBinarySerializable, IDisposable + { + private const int HALF_LONG_BITS = 32; + + private Dictionary _set = new Dictionary(); + + internal int Count => _set.Count; + + internal IEnumerable FindLinksForId(int id) + { + if (_set.ContainsKey(id)) + { + return _set[id]; + } + return Enumerable.Empty(); + } + + public void Dispose() + { + _set.Clear(); + _set = null; + } + + public void Serialize(IBinaryWriter writer) + { + writer.WriteBoolean(false); // false - set without weights + writer.WriteInt32(_set.Count); + foreach (var record in _set) + { + writer.WriteInt32(record.Key); + writer.WriteCollection(record.Value); + } + } + + public void Deserialize(IBinaryReader reader) + { + _set.Clear(); + _set = null; + if (reader.ReadBoolean() == false) + { + var count = reader.ReadInt32(); + _set = new Dictionary(count); + for (int i = 0; i < count; i++) + { + var key = reader.ReadInt32(); + var value = reader.ReadInt32Array(); + _set.Add(key, value); + } + } + else + { + var count = reader.ReadInt32(); + _set = new Dictionary(count); + + // hack, We know that an sortedset has been saved + long key; + int id1, id2; + var prevId = -1; + var set = new HashSet(); + + for (int i = 0; i < count; i++) + { + key = reader.ReadLong(); + id1 = (int)(key >> HALF_LONG_BITS); + id2 = (int)(key - (((long)id1) << HALF_LONG_BITS)); + + reader.ReadFloat(); // SKIP + + if (prevId == -1) + { + prevId = id1; + if (id1 != id2) + { + set.Add(id2); + } + } + else if (prevId != id1) + { + _set.Add(prevId, set.ToArray()); + set.Clear(); + prevId = id1; + } + else + { + if (id1 != id2) + { + set.Add(id2); + } + } + } + if (set.Count > 0) + { + _set.Add(prevId, set.ToArray()); + set.Clear(); + } + } + } + } +} diff --git a/ZeroLevel.HNSW/Services/ReadOnly/ReadOnlyLayer.cs b/ZeroLevel.HNSW/Services/ReadOnly/ReadOnlyLayer.cs new file mode 100644 index 0000000..f35905e --- /dev/null +++ b/ZeroLevel.HNSW/Services/ReadOnly/ReadOnlyLayer.cs @@ -0,0 +1,316 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using ZeroLevel.Services.Serialization; + +namespace ZeroLevel.HNSW +{ + /// + /// NSW graph + /// + internal sealed class ReadOnlyLayer + : IBinarySerializable + { + private readonly NSWReadOnlyOption _options; + private readonly ReadOnlyVectorSet _vectors; + private readonly ReadOnlyCompactBiDirectionalLinksSet _links; + + /// + /// HNSW layer + /// + /// HNSW graph options + /// General vector set + internal ReadOnlyLayer(NSWReadOnlyOption options, ReadOnlyVectorSet vectors) + { + _options = options; + _vectors = vectors; + _links = new ReadOnlyCompactBiDirectionalLinksSet(); + } + + #region Implementation of https://arxiv.org/ftp/arxiv/papers/1603/1603.09320.pdf + /// + /// Algorithm 2 + /// + /// query element + /// enter points ep + /// Output: ef closest neighbors to q + internal void KNearestAtLayer(int entryPointId, Func targetCosts, IDictionary W, int ef) + { + /* + * v ← ep // set of visited elements + * C ← ep // set of candidates + * W ← ep // dynamic list of found nearest neighbors + * while │C│ > 0 + * c ← extract nearest element from C to q + * f ← get furthest element from W to q + * if distance(c, q) > distance(f, q) + * break // all elements in W are evaluated + * for each e ∈ neighbourhood(c) at layer lc // update C and W + * if e ∉ v + * v ← v ⋃ e + * f ← get furthest element from W to q + * if distance(e, q) < distance(f, q) or │W│ < ef + * C ← C ⋃ e + * W ← W ⋃ e + * if │W│ > ef + * remove furthest element from W to q + * return W + */ + var v = new VisitedBitSet(_vectors.Count, 1); + // v ← ep // set of visited elements + v.Add(entryPointId); + // C ← ep // set of candidates + var C = new Dictionary(); + C.Add(entryPointId, targetCosts(entryPointId)); + // W ← ep // dynamic list of found nearest neighbors + W.Add(entryPointId, C[entryPointId]); + + var popCandidate = new Func<(int, float)>(() => { var pair = C.OrderBy(e => e.Value).First(); C.Remove(pair.Key); return (pair.Key, pair.Value); }); + var fartherFromResult = new Func<(int, float)>(() => { var pair = W.OrderByDescending(e => e.Value).First(); return (pair.Key, pair.Value); }); + var fartherPopFromResult = new Action(() => { var pair = W.OrderByDescending(e => e.Value).First(); W.Remove(pair.Key); }); + // run bfs + while (C.Count > 0) + { + // get next candidate to check and expand + var toExpand = popCandidate(); + var farthestResult = fartherFromResult(); + if (toExpand.Item2 > farthestResult.Item2) + { + // the closest candidate is farther than farthest result + break; + } + + // expand candidate + var neighboursIds = GetNeighbors(toExpand.Item1).ToArray(); + for (int i = 0; i < neighboursIds.Length; ++i) + { + int neighbourId = neighboursIds[i]; + if (!v.Contains(neighbourId)) + { + // enqueue perspective neighbours to expansion list + farthestResult = fartherFromResult(); + + var neighbourDistance = targetCosts(neighbourId); + if (W.Count < ef || neighbourDistance < farthestResult.Item2) + { + C.Add(neighbourId, neighbourDistance); + W.Add(neighbourId, neighbourDistance); + if (W.Count > ef) + { + fartherPopFromResult(); + } + } + v.Add(neighbourId); + } + } + } + C.Clear(); + v.Clear(); + } + + /// + /// Algorithm 2 + /// + /// query element + /// enter points ep + /// Output: ef closest neighbors to q + internal void KNearestAtLayer(int entryPointId, Func targetCosts, IDictionary W, int ef, HashSet activeNodes) + { + /* + * v ← ep // set of visited elements + * C ← ep // set of candidates + * W ← ep // dynamic list of found nearest neighbors + * while │C│ > 0 + * c ← extract nearest element from C to q + * f ← get furthest element from W to q + * if distance(c, q) > distance(f, q) + * break // all elements in W are evaluated + * for each e ∈ neighbourhood(c) at layer lc // update C and W + * if e ∉ v + * v ← v ⋃ e + * f ← get furthest element from W to q + * if distance(e, q) < distance(f, q) or │W│ < ef + * C ← C ⋃ e + * W ← W ⋃ e + * if │W│ > ef + * remove furthest element from W to q + * return W + */ + var v = new VisitedBitSet(_vectors.Count, 1); + // v ← ep // set of visited elements + v.Add(entryPointId); + // C ← ep // set of candidates + var C = new Dictionary(); + C.Add(entryPointId, targetCosts(entryPointId)); + // W ← ep // dynamic list of found nearest neighbors + if (activeNodes.Contains(entryPointId)) + { + W.Add(entryPointId, C[entryPointId]); + } + var popCandidate = new Func<(int, float)>(() => { var pair = C.OrderBy(e => e.Value).First(); C.Remove(pair.Key); return (pair.Key, pair.Value); }); + var farthestDistance = new Func(() => { var pair = W.OrderByDescending(e => e.Value).First(); return pair.Value; }); + var fartherPopFromResult = new Action(() => { var pair = W.OrderByDescending(e => e.Value).First(); W.Remove(pair.Key); }); + // run bfs + while (C.Count > 0) + { + // get next candidate to check and expand + var toExpand = popCandidate(); + if (W.Count > 0) + { + if (toExpand.Item2 > farthestDistance()) + { + // the closest candidate is farther than farthest result + break; + } + } + + // expand candidate + var neighboursIds = GetNeighbors(toExpand.Item1).ToArray(); + for (int i = 0; i < neighboursIds.Length; ++i) + { + int neighbourId = neighboursIds[i]; + if (!v.Contains(neighbourId)) + { + // enqueue perspective neighbours to expansion list + var neighbourDistance = targetCosts(neighbourId); + if (activeNodes.Contains(neighbourId)) + { + if (W.Count < ef || (W.Count > 0 && neighbourDistance < farthestDistance())) + { + W.Add(neighbourId, neighbourDistance); + if (W.Count > ef) + { + fartherPopFromResult(); + } + } + } + if (W.Count < ef) + { + C.Add(neighbourId, neighbourDistance); + } + v.Add(neighbourId); + } + } + } + C.Clear(); + v.Clear(); + } + + /// + /// Algorithm 3 + /// + internal IDictionary SELECT_NEIGHBORS_SIMPLE(Func distance, IDictionary candidates, int M) + { + var bestN = M; + var W = new Dictionary(candidates); + if (W.Count > bestN) + { + var popFarther = new Action(() => { var pair = W.OrderByDescending(e => e.Value).First(); W.Remove(pair.Key); }); + while (W.Count > bestN) + { + popFarther(); + } + } + // return M nearest elements from C to q + return W; + } + + + + /// + /// Algorithm 4 + /// + /// base element + /// candidate elements + /// flag indicating whether or not to extend candidate list + /// flag indicating whether or not to add discarded elements + /// Output: M elements selected by the heuristic + internal IDictionary SELECT_NEIGHBORS_HEURISTIC(Func distance, IDictionary candidates, int M) + { + // R ← ∅ + var R = new Dictionary(); + // W ← C // working queue for the candidates + var W = new Dictionary(candidates); + // if extendCandidates // extend candidates by their neighbors + if (_options.ExpandBestSelection) + { + var extendBuffer = new HashSet(); + // for each e ∈ C + foreach (var e in W) + { + var neighbors = GetNeighbors(e.Key); + // for each e_adj ∈ neighbourhood(e) at layer lc + foreach (var e_adj in neighbors) + { + // if eadj ∉ W + if (extendBuffer.Contains(e_adj) == false) + { + extendBuffer.Add(e_adj); + } + } + } + // W ← W ⋃ eadj + foreach (var id in extendBuffer) + { + W[id] = distance(id); + } + } + + // Wd ← ∅ // queue for the discarded candidates + var Wd = new Dictionary(); + + + var popCandidate = new Func<(int, float)>(() => { var pair = W.OrderBy(e => e.Value).First(); W.Remove(pair.Key); return (pair.Key, pair.Value); }); + var fartherFromResult = new Func<(int, float)>(() => { if (R.Count == 0) return (-1, 0f); var pair = R.OrderByDescending(e => e.Value).First(); return (pair.Key, pair.Value); }); + var popNearestDiscarded = new Func<(int, float)>(() => { var pair = Wd.OrderBy(e => e.Value).First(); Wd.Remove(pair.Key); return (pair.Key, pair.Value); }); + + + // while │W│ > 0 and │R│< M + while (W.Count > 0 && R.Count < M) + { + // e ← extract nearest element from W to q + var (e, ed) = popCandidate(); + var (fe, fd) = fartherFromResult(); + + // if e is closer to q compared to any element from R + if (R.Count == 0 || + ed < fd) + { + // R ← R ⋃ e + R.Add(e, ed); + } + else + { + // Wd ← Wd ⋃ e + Wd.Add(e, ed); + } + } + // if keepPrunedConnections // add some of the discarded // connections from Wd + if (_options.KeepPrunedConnections) + { + // while │Wd│> 0 and │R│< M + while (Wd.Count > 0 && R.Count < M) + { + // R ← R ⋃ extract nearest element from Wd to q + var nearest = popNearestDiscarded(); + R[nearest.Item1] = nearest.Item2; + } + } + // return R + return R; + } + #endregion + + private IEnumerable GetNeighbors(int id) => _links.FindLinksForId(id); + + public void Serialize(IBinaryWriter writer) + { + _links.Serialize(writer); + } + + public void Deserialize(IBinaryReader reader) + { + _links.Deserialize(reader); + } + } +} diff --git a/ZeroLevel.HNSW/Services/ReadOnly/ReadOnlyVectorSet.cs b/ZeroLevel.HNSW/Services/ReadOnly/ReadOnlyVectorSet.cs new file mode 100644 index 0000000..50801f2 --- /dev/null +++ b/ZeroLevel.HNSW/Services/ReadOnly/ReadOnlyVectorSet.cs @@ -0,0 +1,33 @@ +using System.Collections.Generic; +using ZeroLevel.Services.Serialization; + +namespace ZeroLevel.HNSW +{ + internal sealed class ReadOnlyVectorSet + : IBinarySerializable + { + private List _set = new List(); + + internal T this[int index] => _set[index]; + internal int Count => _set.Count; + + public void Deserialize(IBinaryReader reader) + { + int count = reader.ReadInt32(); + _set = new List(count + 1); + for (int i = 0; i < count; i++) + { + _set.Add(reader.ReadCompatible()); + } + } + + public void Serialize(IBinaryWriter writer) + { + writer.WriteInt32(_set.Count); + foreach (var r in _set) + { + writer.WriteCompatible(r); + } + } + } +} diff --git a/ZeroLevel.HNSW/Services/VectorSet.cs b/ZeroLevel.HNSW/Services/VectorSet.cs index 227a2c8..761973c 100644 --- a/ZeroLevel.HNSW/Services/VectorSet.cs +++ b/ZeroLevel.HNSW/Services/VectorSet.cs @@ -1,18 +1,19 @@ -using System; -using System.Collections.Generic; +using System.Collections.Generic; using System.Threading; +using ZeroLevel.Services.Serialization; namespace ZeroLevel.HNSW { - public class VectorSet + internal sealed class VectorSet + : IBinarySerializable { private List _set = new List(); private SpinLock _lock = new SpinLock(); - public T this[int index] => _set[index]; - public int Count => _set.Count; + internal T this[int index] => _set[index]; + internal int Count => _set.Count; - public int Append(T vector) + internal int Append(T vector) { bool gotLock = false; gotLock = false; @@ -29,7 +30,7 @@ namespace ZeroLevel.HNSW } } - public int[] Append(IEnumerable vectors) + internal int[] Append(IEnumerable vectors) { bool gotLock = false; int startIndex, endIndex; @@ -53,5 +54,24 @@ namespace ZeroLevel.HNSW } return ids; } + + public void Deserialize(IBinaryReader reader) + { + int count = reader.ReadInt32(); + _set = new List(count + 1); + for (int i = 0; i < count; i++) + { + _set.Add(reader.ReadCompatible()); + } + } + + public void Serialize(IBinaryWriter writer) + { + writer.WriteInt32(_set.Count); + foreach (var r in _set) + { + writer.WriteCompatible(r); + } + } } } diff --git a/ZeroLevel.HNSW/SmallWorld.cs b/ZeroLevel.HNSW/SmallWorld.cs index 8349817..fe51c9e 100644 --- a/ZeroLevel.HNSW/SmallWorld.cs +++ b/ZeroLevel.HNSW/SmallWorld.cs @@ -1,43 +1,18 @@ using System; using System.Collections.Generic; +using System.IO; using System.Linq; using System.Threading; +using ZeroLevel.HNSW.Services; +using ZeroLevel.Services.Serialization; namespace ZeroLevel.HNSW { - public class ProbabilityLayerNumberGenerator - { - private const float DIVIDER = 3.361f; - private readonly float[] _probabilities; - - public ProbabilityLayerNumberGenerator(int maxLayers, int M) - { - _probabilities = new float[maxLayers]; - var probability = 1.0f / DIVIDER; - for (int i = 0; i < maxLayers; i++) - { - _probabilities[i] = probability; - probability /= DIVIDER; - } - } - - public int GetRandomLayer() - { - var probability = DefaultRandomGenerator.Instance.NextFloat(); - for (int i = 0; i < _probabilities.Length; i++) - { - if (probability > _probabilities[i]) - return i; - } - return 0; - } - } - public class SmallWorld { private readonly NSWOptions _options; private readonly VectorSet _vectors; - private readonly Layer[] _layers; + private Layer[] _layers; private int EntryPoint = 0; private int MaxLayer = 0; private readonly ProbabilityLayerNumberGenerator _layerLevelGenerator; @@ -55,6 +30,12 @@ namespace ZeroLevel.HNSW } } + internal SmallWorld(NSWOptions options, Stream stream) + { + _options = options; + Deserialize(stream); + } + /// /// Search in the graph K for vectors closest to a given vector /// @@ -62,14 +43,32 @@ namespace ZeroLevel.HNSW /// Count of elements for search /// /// - public IEnumerable<(int, TItem, float)> Search(TItem vector, int k, HashSet activeNodes = null) + public IEnumerable<(int, TItem, float)> Search(TItem vector, int k) { - foreach (var pair in KNearest(vector, k, activeNodes)) + foreach (var pair in KNearest(vector, k)) { yield return (pair.Item1, _vectors[pair.Item1], pair.Item2); } } + public IEnumerable<(int, TItem, float)> Search(TItem vector, int k, HashSet activeNodes) + { + if (activeNodes == null) + { + foreach (var pair in KNearest(vector, k)) + { + yield return (pair.Item1, _vectors[pair.Item1], pair.Item2); + } + } + else + { + foreach (var pair in KNearest(vector, k, activeNodes)) + { + yield return (pair.Item1, _vectors[pair.Item1], pair.Item2); + } + } + } + /// /// Adding vectors batch /// @@ -108,7 +107,7 @@ namespace ZeroLevel.HNSW // L ← level of ep // top layer for hnsw var L = MaxLayer; // l ← ⌊-ln(unif(0..1))∙mL⌋ // new element’s level - int l = _layerLevelGenerator.GetRandomLayer(); + int l = _layerLevelGenerator.GetRandomLayer(); // for lc ← L … l+1 // Проход с верхнего уровня до уровня где появляется элемент, для нахождения точки входа for (int lc = L; lc > l; --lc) @@ -201,7 +200,43 @@ namespace ZeroLevel.HNSW /// /// Algorithm 5 /// - private IEnumerable<(int, float)> KNearest(TItem q, int k, HashSet activeNodes = null) + private IEnumerable<(int, float)> KNearest(TItem q, int k) + { + _lockGraph.EnterReadLock(); + try + { + if (_vectors.Count == 0) + { + return Enumerable.Empty<(int, float)>(); + } + var distance = new Func(candidate => _options.Distance(q, _vectors[candidate])); + + // W ← ∅ // set for the current nearest elements + var W = new Dictionary(k + 1); + // ep ← get enter point for hnsw + var ep = EntryPoint; + // L ← level of ep // top layer for hnsw + var L = MaxLayer; + // for lc ← L … 1 + for (int layer = L; layer > 0; --layer) + { + // W ← SEARCH-LAYER(q, ep, ef = 1, lc) + _layers[layer].KNearestAtLayer(ep, distance, W, 1); + // ep ← get nearest element from W to q + ep = W.OrderBy(p => p.Value).First().Key; + W.Clear(); + } + // W ← SEARCH-LAYER(q, ep, ef, lc =0) + _layers[0].KNearestAtLayer(ep, distance, W, k); + // return K nearest elements from W to q + return W.Select(p => (p.Key, p.Value)); + } + finally + { + _lockGraph.ExitReadLock(); + } + } + private IEnumerable<(int, float)> KNearest(TItem q, int k, HashSet activeNodes) { _lockGraph.EnterReadLock(); try @@ -238,5 +273,37 @@ namespace ZeroLevel.HNSW } } #endregion + + public void Serialize(Stream stream) + { + using (var writer = new MemoryStreamWriter(stream)) + { + writer.WriteInt32(EntryPoint); + writer.WriteInt32(MaxLayer); + _vectors.Serialize(writer); + writer.WriteInt32(_layers.Length); + foreach (var l in _layers) + { + l.Serialize(writer); + } + } + } + + public void Deserialize(Stream stream) + { + using (var reader = new MemoryStreamReader(stream)) + { + this.EntryPoint = reader.ReadInt32(); + this.MaxLayer = reader.ReadInt32(); + _vectors.Deserialize(reader); + var countLayers = reader.ReadInt32(); + _layers = new Layer[countLayers]; + for (int i = 0; i < countLayers; i++) + { + _layers[i] = new Layer(_options, _vectors); + _layers[i].Deserialize(reader); + } + } + } } } diff --git a/ZeroLevel.HNSW/SmallWorldFactory.cs b/ZeroLevel.HNSW/SmallWorldFactory.cs new file mode 100644 index 0000000..cbbdd1a --- /dev/null +++ b/ZeroLevel.HNSW/SmallWorldFactory.cs @@ -0,0 +1,11 @@ +using System.IO; + +namespace ZeroLevel.HNSW +{ + public static class SmallWorld + { + public static SmallWorld CreateWorld(NSWOptions options) => new SmallWorld(options); + public static SmallWorld CreateWorldFrom(NSWOptions options, Stream stream) => new SmallWorld(options, stream); + public static ReadOnlySmallWorld CreateReadOnlyWorldFrom(NSWReadOnlyOption options, Stream stream) => new ReadOnlySmallWorld(options, stream); + } +} diff --git a/ZeroLevel.HNSW/Services/CosineDistance.cs b/ZeroLevel.HNSW/Utils/CosineDistance.cs similarity index 100% rename from ZeroLevel.HNSW/Services/CosineDistance.cs rename to ZeroLevel.HNSW/Utils/CosineDistance.cs diff --git a/ZeroLevel.HNSW/Services/FastRandom.cs b/ZeroLevel.HNSW/Utils/FastRandom.cs similarity index 100% rename from ZeroLevel.HNSW/Services/FastRandom.cs rename to ZeroLevel.HNSW/Utils/FastRandom.cs diff --git a/ZeroLevel.HNSW/Utils/ProbabilityLayerNumberGenerator.cs b/ZeroLevel.HNSW/Utils/ProbabilityLayerNumberGenerator.cs new file mode 100644 index 0000000..0560f21 --- /dev/null +++ b/ZeroLevel.HNSW/Utils/ProbabilityLayerNumberGenerator.cs @@ -0,0 +1,30 @@ +namespace ZeroLevel.HNSW.Services +{ + internal sealed class ProbabilityLayerNumberGenerator + { + private const float DIVIDER = 3.361f; + private readonly float[] _probabilities; + + internal ProbabilityLayerNumberGenerator(int maxLayers, int M) + { + _probabilities = new float[maxLayers]; + var probability = 1.0f / DIVIDER; + for (int i = 0; i < maxLayers; i++) + { + _probabilities[i] = probability; + probability /= DIVIDER; + } + } + + internal int GetRandomLayer() + { + var probability = DefaultRandomGenerator.Instance.NextFloat(); + for (int i = 0; i < _probabilities.Length; i++) + { + if (probability > _probabilities[i]) + return i; + } + return 0; + } + } +} diff --git a/ZeroLevel.HNSW/Services/VectorUtils.cs b/ZeroLevel.HNSW/Utils/VectorUtils.cs similarity index 100% rename from ZeroLevel.HNSW/Services/VectorUtils.cs rename to ZeroLevel.HNSW/Utils/VectorUtils.cs diff --git a/ZeroLevel.HNSW/Services/VisitedBitSet.cs b/ZeroLevel.HNSW/Utils/VisitedBitSet.cs similarity index 100% rename from ZeroLevel.HNSW/Services/VisitedBitSet.cs rename to ZeroLevel.HNSW/Utils/VisitedBitSet.cs diff --git a/ZeroLevel.HNSW/ZeroLevel.HNSW.csproj b/ZeroLevel.HNSW/ZeroLevel.HNSW.csproj index 09920fe..3c39f1c 100644 --- a/ZeroLevel.HNSW/ZeroLevel.HNSW.csproj +++ b/ZeroLevel.HNSW/ZeroLevel.HNSW.csproj @@ -6,6 +6,7 @@ +