ReadOnlySmallWorld

The ability to create a compact representation from a graph. Only for reading. Memory costs in bytes are reduced by the formula:
MemCompact = MemFull - AllLinks * 4
pull/1/head
unknown 3 years ago
parent 4132147657
commit f8b72e38e3

@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Linq;
using ZeroLevel.HNSW;
@ -98,10 +99,267 @@ namespace HNSWDemo
static void Main(string[] args)
{
FilterTest();
TransformToCompactWorldTestWithAccuracity();
Console.ReadKey();
}
static void TransformToCompactWorldTest()
{
var count = 10000;
var dimensionality = 128;
var samples = RandomVectors(dimensionality, count);
var world = new SmallWorld<float[]>(NSWOptions<float[]>.Create(6, 15, 200, 200, CosineDistance.ForUnits, true, true, selectionHeuristic: NeighbourSelectionHeuristic.SelectSimple));
var ids = world.AddItems(samples.ToArray());
Console.WriteLine("Start test");
byte[] dump;
using (var ms = new MemoryStream())
{
world.Serialize(ms);
dump = ms.ToArray();
}
Console.WriteLine($"Full dump size: {dump.Length} bytes");
ReadOnlySmallWorld<float[]> compactWorld;
using (var ms = new MemoryStream(dump))
{
compactWorld = SmallWorld.CreateReadOnlyWorldFrom<float[]>(NSWReadOnlyOption<float[]>.Create(200, CosineDistance.ForUnits, true, true, selectionHeuristic: NeighbourSelectionHeuristic.SelectSimple), ms);
}
// Compare worlds outputs
int K = 200;
var hits = 0;
var miss = 0;
var testCount = 1000;
var sw = new Stopwatch();
var timewatchesHNSW = new List<float>();
var timewatchesHNSWCompact = new List<float>();
var test_vectors = RandomVectors(dimensionality, testCount);
foreach (var v in test_vectors)
{
sw.Restart();
var gt = world.Search(v, K).Select(e => e.Item1).ToHashSet();
sw.Stop();
timewatchesHNSW.Add(sw.ElapsedMilliseconds);
sw.Restart();
var result = compactWorld.Search(v, K).Select(e => e.Item1).ToHashSet();
sw.Stop();
timewatchesHNSWCompact.Add(sw.ElapsedMilliseconds);
foreach (var r in result)
{
if (gt.Contains(r))
{
hits++;
}
else
{
miss++;
}
}
}
byte[] smallWorldDump;
using (var ms = new MemoryStream())
{
compactWorld.Serialize(ms);
smallWorldDump = ms.ToArray();
}
var p = smallWorldDump.Length * 100.0f / dump.Length;
Console.WriteLine($"Compact dump size: {smallWorldDump.Length} bytes. Decrease: {100 - p}%");
Console.WriteLine($"HITS: {hits}");
Console.WriteLine($"MISSES: {miss}");
Console.WriteLine($"MIN HNSW TIME: {timewatchesHNSW.Min()} ms");
Console.WriteLine($"AVG HNSW TIME: {timewatchesHNSW.Average()} ms");
Console.WriteLine($"MAX HNSW TIME: {timewatchesHNSW.Max()} ms");
Console.WriteLine($"MIN HNSWCompact TIME: {timewatchesHNSWCompact.Min()} ms");
Console.WriteLine($"AVG HNSWCompact TIME: {timewatchesHNSWCompact.Average()} ms");
Console.WriteLine($"MAX HNSWCompact TIME: {timewatchesHNSWCompact.Max()} ms");
}
static void TransformToCompactWorldTestWithAccuracity()
{
var count = 10000;
var dimensionality = 128;
var samples = RandomVectors(dimensionality, count);
var test = new VectorsDirectCompare(samples, CosineDistance.ForUnits);
var world = new SmallWorld<float[]>(NSWOptions<float[]>.Create(6, 15, 200, 200, CosineDistance.ForUnits, true, true, selectionHeuristic: NeighbourSelectionHeuristic.SelectSimple));
var ids = world.AddItems(samples.ToArray());
Console.WriteLine("Start test");
byte[] dump;
using (var ms = new MemoryStream())
{
world.Serialize(ms);
dump = ms.ToArray();
}
ReadOnlySmallWorld<float[]> compactWorld;
using (var ms = new MemoryStream(dump))
{
compactWorld = SmallWorld.CreateReadOnlyWorldFrom<float[]>(NSWReadOnlyOption<float[]>.Create(200, CosineDistance.ForUnits, true, true, selectionHeuristic: NeighbourSelectionHeuristic.SelectSimple), ms);
}
// Compare worlds outputs
int K = 200;
var hits = 0;
var miss = 0;
var testCount = 2000;
var sw = new Stopwatch();
var timewatchesNP = new List<float>();
var timewatchesHNSW = new List<float>();
var timewatchesHNSWCompact = new List<float>();
var test_vectors = RandomVectors(dimensionality, testCount);
var totalHitsHNSW = new List<int>();
var totalHitsHNSWCompact = new List<int>();
foreach (var v in test_vectors)
{
var npHitsHNSW = 0;
var npHitsHNSWCompact = 0;
sw.Restart();
var gtNP = test.KNearest(v, K).Select(p => p.Item1).ToHashSet();
sw.Stop();
timewatchesNP.Add(sw.ElapsedMilliseconds);
sw.Restart();
var gt = world.Search(v, K).Select(e => e.Item1).ToHashSet();
sw.Stop();
timewatchesHNSW.Add(sw.ElapsedMilliseconds);
sw.Restart();
var result = compactWorld.Search(v, K).Select(e => e.Item1).ToHashSet();
sw.Stop();
timewatchesHNSWCompact.Add(sw.ElapsedMilliseconds);
foreach (var r in result)
{
if (gt.Contains(r))
{
hits++;
}
else
{
miss++;
}
if (gtNP.Contains(r))
{
npHitsHNSWCompact++;
}
}
foreach (var r in gt)
{
if (gtNP.Contains(r))
{
npHitsHNSW++;
}
}
totalHitsHNSW.Add(npHitsHNSW);
totalHitsHNSWCompact.Add(npHitsHNSWCompact);
}
byte[] smallWorldDump;
using (var ms = new MemoryStream())
{
compactWorld.Serialize(ms);
smallWorldDump = ms.ToArray();
}
var p = smallWorldDump.Length * 100.0f / dump.Length;
Console.WriteLine($"Full dump size: {dump.Length} bytes");
Console.WriteLine($"Compact dump size: {smallWorldDump.Length} bytes. Decrease: {100 - p}%");
Console.WriteLine($"HITS: {hits}");
Console.WriteLine($"MISSES: {miss}");
Console.WriteLine($"MIN NP TIME: {timewatchesNP.Min()} ms");
Console.WriteLine($"AVG NP TIME: {timewatchesNP.Average()} ms");
Console.WriteLine($"MAX NP TIME: {timewatchesNP.Max()} ms");
Console.WriteLine($"MIN HNSW TIME: {timewatchesHNSW.Min()} ms");
Console.WriteLine($"AVG HNSW TIME: {timewatchesHNSW.Average()} ms");
Console.WriteLine($"MAX HNSW TIME: {timewatchesHNSW.Max()} ms");
Console.WriteLine($"MIN HNSWCompact TIME: {timewatchesHNSWCompact.Min()} ms");
Console.WriteLine($"AVG HNSWCompact TIME: {timewatchesHNSWCompact.Average()} ms");
Console.WriteLine($"MAX HNSWCompact TIME: {timewatchesHNSWCompact.Max()} ms");
Console.WriteLine($"MIN HNSW Accuracity: {totalHitsHNSW.Min() * 100 / K}%");
Console.WriteLine($"AVG HNSW Accuracity: {totalHitsHNSW.Average() * 100 / K}%");
Console.WriteLine($"MAX HNSW Accuracity: {totalHitsHNSW.Max() * 100 / K}%");
Console.WriteLine($"MIN HNSWCompact Accuracity: {totalHitsHNSWCompact.Min() * 100 / K}%");
Console.WriteLine($"AVG HNSWCompact Accuracity: {totalHitsHNSWCompact.Average() * 100 / K}%");
Console.WriteLine($"MAX HNSWCompact Accuracity: {totalHitsHNSWCompact.Max() * 100 / K}%");
}
static void SaveRestoreTest()
{
var count = 1000;
var dimensionality = 128;
var samples = RandomVectors(dimensionality, count);
var world = new SmallWorld<float[]>(NSWOptions<float[]>.Create(6, 15, 200, 200, CosineDistance.ForUnits, true, true, selectionHeuristic: NeighbourSelectionHeuristic.SelectSimple));
var sw = new Stopwatch();
sw.Start();
var ids = world.AddItems(samples.ToArray());
sw.Stop();
Console.WriteLine($"Insert {ids.Length} items on {sw.ElapsedMilliseconds} ms");
Console.WriteLine("Start test");
byte[] dump;
using (var ms = new MemoryStream())
{
world.Serialize(ms);
dump = ms.ToArray();
}
Console.WriteLine($"Full dump size: {dump.Length} bytes");
byte[] testDump;
var restoredWorld = new SmallWorld<float[]>(NSWOptions<float[]>.Create(6, 15, 200, 200, CosineDistance.ForUnits, true, true, selectionHeuristic: NeighbourSelectionHeuristic.SelectSimple));
using (var ms = new MemoryStream(dump))
{
restoredWorld.Deserialize(ms);
}
using (var ms = new MemoryStream())
{
restoredWorld.Serialize(ms);
testDump = ms.ToArray();
}
if (testDump.Length != dump.Length)
{
Console.WriteLine($"Incorrect restored size. Got {testDump.Length}. Expected: {dump.Length}");
return;
}
ReadOnlySmallWorld<float[]> compactWorld;
using (var ms = new MemoryStream(dump))
{
compactWorld = SmallWorld.CreateReadOnlyWorldFrom<float[]>(NSWReadOnlyOption<float[]>.Create(200, CosineDistance.ForUnits, true, true, selectionHeuristic: NeighbourSelectionHeuristic.SelectSimple), ms);
}
byte[] smallWorldDump;
using (var ms = new MemoryStream())
{
compactWorld.Serialize(ms);
smallWorldDump = ms.ToArray();
}
var p = smallWorldDump.Length * 100.0f / dump.Length;
Console.WriteLine($"Compact dump size: {smallWorldDump.Length} bytes. Decrease: {100 - p}%");
}
static void FilterTest()
{
var count = 5000;

@ -2,29 +2,12 @@
namespace ZeroLevel.HNSW
{
/// <summary>
/// Type of heuristic to select best neighbours for a node.
/// </summary>
public enum NeighbourSelectionHeuristic
{
/// <summary>
/// Marker for the Algorithm 3 (SELECT-NEIGHBORS-SIMPLE) from the article. Implemented in <see cref="Algorithms.Algorithm3{TItem, TDistance}"/>
/// </summary>
SelectSimple,
/// <summary>
/// Marker for the Algorithm 4 (SELECT-NEIGHBORS-HEURISTIC) from the article. Implemented in <see cref="Algorithms.Algorithm4{TItem, TDistance}"/>
/// </summary>
SelectHeuristic
}
public sealed class NSWOptions<TItem>
{
/// <summary>
/// Mox node connections on Layer
/// </summary>
public readonly int M;
/// <summary>
/// Max search buffer
/// </summary>

@ -0,0 +1,44 @@
using System;
namespace ZeroLevel.HNSW
{
public sealed class NSWReadOnlyOption<TItem>
{
/// <summary>
/// Max search buffer
/// </summary>
public readonly int EF;
/// <summary>
/// Distance function beetween vectors
/// </summary>
public readonly Func<TItem, TItem, float> Distance;
public readonly bool ExpandBestSelection;
public readonly bool KeepPrunedConnections;
public readonly NeighbourSelectionHeuristic SelectionHeuristic;
private NSWReadOnlyOption(
int ef,
Func<TItem, TItem, float> distance,
bool expandBestSelection,
bool keepPrunedConnections,
NeighbourSelectionHeuristic selectionHeuristic)
{
EF = ef;
Distance = distance;
ExpandBestSelection = expandBestSelection;
KeepPrunedConnections = keepPrunedConnections;
SelectionHeuristic = selectionHeuristic;
}
public static NSWReadOnlyOption<TItem> Create(
int EF,
Func<TItem, TItem, float> distance,
bool expandBestSelection = false,
bool keepPrunedConnections = false,
NeighbourSelectionHeuristic selectionHeuristic = NeighbourSelectionHeuristic.SelectSimple) =>
new NSWReadOnlyOption<TItem>(EF, distance, expandBestSelection, keepPrunedConnections, selectionHeuristic);
}
}

@ -0,0 +1,18 @@
namespace ZeroLevel.HNSW
{
/// <summary>
/// Type of heuristic to select best neighbours for a node.
/// </summary>
public enum NeighbourSelectionHeuristic
{
/// <summary>
/// Marker for the Algorithm 3 (SELECT-NEIGHBORS-SIMPLE) from the article. Implemented in <see cref="Algorithms.Algorithm3{TItem, TDistance}"/>
/// </summary>
SelectSimple,
/// <summary>
/// Marker for the Algorithm 4 (SELECT-NEIGHBORS-HEURISTIC) from the article. Implemented in <see cref="Algorithms.Algorithm4{TItem, TDistance}"/>
/// </summary>
SelectHeuristic
}
}

@ -0,0 +1,152 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using ZeroLevel.Services.Serialization;
namespace ZeroLevel.HNSW
{
public class ReadOnlySmallWorld<TItem>
{
private readonly NSWReadOnlyOption<TItem> _options;
private ReadOnlyVectorSet<TItem> _vectors;
private ReadOnlyLayer<TItem>[] _layers;
private int EntryPoint = 0;
private int MaxLayer = 0;
private ReadOnlySmallWorld() { }
internal ReadOnlySmallWorld(NSWReadOnlyOption<TItem> options, Stream stream)
{
_options = options;
Deserialize(stream);
}
/// <summary>
/// Search in the graph K for vectors closest to a given vector
/// </summary>
/// <param name="vector">Given vector</param>
/// <param name="k">Count of elements for search</param>
/// <param name="activeNodes"></param>
/// <returns></returns>
public IEnumerable<(int, TItem, float)> Search(TItem vector, int k)
{
foreach (var pair in KNearest(vector, k))
{
yield return (pair.Item1, _vectors[pair.Item1], pair.Item2);
}
}
public IEnumerable<(int, TItem, float)> Search(TItem vector, int k, HashSet<int> activeNodes)
{
if (activeNodes == null)
{
foreach (var pair in KNearest(vector, k))
{
yield return (pair.Item1, _vectors[pair.Item1], pair.Item2);
}
}
else
{
foreach (var pair in KNearest(vector, k, activeNodes))
{
yield return (pair.Item1, _vectors[pair.Item1], pair.Item2);
}
}
}
/// <summary>
/// Algorithm 5
/// </summary>
private IEnumerable<(int, float)> KNearest(TItem q, int k)
{
if (_vectors.Count == 0)
{
return Enumerable.Empty<(int, float)>();
}
var distance = new Func<int, float>(candidate => _options.Distance(q, _vectors[candidate]));
// W ← ∅ // set for the current nearest elements
var W = new Dictionary<int, float>(k + 1);
// ep ← get enter point for hnsw
var ep = EntryPoint;
// L ← level of ep // top layer for hnsw
var L = MaxLayer;
// for lc ← L … 1
for (int layer = L; layer > 0; --layer)
{
// W ← SEARCH-LAYER(q, ep, ef = 1, lc)
_layers[layer].KNearestAtLayer(ep, distance, W, 1);
// ep ← get nearest element from W to q
ep = W.OrderBy(p => p.Value).First().Key;
W.Clear();
}
// W ← SEARCH-LAYER(q, ep, ef, lc =0)
_layers[0].KNearestAtLayer(ep, distance, W, k);
// return K nearest elements from W to q
return W.Select(p => (p.Key, p.Value));
}
private IEnumerable<(int, float)> KNearest(TItem q, int k, HashSet<int> activeNodes)
{
if (_vectors.Count == 0)
{
return Enumerable.Empty<(int, float)>();
}
var distance = new Func<int, float>(candidate => _options.Distance(q, _vectors[candidate]));
// W ← ∅ // set for the current nearest elements
var W = new Dictionary<int, float>(k + 1);
// ep ← get enter point for hnsw
var ep = EntryPoint;
// L ← level of ep // top layer for hnsw
var L = MaxLayer;
// for lc ← L … 1
for (int layer = L; layer > 0; --layer)
{
// W ← SEARCH-LAYER(q, ep, ef = 1, lc)
_layers[layer].KNearestAtLayer(ep, distance, W, 1);
// ep ← get nearest element from W to q
ep = W.OrderBy(p => p.Value).First().Key;
W.Clear();
}
// W ← SEARCH-LAYER(q, ep, ef, lc =0)
_layers[0].KNearestAtLayer(ep, distance, W, k, activeNodes);
// return K nearest elements from W to q
return W.Select(p => (p.Key, p.Value));
}
public void Serialize(Stream stream)
{
using (var writer = new MemoryStreamWriter(stream))
{
writer.WriteInt32(EntryPoint);
writer.WriteInt32(MaxLayer);
_vectors.Serialize(writer);
writer.WriteInt32(_layers.Length);
foreach (var l in _layers)
{
l.Serialize(writer);
}
}
}
public void Deserialize(Stream stream)
{
using (var reader = new MemoryStreamReader(stream))
{
this.EntryPoint = reader.ReadInt32();
this.MaxLayer = reader.ReadInt32();
_vectors = new ReadOnlyVectorSet<TItem>();
_vectors.Deserialize(reader);
var countLayers = reader.ReadInt32();
_layers = new ReadOnlyLayer<TItem>[countLayers];
for (int i = 0; i < countLayers; i++)
{
_layers[i] = new ReadOnlyLayer<TItem>(_options, _vectors);
_layers[i].Deserialize(reader);
}
}
}
}
}

@ -2,11 +2,12 @@
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using ZeroLevel.Services.Serialization;
namespace ZeroLevel.HNSW
{
internal sealed class CompactBiDirectionalLinksSet
: IDisposable
: IBinarySerializable, IDisposable
{
private readonly ReaderWriterLockSlim _rwLock = new ReaderWriterLockSlim();
@ -14,7 +15,7 @@ namespace ZeroLevel.HNSW
private SortedList<long, float> _set = new SortedList<long, float>();
public (int, int) this[int index]
internal (int, int) this[int index]
{
get
{
@ -25,12 +26,12 @@ namespace ZeroLevel.HNSW
}
}
public int Count => _set.Count;
internal int Count => _set.Count;
/// <summary>
/// Разрывает связи id1 - id2 и id2 - id1, и строит новые id1 - id, id - id1
/// </summary>
public void Relink(int id1, int id2, int id, float distance)
internal void Relink(int id1, int id2, int id, float distance)
{
long k1old = (((long)(id1)) << HALF_LONG_BITS) + id2;
long k2old = (((long)(id2)) << HALF_LONG_BITS) + id1;
@ -57,7 +58,7 @@ namespace ZeroLevel.HNSW
/// <summary>
/// Разрывает связи id1 - id2 и id2 - id1, и строит новые id1 - id, id - id1, id2 - id, id - id2
/// </summary>
public void Relink(int id1, int id2, int id, float distanceToId1, float distanceToId2)
internal void Relink(int id1, int id2, int id, float distanceToId1, float distanceToId2)
{
long k_id1_id2 = (((long)(id1)) << HALF_LONG_BITS) + id2;
long k_id2_id1 = (((long)(id2)) << HALF_LONG_BITS) + id1;
@ -96,7 +97,7 @@ namespace ZeroLevel.HNSW
}
}
public IEnumerable<(int, int, float)> FindLinksForId(int id)
internal IEnumerable<(int, int, float)> FindLinksForId(int id)
{
_rwLock.EnterReadLock();
try
@ -114,7 +115,7 @@ namespace ZeroLevel.HNSW
}
}
public IEnumerable<(int, int, float)> Items()
internal IEnumerable<(int, int, float)> Items()
{
_rwLock.EnterReadLock();
try
@ -132,7 +133,7 @@ namespace ZeroLevel.HNSW
}
}
public void RemoveIndex(int id)
internal void RemoveIndex(int id)
{
long[] forward;
long[] backward;
@ -169,7 +170,7 @@ namespace ZeroLevel.HNSW
}
}
public bool Add(int id1, int id2, float distance)
internal bool Add(int id1, int id2, float distance)
{
_rwLock.EnterWriteLock();
try
@ -193,7 +194,7 @@ namespace ZeroLevel.HNSW
return false;
}
static IEnumerable<(long, float)> Search(SortedList<long, float> set, int index)
private static IEnumerable<(long, float)> Search(SortedList<long, float> set, int index)
{
long k = ((long)index) << HALF_LONG_BITS;
int left = 0;
@ -232,7 +233,7 @@ namespace ZeroLevel.HNSW
return Enumerable.Empty<(long, float)>();
}
static IEnumerable<(long, float)> SearchByPosition(SortedList<long, float> set, long k, int position)
private static IEnumerable<(long, float)> SearchByPosition(SortedList<long, float> set, long k, int position)
{
var start = position;
var end = position;
@ -259,5 +260,34 @@ namespace ZeroLevel.HNSW
_set.Clear();
_set = null;
}
public void Serialize(IBinaryWriter writer)
{
writer.WriteBoolean(true); // true - set with weights
writer.WriteInt32(_set.Count);
foreach (var record in _set)
{
writer.WriteLong(record.Key);
writer.WriteFloat(record.Value);
}
}
public void Deserialize(IBinaryReader reader)
{
if (reader.ReadBoolean() == false)
{
throw new InvalidOperationException("Incompatible data format. The set does not contain weights.");
}
_set.Clear();
_set = null;
var count = reader.ReadInt32();
_set = new SortedList<long, float>(count + 1);
for (int i = 0; i < count; i++)
{
var key = reader.ReadLong();
var value = reader.ReadFloat();
_set.Add(key, value);
}
}
}
}

