From 0fdda146da5bbc9805ce5129cef4359ea1050e9b Mon Sep 17 00:00:00 2001 From: unknown Date: Sat, 22 Jan 2022 19:52:54 +0300 Subject: [PATCH] HNSW graphs serialization --- TestHNSW/HNSWDemo/Tests/LALTest.cs | 91 +++++++++++-------- ZeroLevel.HNSW/Services/HNSWMap.cs | 2 + ZeroLevel.HNSW/Services/HNSWMappers.cs | 46 +++++++++- ZeroLevel.HNSW/Services/LAL/LALGraph.cs | 15 ++- ZeroLevel.HNSW/Services/LAL/LALLinks.cs | 2 + .../Services/LAL/SplittedLALGraph.cs | 48 +++++++++- .../Services/Serialization/IBinaryWriter.cs | 2 +- .../Serialization/MemoryStreamWriter.cs | 2 +- 8 files changed, 164 insertions(+), 44 deletions(-) diff --git a/TestHNSW/HNSWDemo/Tests/LALTest.cs b/TestHNSW/HNSWDemo/Tests/LALTest.cs index b0f85e0..8f21151 100644 --- a/TestHNSW/HNSWDemo/Tests/LALTest.cs +++ b/TestHNSW/HNSWDemo/Tests/LALTest.cs @@ -12,6 +12,8 @@ namespace HNSWDemo.Tests { private const int count = 20000; private const int dimensionality = 128; + private const string _graphFileCachee = @"lal_test_graph.bin"; + private const string _mapFileCachee = @"lal_test_map.bin"; public void Run() { @@ -27,52 +29,66 @@ namespace HNSWDemo.Tests samples[c].Add(p); } + SplittedLALGraph worlds; + HNSWMappers mappers; - var worlds = new SplittedLALGraph(); - var mappers = new HNSWMappers(l => (int)Math.Abs(l.GetHashCode() % moda)); - - var worlds_dict = new Dictionary>(); - var maps_dict = new Dictionary>(); - - foreach (var p in samples) + if (File.Exists(_graphFileCachee) && File.Exists(_mapFileCachee)) { - var c = p.Key; - if (worlds_dict.ContainsKey(c) == false) - { - worlds_dict.Add(c, new SmallWorld(options)); - } - if (maps_dict.ContainsKey(c) == false) - { - maps_dict.Add(c, new HNSWMap()); - } - var w = worlds_dict[c]; - var m = maps_dict[c]; - var ids = w.AddItems(p.Value.Select(i => i.Item1)); - - for (int i = 0; i < ids.Length; i++) - { - m.Append(p.Value[i].Item2.Number, ids[i]); - } + worlds = new SplittedLALGraph(_graphFileCachee); + mappers = new HNSWMappers(_mapFileCachee, l => (int)Math.Abs(l.GetHashCode() % moda)); } - - var name = Guid.NewGuid().ToString(); - foreach (var p in samples) + else { - var c = p.Key; - var w = worlds_dict[c]; - var m = maps_dict[c]; - using (var s = File.Create(name)) + worlds = new SplittedLALGraph(); + mappers = new HNSWMappers(l => (int)Math.Abs(l.GetHashCode() % moda)); + + var worlds_dict = new Dictionary>(); + var maps_dict = new Dictionary>(); + + foreach (var p in samples) { - w.Serialize(s); + var c = p.Key; + if (worlds_dict.ContainsKey(c) == false) + { + worlds_dict.Add(c, new SmallWorld(options)); + } + if (maps_dict.ContainsKey(c) == false) + { + maps_dict.Add(c, new HNSWMap()); + } + var w = worlds_dict[c]; + var m = maps_dict[c]; + var ids = w.AddItems(p.Value.Select(i => i.Item1)); + + for (int i = 0; i < ids.Length; i++) + { + m.Append(p.Value[i].Item2.Number, ids[i]); + } } - using (var s = File.OpenRead(name)) + + var name = Guid.NewGuid().ToString(); + foreach (var p in samples) { - var l = LALGraph.FromHNSWGraph(s); - worlds.Append(l, c); + var c = p.Key; + var w = worlds_dict[c]; + var m = maps_dict[c]; + + using (var s = File.Create(name)) + { + w.Serialize(s); + } + using (var s = File.OpenRead(name)) + { + var l = LALGraph.FromHNSWGraph(s); + worlds.Append(l, c); + } + File.Delete(name); + mappers.Append(m, c); } - File.Delete(name); - mappers.Append(m, c); + + worlds.Save(_graphFileCachee); + mappers.Save(_mapFileCachee); } var entries = new long[10]; @@ -80,7 +96,6 @@ namespace HNSWDemo.Tests { entries[i] = persons[DefaultRandomGenerator.Instance.Next(0, persons.Count - 1)].Item2.Number; } - var contexts = mappers.CreateContext(null, entries); var result = worlds.KNearest(5000, contexts); diff --git a/ZeroLevel.HNSW/Services/HNSWMap.cs b/ZeroLevel.HNSW/Services/HNSWMap.cs index 1574e6d..0f21705 100644 --- a/ZeroLevel.HNSW/Services/HNSWMap.cs +++ b/ZeroLevel.HNSW/Services/HNSWMap.cs @@ -14,6 +14,8 @@ namespace ZeroLevel.HNSW private Dictionary _reverse_map; public int this[TFeature feature] => _map.GetValueOrDefault(feature); + + public HNSWMap() { } public HNSWMap(int capacity = -1) { if (capacity > 0) diff --git a/ZeroLevel.HNSW/Services/HNSWMappers.cs b/ZeroLevel.HNSW/Services/HNSWMappers.cs index 5db34e9..1f124c0 100644 --- a/ZeroLevel.HNSW/Services/HNSWMappers.cs +++ b/ZeroLevel.HNSW/Services/HNSWMappers.cs @@ -1,14 +1,48 @@ using System; using System.Collections.Generic; +using System.IO; +using ZeroLevel.Services.Serialization; namespace ZeroLevel.HNSW { public class HNSWMappers + : IBinarySerializable { - private readonly IDictionary> _mappers = new Dictionary>(); + private IDictionary> _mappers; private readonly Func _bucketFunction; + + public HNSWMappers(string filePath, Func bucketFunction) + { + _bucketFunction = bucketFunction; + using (var fs = File.OpenRead(filePath)) + { + using (var bs = new BufferedStream(fs, 1024 * 1024 * 32)) + { + using (var reader = new MemoryStreamReader(bs)) + { + Deserialize(reader); + } + } + } + } + + public void Save(string filePath) + { + using (var fs = File.OpenWrite(filePath)) + { + using (var bs = new BufferedStream(fs, 1024 * 1024 * 32)) + { + using (var writer = new MemoryStreamWriter(bs)) + { + Serialize(writer); + } + } + } + } + public HNSWMappers(Func bucketFunction) { + _mappers = new Dictionary>(); _bucketFunction = bucketFunction; } @@ -68,5 +102,15 @@ namespace ZeroLevel.HNSW } return result; } + + public void Deserialize(IBinaryReader reader) + { + this._mappers = reader.ReadDictionary>(); + } + + public void Serialize(IBinaryWriter writer) + { + writer.WriteDictionary>(this._mappers); + } } } diff --git a/ZeroLevel.HNSW/Services/LAL/LALGraph.cs b/ZeroLevel.HNSW/Services/LAL/LALGraph.cs index 588e062..ea0a0e1 100644 --- a/ZeroLevel.HNSW/Services/LAL/LALGraph.cs +++ b/ZeroLevel.HNSW/Services/LAL/LALGraph.cs @@ -6,10 +6,11 @@ using ZeroLevel.Services.Serialization; namespace ZeroLevel.HNSW { public class LALGraph + : IBinarySerializable { private readonly LALLinks _links = new LALLinks(); - private LALGraph() { } + public LALGraph() { } public static LALGraph FromLALGraph(Stream stream) { var l = new LALGraph(); @@ -76,7 +77,7 @@ namespace ZeroLevel.HNSW { using (var reader = new MemoryStreamReader(stream)) { - _links.Deserialize(reader); // deserialize only base layer and skip another + _links.Deserialize(reader); } } @@ -105,5 +106,15 @@ namespace ZeroLevel.HNSW _links.Serialize(writer); } } + + public void Serialize(IBinaryWriter writer) + { + _links.Serialize(writer); + } + + public void Deserialize(IBinaryReader reader) + { + _links.Deserialize(reader); + } } } diff --git a/ZeroLevel.HNSW/Services/LAL/LALLinks.cs b/ZeroLevel.HNSW/Services/LAL/LALLinks.cs index dd9bdce..3a3ef5d 100644 --- a/ZeroLevel.HNSW/Services/LAL/LALLinks.cs +++ b/ZeroLevel.HNSW/Services/LAL/LALLinks.cs @@ -6,6 +6,7 @@ using ZeroLevel.Services.Serialization; namespace ZeroLevel.HNSW { internal class LALLinks + : IBinarySerializable { private ConcurrentDictionary _set = new ConcurrentDictionary(); internal IDictionary Links => _set; @@ -47,6 +48,7 @@ namespace ZeroLevel.HNSW _set.Clear(); _set = null; } + public void Serialize(IBinaryWriter writer) { writer.WriteInt32(_set.Count); diff --git a/ZeroLevel.HNSW/Services/LAL/SplittedLALGraph.cs b/ZeroLevel.HNSW/Services/LAL/SplittedLALGraph.cs index fcdd4ec..243cf99 100644 --- a/ZeroLevel.HNSW/Services/LAL/SplittedLALGraph.cs +++ b/ZeroLevel.HNSW/Services/LAL/SplittedLALGraph.cs @@ -1,10 +1,47 @@ using System.Collections.Generic; +using System.IO; +using System.Linq; +using ZeroLevel.Services.Serialization; namespace ZeroLevel.HNSW { public class SplittedLALGraph + : IBinarySerializable { - private readonly IDictionary _graphs = new Dictionary(); + private IDictionary _graphs; + + public SplittedLALGraph() + { + _graphs = new Dictionary(); + } + + public SplittedLALGraph(string filePath) + { + using (var fs = File.OpenRead(filePath)) + { + using (var bs = new BufferedStream(fs, 1024 * 1024 * 32)) + { + using (var reader = new MemoryStreamReader(bs)) + { + Deserialize(reader); + } + } + } + } + + public void Save(string filePath) + { + using (var fs = File.OpenWrite(filePath)) + { + using (var bs = new BufferedStream(fs, 1024 * 1024 * 32)) + { + using (var writer = new MemoryStreamWriter(bs)) + { + Serialize(writer); + } + } + } + } public void Append(LALGraph graph, int c) { @@ -34,5 +71,14 @@ namespace ZeroLevel.HNSW } return result; } + + public void Serialize(IBinaryWriter writer) + { + writer.WriteDictionary(this._graphs); + } + public void Deserialize(IBinaryReader reader) + { + this._graphs = reader.ReadDictionary(); + } } } diff --git a/ZeroLevel/Services/Serialization/IBinaryWriter.cs b/ZeroLevel/Services/Serialization/IBinaryWriter.cs index 2e405e8..212fdc2 100644 --- a/ZeroLevel/Services/Serialization/IBinaryWriter.cs +++ b/ZeroLevel/Services/Serialization/IBinaryWriter.cs @@ -96,7 +96,7 @@ namespace ZeroLevel.Services.Serialization void WriteCollection(IEnumerable collection); #endregion - void WriteDictionary(Dictionary collection); + void WriteDictionary(IDictionary collection); void WriteDictionary(ConcurrentDictionary collection); void Write(T item) diff --git a/ZeroLevel/Services/Serialization/MemoryStreamWriter.cs b/ZeroLevel/Services/Serialization/MemoryStreamWriter.cs index b6bd01a..18abaea 100644 --- a/ZeroLevel/Services/Serialization/MemoryStreamWriter.cs +++ b/ZeroLevel/Services/Serialization/MemoryStreamWriter.cs @@ -777,7 +777,7 @@ namespace ZeroLevel.Services.Serialization } } - public void WriteDictionary(Dictionary collection) + public void WriteDictionary(IDictionary collection) { WriteInt32(collection?.Count() ?? 0); if (collection != null)