HNSW graphs serialization

pull/1/head
unknown 3 years ago
parent 8fab18b0f6
commit 0fdda146da

@ -12,6 +12,8 @@ namespace HNSWDemo.Tests
{ {
private const int count = 20000; private const int count = 20000;
private const int dimensionality = 128; private const int dimensionality = 128;
private const string _graphFileCachee = @"lal_test_graph.bin";
private const string _mapFileCachee = @"lal_test_map.bin";
public void Run() public void Run()
{ {
@ -27,52 +29,66 @@ namespace HNSWDemo.Tests
samples[c].Add(p); samples[c].Add(p);
} }
SplittedLALGraph worlds;
HNSWMappers<long> mappers;
var worlds = new SplittedLALGraph(); if (File.Exists(_graphFileCachee) && File.Exists(_mapFileCachee))
var mappers = new HNSWMappers<long>(l => (int)Math.Abs(l.GetHashCode() % moda));
var worlds_dict = new Dictionary<int, SmallWorld<float[]>>();
var maps_dict = new Dictionary<int, HNSWMap<long>>();
foreach (var p in samples)
{ {
var c = p.Key; worlds = new SplittedLALGraph(_graphFileCachee);
if (worlds_dict.ContainsKey(c) == false) mappers = new HNSWMappers<long>(_mapFileCachee, l => (int)Math.Abs(l.GetHashCode() % moda));
{
worlds_dict.Add(c, new SmallWorld<float[]>(options));
}
if (maps_dict.ContainsKey(c) == false)
{
maps_dict.Add(c, new HNSWMap<long>());
}
var w = worlds_dict[c];
var m = maps_dict[c];
var ids = w.AddItems(p.Value.Select(i => i.Item1));
for (int i = 0; i < ids.Length; i++)
{
m.Append(p.Value[i].Item2.Number, ids[i]);
}
} }
else
var name = Guid.NewGuid().ToString();
foreach (var p in samples)
{ {
var c = p.Key;
var w = worlds_dict[c];
var m = maps_dict[c];
using (var s = File.Create(name)) worlds = new SplittedLALGraph();
mappers = new HNSWMappers<long>(l => (int)Math.Abs(l.GetHashCode() % moda));
var worlds_dict = new Dictionary<int, SmallWorld<float[]>>();
var maps_dict = new Dictionary<int, HNSWMap<long>>();
foreach (var p in samples)
{ {
w.Serialize(s); var c = p.Key;
if (worlds_dict.ContainsKey(c) == false)
{
worlds_dict.Add(c, new SmallWorld<float[]>(options));
}
if (maps_dict.ContainsKey(c) == false)
{
maps_dict.Add(c, new HNSWMap<long>());
}
var w = worlds_dict[c];
var m = maps_dict[c];
var ids = w.AddItems(p.Value.Select(i => i.Item1));
for (int i = 0; i < ids.Length; i++)
{
m.Append(p.Value[i].Item2.Number, ids[i]);
}
} }
using (var s = File.OpenRead(name))
var name = Guid.NewGuid().ToString();
foreach (var p in samples)
{ {
var l = LALGraph.FromHNSWGraph<float[]>(s); var c = p.Key;
worlds.Append(l, c); var w = worlds_dict[c];
var m = maps_dict[c];
using (var s = File.Create(name))
{
w.Serialize(s);
}
using (var s = File.OpenRead(name))
{
var l = LALGraph.FromHNSWGraph<float[]>(s);
worlds.Append(l, c);
}
File.Delete(name);
mappers.Append(m, c);
} }
File.Delete(name);
mappers.Append(m, c); worlds.Save(_graphFileCachee);
mappers.Save(_mapFileCachee);
} }
var entries = new long[10]; var entries = new long[10];
@ -80,7 +96,6 @@ namespace HNSWDemo.Tests
{ {
entries[i] = persons[DefaultRandomGenerator.Instance.Next(0, persons.Count - 1)].Item2.Number; entries[i] = persons[DefaultRandomGenerator.Instance.Next(0, persons.Count - 1)].Item2.Number;
} }
var contexts = mappers.CreateContext(null, entries); var contexts = mappers.CreateContext(null, entries);
var result = worlds.KNearest(5000, contexts); var result = worlds.KNearest(5000, contexts);

@ -14,6 +14,8 @@ namespace ZeroLevel.HNSW
private Dictionary<int, TFeature> _reverse_map; private Dictionary<int, TFeature> _reverse_map;
public int this[TFeature feature] => _map.GetValueOrDefault(feature); public int this[TFeature feature] => _map.GetValueOrDefault(feature);
public HNSWMap() { }
public HNSWMap(int capacity = -1) public HNSWMap(int capacity = -1)
{ {
if (capacity > 0) if (capacity > 0)

@ -1,14 +1,48 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.IO;
using ZeroLevel.Services.Serialization;
namespace ZeroLevel.HNSW namespace ZeroLevel.HNSW
{ {
public class HNSWMappers<TFeature> public class HNSWMappers<TFeature>
: IBinarySerializable
{ {
private readonly IDictionary<int, HNSWMap<TFeature>> _mappers = new Dictionary<int, HNSWMap<TFeature>>(); private IDictionary<int, HNSWMap<TFeature>> _mappers;
private readonly Func<TFeature, int> _bucketFunction; private readonly Func<TFeature, int> _bucketFunction;
public HNSWMappers(string filePath, Func<TFeature, int> bucketFunction)
{
_bucketFunction = bucketFunction;
using (var fs = File.OpenRead(filePath))
{
using (var bs = new BufferedStream(fs, 1024 * 1024 * 32))
{
using (var reader = new MemoryStreamReader(bs))
{
Deserialize(reader);
}
}
}
}
public void Save(string filePath)
{
using (var fs = File.OpenWrite(filePath))
{
using (var bs = new BufferedStream(fs, 1024 * 1024 * 32))
{
using (var writer = new MemoryStreamWriter(bs))
{
Serialize(writer);
}
}
}
}
public HNSWMappers(Func<TFeature, int> bucketFunction) public HNSWMappers(Func<TFeature, int> bucketFunction)
{ {
_mappers = new Dictionary<int, HNSWMap<TFeature>>();
_bucketFunction = bucketFunction; _bucketFunction = bucketFunction;
} }
@ -68,5 +102,15 @@ namespace ZeroLevel.HNSW
} }
return result; return result;
} }
public void Deserialize(IBinaryReader reader)
{
this._mappers = reader.ReadDictionary<int, HNSWMap<TFeature>>();
}
public void Serialize(IBinaryWriter writer)
{
writer.WriteDictionary<int, HNSWMap<TFeature>>(this._mappers);
}
} }
} }