@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using ZeroLevel.Services.Serialization;
namespace ZeroLevel.HNSW
{
@ -8,6 +9,7 @@ namespace ZeroLevel.HNSW
/// NSW graph
/// </summary>
internal sealed class Layer<TItem>
: IBinarySerializable
{
private readonly NSWOptions<TItem> _options;
private readonly VectorSet<TItem> _vectors;
@ -352,5 +354,15 @@ namespace ZeroLevel.HNSW
#endregion
private IEnumerable<int> GetNeighbors(int id) => _links.FindLinksForId(id).Select(d => d.Item2);
public void Serialize(IBinaryWriter writer)
{
_links.Serialize(writer);
}
public void Deserialize(IBinaryReader reader)
{
_links.Deserialize(reader);
}
}
}

@ -0,0 +1,107 @@
using System;
using System.Collections.Generic;
using System.Linq;
using ZeroLevel.Services.Serialization;
namespace ZeroLevel.HNSW
{
internal sealed class ReadOnlyCompactBiDirectionalLinksSet
: IBinarySerializable, IDisposable
{
private const int HALF_LONG_BITS = 32;
private Dictionary<int, int[]> _set = new Dictionary<int, int[]>();
internal int Count => _set.Count;
internal IEnumerable<int> FindLinksForId(int id)
{
if (_set.ContainsKey(id))
{
return _set[id];
}
return Enumerable.Empty<int>();
}
public void Dispose()
{
_set.Clear();
_set = null;
}
public void Serialize(IBinaryWriter writer)
{
writer.WriteBoolean(false); // false - set without weights
writer.WriteInt32(_set.Count);
foreach (var record in _set)
{
writer.WriteInt32(record.Key);
writer.WriteCollection(record.Value);
}
}
public void Deserialize(IBinaryReader reader)
{
_set.Clear();
_set = null;
if (reader.ReadBoolean() == false)
{
var count = reader.ReadInt32();
_set = new Dictionary<int, int[]>(count);
for (int i = 0; i < count; i++)
{
var key = reader.ReadInt32();
var value = reader.ReadInt32Array();
_set.Add(key, value);
}
}
else
{
var count = reader.ReadInt32();
_set = new Dictionary<int, int[]>(count);
// hack, We know that an sortedset has been saved
long key;
int id1, id2;
var prevId = -1;
var set = new HashSet<int>();
for (int i = 0; i < count; i++)
{
key = reader.ReadLong();
id1 = (int)(key >> HALF_LONG_BITS);
id2 = (int)(key - (((long)id1) << HALF_LONG_BITS));
reader.ReadFloat(); // SKIP
if (prevId == -1)
{
prevId = id1;
if (id1 != id2)
{
set.Add(id2);
}
}
else if (prevId != id1)
{
_set.Add(prevId, set.ToArray());
set.Clear();
prevId = id1;
}
else
{
if (id1 != id2)
{
set.Add(id2);
}
}
}
if (set.Count > 0)
{
_set.Add(prevId, set.ToArray());
set.Clear();
}
}
}
}
}

@ -0,0 +1,316 @@
using System;
using System.Collections.Generic;
using System.Linq;
using ZeroLevel.Services.Serialization;
namespace ZeroLevel.HNSW
{
/// <summary>
/// NSW graph
/// </summary>
internal sealed class ReadOnlyLayer<TItem>
: IBinarySerializable
{
private readonly NSWReadOnlyOption<TItem> _options;
private readonly ReadOnlyVectorSet<TItem> _vectors;
private readonly ReadOnlyCompactBiDirectionalLinksSet _links;
/// <summary>
/// HNSW layer
/// </summary>
/// <param name="options">HNSW graph options</param>
/// <param name="vectors">General vector set</param>
internal ReadOnlyLayer(NSWReadOnlyOption<TItem> options, ReadOnlyVectorSet<TItem> vectors)
{
_options = options;
_vectors = vectors;
_links = new ReadOnlyCompactBiDirectionalLinksSet();
}
#region Implementation of https://arxiv.org/ftp/arxiv/papers/1603/1603.09320.pdf
/// <summary>
/// Algorithm 2
/// </summary>
/// <param name="q">query element</param>
/// <param name="ep">enter points ep</param>
/// <returns>Output: ef closest neighbors to q</returns>
internal void KNearestAtLayer(int entryPointId, Func<int, float> targetCosts, IDictionary<int, float> W, int ef)
{
/*
* v ep // set of visited elements
* C ep // set of candidates
* W ep // dynamic list of found nearest neighbors
* while C > 0
* c extract nearest element from C to q
* f get furthest element from W to q
* if distance(c, q) > distance(f, q)
* break // all elements in W are evaluated
* for each e neighbourhood(c) at layer lc // update C and W
* if e v
* v v e
* f get furthest element from W to q
* if distance(e, q) < distance(f, q) or W < ef
* C C e
* W W e
* if W > ef
* remove furthest element from W to q
* return W
*/
var v = new VisitedBitSet(_vectors.Count, 1);
// v ← ep // set of visited elements
v.Add(entryPointId);
// C ← ep // set of candidates
var C = new Dictionary<int, float>();
C.Add(entryPointId, targetCosts(entryPointId));
// W ← ep // dynamic list of found nearest neighbors
W.Add(entryPointId, C[entryPointId]);
var popCandidate = new Func<(int, float)>(() => { var pair = C.OrderBy(e => e.Value).First(); C.Remove(pair.Key); return (pair.Key, pair.Value); });
var fartherFromResult = new Func<(int, float)>(() => { var pair = W.OrderByDescending(e => e.Value).First(); return (pair.Key, pair.Value); });
var fartherPopFromResult = new Action(() => { var pair = W.OrderByDescending(e => e.Value).First(); W.Remove(pair.Key); });
// run bfs
while (C.Count > 0)
{
// get next candidate to check and expand
var toExpand = popCandidate();
var farthestResult = fartherFromResult();
if (toExpand.Item2 > farthestResult.Item2)
{
// the closest candidate is farther than farthest result
break;
}
// expand candidate
var neighboursIds = GetNeighbors(toExpand.Item1).ToArray();
for (int i = 0; i < neighboursIds.Length; ++i)
{
int neighbourId = neighboursIds[i];
if (!v.Contains(neighbourId))
{
// enqueue perspective neighbours to expansion list
farthestResult = fartherFromResult();
var neighbourDistance = targetCosts(neighbourId);
if (W.Count < ef || neighbourDistance < farthestResult.Item2)
{
C.Add(neighbourId, neighbourDistance);
W.Add(neighbourId, neighbourDistance);
if (W.Count > ef)
{
fartherPopFromResult();
}
}
v.Add(neighbourId);
}
}
}
C.Clear();
v.Clear();
}
/// <summary>
/// Algorithm 2
/// </summary>
/// <param name="q">query element</param>
/// <param name="ep">enter points ep</param>
/// <returns>Output: ef closest neighbors to q</returns>
internal void KNearestAtLayer(int entryPointId, Func<int, float> targetCosts, IDictionary<int, float> W, int ef, HashSet<int> activeNodes)
{
/*
* v ep // set of visited elements
* C ep // set of candidates
* W ep // dynamic list of found nearest neighbors
* while C > 0
* c extract nearest element from C to q
* f get furthest element from W to q
* if distance(c, q) > distance(f, q)
* break // all elements in W are evaluated
* for each e neighbourhood(c) at layer lc // update C and W
* if e v
* v v e
* f get furthest element from W to q
* if distance(e, q) < distance(f, q) or W < ef
* C C e
* W W e
* if W > ef
* remove furthest element from W to q
* return W
*/
var v = new VisitedBitSet(_vectors.Count, 1);
// v ← ep // set of visited elements
v.Add(entryPointId);
// C ← ep // set of candidates
var C = new Dictionary<int, float>();
C.Add(entryPointId, targetCosts(entryPointId));
// W ← ep // dynamic list of found nearest neighbors
if (activeNodes.Contains(entryPointId))
{
W.Add(entryPointId, C[entryPointId]);
}
var popCandidate = new Func<(int, float)>(() => { var pair = C.OrderBy(e => e.Value).First(); C.Remove(pair.Key); return (pair.Key, pair.Value); });
var farthestDistance = new Func<float>(() => { var pair = W.OrderByDescending(e => e.Value).First(); return pair.Value; });
var fartherPopFromResult = new Action(() => { var pair = W.OrderByDescending(e => e.Value).First(); W.Remove(pair.Key); });
// run bfs
while (C.Count > 0)
{
// get next candidate to check and expand
var toExpand = popCandidate();
if (W.Count > 0)
{
if (toExpand.Item2 > farthestDistance())
{
// the closest candidate is farther than farthest result
break;
}
}
// expand candidate
var neighboursIds = GetNeighbors(toExpand.Item1).ToArray();
for (int i = 0; i < neighboursIds.Length; ++i)
{
int neighbourId = neighboursIds[i];
if (!v.Contains(neighbourId))
{
// enqueue perspective neighbours to expansion list
var neighbourDistance = targetCosts(neighbourId);
if (activeNodes.Contains(neighbourId))
{
if (W.Count < ef || (W.Count > 0 && neighbourDistance < farthestDistance()))
{
W.Add(neighbourId, neighbourDistance);
if (W.Count > ef)
{
fartherPopFromResult();
}
}
}
if (W.Count < ef)
{
C.Add(neighbourId, neighbourDistance);
}
v.Add(neighbourId);
}
}
}
C.Clear();
v.Clear();
}
/// <summary>
/// Algorithm 3
/// </summary>
internal IDictionary<int, float> SELECT_NEIGHBORS_SIMPLE(Func<int, float> distance, IDictionary<int, float> candidates, int M)
{
var bestN = M;
var W = new Dictionary<int, float>(candidates);
if (W.Count > bestN)
{
var popFarther = new Action(() => { var pair = W.OrderByDescending(e => e.Value).First(); W.Remove(pair.Key); });
while (W.Count > bestN)
{
popFarther();
}
}
// return M nearest elements from C to q
return W;
}
/// <summary>
/// Algorithm 4
/// </summary>
/// <param name="q">base element</param>
/// <param name="C">candidate elements</param>
/// <param name="extendCandidates">flag indicating whether or not to extend candidate list</param>
/// <param name="keepPrunedConnections">flag indicating whether or not to add discarded elements</param>
/// <returns>Output: M elements selected by the heuristic</returns>
internal IDictionary<int, float> SELECT_NEIGHBORS_HEURISTIC(Func<int, float> distance, IDictionary<int, float> candidates, int M)
{
// R ← ∅
var R = new Dictionary<int, float>();
// W ← C // working queue for the candidates
var W = new Dictionary<int, float>(candidates);
// if extendCandidates // extend candidates by their neighbors
if (_options.ExpandBestSelection)
{
var extendBuffer = new HashSet<int>();
// for each e ∈ C
foreach (var e in W)
{
var neighbors = GetNeighbors(e.Key);
// for each e_adj ∈ neighbourhood(e) at layer lc
foreach (var e_adj in neighbors)
{
// if eadj ∉ W
if (extendBuffer.Contains(e_adj) == false)
{
extendBuffer.Add(e_adj);
}
}
}
// W ← W eadj
foreach (var id in extendBuffer)
{
W[id] = distance(id);
}
}
// Wd ← ∅ // queue for the discarded candidates
var Wd = new Dictionary<int, float>();
var popCandidate = new Func<(int, float)>(() => { var pair = W.OrderBy(e => e.Value).First(); W.Remove(pair.Key); return (pair.Key, pair.Value); });
var fartherFromResult = new Func<(int, float)>(() => { if (R.Count == 0) return (-1, 0f); var pair = R.OrderByDescending(e => e.Value).First(); return (pair.Key, pair.Value); });
var popNearestDiscarded = new Func<(int, float)>(() => { var pair = Wd.OrderBy(e => e.Value).First(); Wd.Remove(pair.Key); return (pair.Key, pair.Value); });
// while │W│ > 0 and │R│< M
while (W.Count > 0 && R.Count < M)
{
// e ← extract nearest element from W to q
var (e, ed) = popCandidate();
var (fe, fd) = fartherFromResult();
// if e is closer to q compared to any element from R
if (R.Count == 0 ||
ed < fd)
{
// R ← R e
R.Add(e, ed);
}
else
{
// Wd ← Wd e
Wd.Add(e, ed);
}
}
// if keepPrunedConnections // add some of the discarded // connections from Wd
if (_options.KeepPrunedConnections)
{
// while │Wd│> 0 and │R│< M
while (Wd.Count > 0 && R.Count < M)
{
// R ← R extract nearest element from Wd to q
var nearest = popNearestDiscarded();
R[nearest.Item1] = nearest.Item2;
}
}
// return R
return R;
}
#endregion
private IEnumerable<int> GetNeighbors(int id) => _links.FindLinksForId(id);
public void Serialize(IBinaryWriter writer)
{
_links.Serialize(writer);
}
public void Deserialize(IBinaryReader reader)
{
_links.Deserialize(reader);
}
}
}