@ -6,10 +6,11 @@ using ZeroLevel.Services.Serialization;
namespace ZeroLevel.HNSW namespace ZeroLevel.HNSW
{ {
public class LALGraph public class LALGraph
: IBinarySerializable
{ {
private readonly LALLinks _links = new LALLinks(); private readonly LALLinks _links = new LALLinks();
private LALGraph() { } public LALGraph() { }
public static LALGraph FromLALGraph(Stream stream) public static LALGraph FromLALGraph(Stream stream)
{ {
var l = new LALGraph(); var l = new LALGraph();
@ -76,7 +77,7 @@ namespace ZeroLevel.HNSW
{ {
using (var reader = new MemoryStreamReader(stream)) using (var reader = new MemoryStreamReader(stream))
{ {
_links.Deserialize(reader); // deserialize only base layer and skip another _links.Deserialize(reader);
} }
} }
@ -105,5 +106,15 @@ namespace ZeroLevel.HNSW
_links.Serialize(writer); _links.Serialize(writer);
} }
} }
public void Serialize(IBinaryWriter writer)
{
_links.Serialize(writer);
}
public void Deserialize(IBinaryReader reader)
{
_links.Deserialize(reader);
}
} }
} }

@ -6,6 +6,7 @@ using ZeroLevel.Services.Serialization;
namespace ZeroLevel.HNSW namespace ZeroLevel.HNSW
{ {
internal class LALLinks internal class LALLinks
: IBinarySerializable
{ {
private ConcurrentDictionary<int, int[]> _set = new ConcurrentDictionary<int, int[]>(); private ConcurrentDictionary<int, int[]> _set = new ConcurrentDictionary<int, int[]>();
internal IDictionary<int, int[]> Links => _set; internal IDictionary<int, int[]> Links => _set;
@ -47,6 +48,7 @@ namespace ZeroLevel.HNSW
_set.Clear(); _set.Clear();
_set = null; _set = null;
} }
public void Serialize(IBinaryWriter writer) public void Serialize(IBinaryWriter writer)
{ {
writer.WriteInt32(_set.Count); writer.WriteInt32(_set.Count);

@ -1,10 +1,47 @@
using System.Collections.Generic; using System.Collections.Generic;
using System.IO;
using System.Linq;
using ZeroLevel.Services.Serialization;
namespace ZeroLevel.HNSW namespace ZeroLevel.HNSW
{ {
public class SplittedLALGraph public class SplittedLALGraph
: IBinarySerializable
{ {
private readonly IDictionary<int, LALGraph> _graphs = new Dictionary<int, LALGraph>(); private IDictionary<int, LALGraph> _graphs;
public SplittedLALGraph()
{
_graphs = new Dictionary<int, LALGraph>();
}
public SplittedLALGraph(string filePath)
{
using (var fs = File.OpenRead(filePath))
{
using (var bs = new BufferedStream(fs, 1024 * 1024 * 32))
{
using (var reader = new MemoryStreamReader(bs))
{
Deserialize(reader);
}
}
}
}
public void Save(string filePath)
{
using (var fs = File.OpenWrite(filePath))
{
using (var bs = new BufferedStream(fs, 1024 * 1024 * 32))
{
using (var writer = new MemoryStreamWriter(bs))
{
Serialize(writer);
}
}
}
}
public void Append(LALGraph graph, int c) public void Append(LALGraph graph, int c)
{ {
@ -34,5 +71,14 @@ namespace ZeroLevel.HNSW
} }
return result; return result;
} }
public void Serialize(IBinaryWriter writer)
{
writer.WriteDictionary<int, LALGraph>(this._graphs);
}
public void Deserialize(IBinaryReader reader)
{
this._graphs = reader.ReadDictionary<int, LALGraph>();
}
} }
} }

@ -96,7 +96,7 @@ namespace ZeroLevel.Services.Serialization
void WriteCollection(IEnumerable<IPAddress> collection); void WriteCollection(IEnumerable<IPAddress> collection);
#endregion #endregion
void WriteDictionary<TKey, TValue>(Dictionary<TKey, TValue> collection); void WriteDictionary<TKey, TValue>(IDictionary<TKey, TValue> collection);
void WriteDictionary<TKey, TValue>(ConcurrentDictionary<TKey, TValue> collection); void WriteDictionary<TKey, TValue>(ConcurrentDictionary<TKey, TValue> collection);
void Write<T>(T item) void Write<T>(T item)

@ -777,7 +777,7 @@ namespace ZeroLevel.Services.Serialization
} }
} }
public void WriteDictionary<TKey, TValue>(Dictionary<TKey, TValue> collection) public void WriteDictionary<TKey, TValue>(IDictionary<TKey, TValue> collection)
{ {
WriteInt32(collection?.Count() ?? 0); WriteInt32(collection?.Count() ?? 0);
if (collection != null) if (collection != null)

Loading…
Cancel
Save

Powered by TurnKey Linux.