@ -0,0 +1,33 @@
using System.Collections.Generic;
using ZeroLevel.Services.Serialization;
namespace ZeroLevel.HNSW
{
internal sealed class ReadOnlyVectorSet<T>
: IBinarySerializable
{
private List<T> _set = new List<T>();
internal T this[int index] => _set[index];
internal int Count => _set.Count;
public void Deserialize(IBinaryReader reader)
{
int count = reader.ReadInt32();
_set = new List<T>(count + 1);
for (int i = 0; i < count; i++)
{
_set.Add(reader.ReadCompatible<T>());
}
}
public void Serialize(IBinaryWriter writer)
{
writer.WriteInt32(_set.Count);
foreach (var r in _set)
{
writer.WriteCompatible<T>(r);
}
}
}
}

@ -1,18 +1,19 @@
using System;
using System.Collections.Generic;
using System.Collections.Generic;
using System.Threading;
using ZeroLevel.Services.Serialization;
namespace ZeroLevel.HNSW
{
public class VectorSet<T>
internal sealed class VectorSet<T>
: IBinarySerializable
{
private List<T> _set = new List<T>();
private SpinLock _lock = new SpinLock();
public T this[int index] => _set[index];
public int Count => _set.Count;
internal T this[int index] => _set[index];
internal int Count => _set.Count;
public int Append(T vector)
internal int Append(T vector)
{
bool gotLock = false;
gotLock = false;
@ -29,7 +30,7 @@ namespace ZeroLevel.HNSW
}
}
public int[] Append(IEnumerable<T> vectors)
internal int[] Append(IEnumerable<T> vectors)
{
bool gotLock = false;
int startIndex, endIndex;
@ -53,5 +54,24 @@ namespace ZeroLevel.HNSW
}
return ids;
}
public void Deserialize(IBinaryReader reader)
{
int count = reader.ReadInt32();
_set = new List<T>(count + 1);
for (int i = 0; i < count; i++)
{
_set.Add(reader.ReadCompatible<T>());
}
}
public void Serialize(IBinaryWriter writer)
{
writer.WriteInt32(_set.Count);
foreach (var r in _set)
{
writer.WriteCompatible<T>(r);
}
}
}
}

@ -1,43 +1,18 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Threading;
using ZeroLevel.HNSW.Services;
using ZeroLevel.Services.Serialization;
namespace ZeroLevel.HNSW
{
public class ProbabilityLayerNumberGenerator
{
private const float DIVIDER = 3.361f;
private readonly float[] _probabilities;
public ProbabilityLayerNumberGenerator(int maxLayers, int M)
{
_probabilities = new float[maxLayers];
var probability = 1.0f / DIVIDER;
for (int i = 0; i < maxLayers; i++)
{
_probabilities[i] = probability;
probability /= DIVIDER;
}
}
public int GetRandomLayer()
{
var probability = DefaultRandomGenerator.Instance.NextFloat();
for (int i = 0; i < _probabilities.Length; i++)
{
if (probability > _probabilities[i])
return i;
}
return 0;
}
}
public class SmallWorld<TItem>
{
private readonly NSWOptions<TItem> _options;
private readonly VectorSet<TItem> _vectors;
private readonly Layer<TItem>[] _layers;
private Layer<TItem>[] _layers;
private int EntryPoint = 0;
private int MaxLayer = 0;
private readonly ProbabilityLayerNumberGenerator _layerLevelGenerator;
@ -55,6 +30,12 @@ namespace ZeroLevel.HNSW
}
}
internal SmallWorld(NSWOptions<TItem> options, Stream stream)
{
_options = options;
Deserialize(stream);
}
/// <summary>
/// Search in the graph K for vectors closest to a given vector
/// </summary>
@ -62,14 +43,32 @@ namespace ZeroLevel.HNSW
/// <param name="k">Count of elements for search</param>
/// <param name="activeNodes"></param>
/// <returns></returns>
public IEnumerable<(int, TItem, float)> Search(TItem vector, int k, HashSet<int> activeNodes = null)
public IEnumerable<(int, TItem, float)> Search(TItem vector, int k)
{
foreach (var pair in KNearest(vector, k, activeNodes))
foreach (var pair in KNearest(vector, k))
{
yield return (pair.Item1, _vectors[pair.Item1], pair.Item2);
}
}
public IEnumerable<(int, TItem, float)> Search(TItem vector, int k, HashSet<int> activeNodes)
{
if (activeNodes == null)
{
foreach (var pair in KNearest(vector, k))
{
yield return (pair.Item1, _vectors[pair.Item1], pair.Item2);
}
}
else
{
foreach (var pair in KNearest(vector, k, activeNodes))
{
yield return (pair.Item1, _vectors[pair.Item1], pair.Item2);
}
}
}
/// <summary>
/// Adding vectors batch
/// </summary>
@ -201,7 +200,43 @@ namespace ZeroLevel.HNSW
/// <summary>
/// Algorithm 5
/// </summary>
private IEnumerable<(int, float)> KNearest(TItem q, int k, HashSet<int> activeNodes = null)
private IEnumerable<(int, float)> KNearest(TItem q, int k)
{
_lockGraph.EnterReadLock();
try
{
if (_vectors.Count == 0)
{
return Enumerable.Empty<(int, float)>();
}
var distance = new Func<int, float>(candidate => _options.Distance(q, _vectors[candidate]));
// W ← ∅ // set for the current nearest elements
var W = new Dictionary<int, float>(k + 1);
// ep ← get enter point for hnsw
var ep = EntryPoint;
// L ← level of ep // top layer for hnsw
var L = MaxLayer;
// for lc ← L … 1
for (int layer = L; layer > 0; --layer)
{
// W ← SEARCH-LAYER(q, ep, ef = 1, lc)
_layers[layer].KNearestAtLayer(ep, distance, W, 1);
// ep ← get nearest element from W to q
ep = W.OrderBy(p => p.Value).First().Key;
W.Clear();
}
// W ← SEARCH-LAYER(q, ep, ef, lc =0)
_layers[0].KNearestAtLayer(ep, distance, W, k);
// return K nearest elements from W to q
return W.Select(p => (p.Key, p.Value));
}
finally
{
_lockGraph.ExitReadLock();
}
}
private IEnumerable<(int, float)> KNearest(TItem q, int k, HashSet<int> activeNodes)
{
_lockGraph.EnterReadLock();
try
@ -238,5 +273,37 @@ namespace ZeroLevel.HNSW
}
}
#endregion
public void Serialize(Stream stream)
{
using (var writer = new MemoryStreamWriter(stream))
{
writer.WriteInt32(EntryPoint);
writer.WriteInt32(MaxLayer);
_vectors.Serialize(writer);
writer.WriteInt32(_layers.Length);
foreach (var l in _layers)
{
l.Serialize(writer);
}
}
}
public void Deserialize(Stream stream)
{
using (var reader = new MemoryStreamReader(stream))
{
this.EntryPoint = reader.ReadInt32();
this.MaxLayer = reader.ReadInt32();
_vectors.Deserialize(reader);
var countLayers = reader.ReadInt32();
_layers = new Layer<TItem>[countLayers];
for (int i = 0; i < countLayers; i++)
{
_layers[i] = new Layer<TItem>(_options, _vectors);
_layers[i].Deserialize(reader);
}
}
}
}
}

@ -0,0 +1,11 @@
using System.IO;
namespace ZeroLevel.HNSW
{
public static class SmallWorld
{
public static SmallWorld<TItem> CreateWorld<TItem>(NSWOptions<TItem> options) => new SmallWorld<TItem>(options);
public static SmallWorld<TItem> CreateWorldFrom<TItem>(NSWOptions<TItem> options, Stream stream) => new SmallWorld<TItem>(options, stream);
public static ReadOnlySmallWorld<TItem> CreateReadOnlyWorldFrom<TItem>(NSWReadOnlyOption<TItem> options, Stream stream) => new ReadOnlySmallWorld<TItem>(options, stream);
}
}

@ -0,0 +1,30 @@
namespace ZeroLevel.HNSW.Services
{
internal sealed class ProbabilityLayerNumberGenerator
{
private const float DIVIDER = 3.361f;
private readonly float[] _probabilities;
internal ProbabilityLayerNumberGenerator(int maxLayers, int M)
{
_probabilities = new float[maxLayers];
var probability = 1.0f / DIVIDER;
for (int i = 0; i < maxLayers; i++)
{
_probabilities[i] = probability;
probability /= DIVIDER;
}
}
internal int GetRandomLayer()
{
var probability = DefaultRandomGenerator.Instance.NextFloat();
for (int i = 0; i < _probabilities.Length; i++)
{
if (probability > _probabilities[i])
return i;
}
return 0;
}
}
}

@ -6,6 +6,7 @@
<ItemGroup>
<PackageReference Include="System.Numerics.Vectors" Version="4.5.0" />
<PackageReference Include="ZeroLevel" Version="3.3.5.6" />
</ItemGroup>
</Project>

Loading…
Cancel
Save

Powered by TurnKey Linux.