Start optimizations

pull/1/head
unknown 3 years ago
parent 2d2616ddce
commit ecd223ebdb

@ -1,4 +1,4 @@
<Project Sdk="Microsoft.NET.Sdk">
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<OutputType>Exe</OutputType>
@ -13,4 +13,10 @@
<ProjectReference Include="..\..\ZeroLevel.HNSW\ZeroLevel.HNSW.csproj" />
</ItemGroup>
<ItemGroup>
<None Update="t10k-images.idx3-ubyte">
<CopyToOutputDirectory>Always</CopyToOutputDirectory>
</None>
</ItemGroup>
</Project>

@ -6,6 +6,7 @@ using System.IO;
using System.Linq;
using ZeroLevel.HNSW;
using ZeroLevel.HNSW.Services;
using ZeroLevel.HNSW.Services.OPT;
namespace HNSWDemo
{
@ -13,6 +14,7 @@ namespace HNSWDemo
{
public class VectorsDirectCompare
{
private const int HALF_LONG_BITS = 32;
private readonly IList<float[]> _vectors;
private readonly Func<float[], float[], float> _distance;
@ -32,6 +34,70 @@ namespace HNSWDemo
}
return weights.OrderBy(p => p.Value).Take(k).Select(p => (p.Key, p.Value));
}
public List<HashSet<int>> DetectClusters()
{
var links = new SortedList<long, float>();
for (int i = 0; i < _vectors.Count; i++)
{
for (int j = i + 1; j < _vectors.Count; j++)
{
long k = (((long)(i)) << HALF_LONG_BITS) + j;
links.Add(k, _distance(_vectors[i], _vectors[j]));
}
}
// 1. Find R - bound between intra-cluster distances and out-of-cluster distances
var histogram = new Histogram(HistogramMode.SQRT, links.Values);
int threshold = histogram.OTSU();
var min = histogram.Bounds[threshold - 1];
var max = histogram.Bounds[threshold];
var R = (max + min) / 2;
// 2. Get links with distances less than R
var resultLinks = new SortedList<long, float>();
foreach (var pair in links)
{
if (pair.Value < R)
{
resultLinks.Add(pair.Key, pair.Value);
}
}
// 3. Extract clusters
List<HashSet<int>> clusters = new List<HashSet<int>>();
foreach (var pair in resultLinks)
{
var k = pair.Key;
var id1 = (int)(k >> HALF_LONG_BITS);
var id2 = (int)(k - (((long)id1) << HALF_LONG_BITS));
bool found = false;
foreach (var c in clusters)
{
if (c.Contains(id1))
{
c.Add(id2);
found = true;
break;
}
else if (c.Contains(id2))
{
c.Add(id1);
found = true;
break;
}
}
if (found == false)
{
var c = new HashSet<int>();
c.Add(id1);
c.Add(id2);
clusters.Add(c);
}
}
return clusters;
}
}
public enum Gender
@ -91,7 +157,7 @@ namespace HNSWDemo
{
var vector = new float[vectorSize];
DefaultRandomGenerator.Instance.NextFloats(vector);
//VectorUtils.NormalizeSIMD(vector);
VectorUtils.NormalizeSIMD(vector);
vectors.Add(vector);
}
return vectors;
@ -100,11 +166,107 @@ namespace HNSWDemo
static void Main(string[] args)
{
AutoClusteringTest();
var samples = RandomVectors(128, 600);
var opt_world = new OptWorld<float[]>(NSWOptions<float[]>.Create(8, 15, 200, 200, Metrics.L2Euclidean, true, true, selectionHeuristic: NeighbourSelectionHeuristic.SelectSimple));
opt_world.AddItems(samples);
//AccuracityTest();
Console.WriteLine("Completed");
Console.ReadKey();
}
static void BinaryHeapTest()
{
var heap = new BinaryHeap();
heap.Push(1, .03f);
heap.Push(2, .05f);
heap.Push(3, .01f);
heap.Push(4, 1.03f);
heap.Push(5, 2.03f);
heap.Push(6, .73f);
var n = heap.Nearest;
Console.WriteLine($"Nearest: [{n.Item1}] = {n.Item2}");
var f = heap.Farthest;
Console.WriteLine($"Farthest: [{f.Item1}] = {f.Item2}");
Console.WriteLine("From nearest to farthest");
while (heap.Count > 0)
{
var i = heap.PopNearest();
Console.WriteLine($"[{i.Item1}] = {i.Item2}");
}
heap.Push(1, .03f);
heap.Push(2, .05f);
heap.Push(3, .01f);
heap.Push(4, 1.03f);
heap.Push(5, 2.03f);
heap.Push(6, .73f);
Console.WriteLine("From farthest to nearest");
while (heap.Count > 0)
{
var i = heap.PopFarthest();
Console.WriteLine($"[{i.Item1}] = {i.Item2}");
}
}
static void TestOnMnist()
{
int imageCount, rowCount, colCount;
var buf = new byte[4];
var image = new byte[28 * 28];
var vectors = new List<float[]>();
using (var fs = new FileStream("t10k-images.idx3-ubyte", FileMode.Open, FileAccess.Read, FileShare.None))
{
// first 4 bytes is a magic number
fs.Read(buf, 0, 4);
// second 4 bytes is the number of images
fs.Read(buf, 0, 4);
imageCount = BitConverter.ToInt32(buf.Reverse().ToArray(), 0);
// third 4 bytes is the row count
fs.Read(buf, 0, 4);
rowCount = BitConverter.ToInt32(buf.Reverse().ToArray(), 0);
// fourth 4 bytes is the column count
fs.Read(buf, 0, 4);
colCount = BitConverter.ToInt32(buf.Reverse().ToArray(), 0);
for (int i = 0; i < imageCount; i++)
{
fs.Read(image, 0, image.Length);
vectors.Add(image.Select(b => (float)b).ToArray());
}
}
//var direct = new VectorsDirectCompare(vectors, Metrics.L2Euclidean);
var options = NSWOptions<float[]>.Create(8, 16, 200, 200, Metrics.L2Euclidean, selectionHeuristic: NeighbourSelectionHeuristic.SelectSimple);
SmallWorld<float[]> world;
if (File.Exists("graph.bin"))
{
using (var fs = new FileStream("graph.bin", FileMode.Open, FileAccess.Read, FileShare.None))
{
world = SmallWorld.CreateWorldFrom<float[]>(options, fs);
}
}
else
{
world = SmallWorld.CreateWorld<float[]>(options);
world.AddItems(vectors);
using (var fs = new FileStream("graph.bin", FileMode.Create, FileAccess.Write, FileShare.None))
{
world.Serialize(fs);
}
}
var clusters = AutomaticGraphClusterer.DetectClusters(world);
Console.WriteLine($"Found {clusters.Count} clusters");
for (int i = 0; i < clusters.Count; i++)
{
Console.WriteLine($"Cluster {i + 1} countains {clusters[i].Count} items");
}
}
static void AutoClusteringTest()
{
var vectors = RandomVectors(128, 3000);
@ -114,7 +276,7 @@ namespace HNSWDemo
Console.WriteLine($"Found {clusters.Count} clusters");
for (int i = 0; i < clusters.Count; i++)
{
Console.WriteLine($"Cluster {i+1} countains {clusters[i].Count} items");
Console.WriteLine($"Cluster {i + 1} countains {clusters[i].Count} items");
}
}
@ -135,10 +297,10 @@ namespace HNSWDemo
static void DrawHistogram(Histogram histogram, string filename)
{
/* while (histogram.CountSignChanges() > 3)
{
histogram.Smooth();
}*/
/* while (histogram.CountSignChanges() > 3)
{
histogram.Smooth();
}*/
var wb = 1200 / histogram.Values.Length;
var k = 600.0f / (float)histogram.Values.Max();
@ -147,8 +309,7 @@ namespace HNSWDemo
using (var bmp = new Bitmap(1200, 600))
{
using (var g = Graphics.FromImage(bmp))
{
using (var g = Graphics.FromImage(bmp)) {
for (int i = 0; i < histogram.Values.Length; i++)
{
var height = (int)(histogram.Values[i] * k);
@ -481,26 +642,37 @@ namespace HNSWDemo
static void AccuracityTest()
{
int K = 200;
var count = 5000;
var count = 2000;
var testCount = 1000;
var dimensionality = 128;
var totalHits = new List<int>();
var timewatchesNP = new List<float>();
var timewatchesHNSW = new List<float>();
var totalOptHits = new List<int>();
var timewatchesOptHNSW = new List<float>();
var samples = RandomVectors(dimensionality, count);
var sw = new Stopwatch();
var test = new VectorsDirectCompare(samples, CosineDistance.ForUnits);
var world = new SmallWorld<float[]>(NSWOptions<float[]>.Create(8, 15, 200, 200, CosineDistance.ForUnits, true, true, selectionHeuristic: NeighbourSelectionHeuristic.SelectSimple));
var test = new VectorsDirectCompare(samples, Metrics.L2Euclidean);
var world = new SmallWorld<float[]>(NSWOptions<float[]>.Create(8, 15, 200, 200, Metrics.L2Euclidean, true, true, selectionHeuristic: NeighbourSelectionHeuristic.SelectSimple));
var opt_world = new OptWorld<float[]>(NSWOptions<float[]>.Create(8, 15, 200, 200, Metrics.L2Euclidean, true, true, selectionHeuristic: NeighbourSelectionHeuristic.SelectSimple));
sw.Start();
var ids = world.AddItems(samples.ToArray());
sw.Stop();
Console.WriteLine($"Insert {ids.Length} items: {sw.ElapsedMilliseconds} ms");
Console.WriteLine($"Insert {ids.Length} items on {sw.ElapsedMilliseconds} ms");
sw.Restart();
opt_world.AddItems(samples.ToArray());
sw.Stop();
Console.WriteLine($"Insert {ids.Length} items in OPT: {sw.ElapsedMilliseconds} ms");
Console.WriteLine("Start test");
var test_vectors = RandomVectors(dimensionality, testCount);
foreach (var v in test_vectors)
{
@ -512,6 +684,7 @@ namespace HNSWDemo
sw.Restart();
var result = world.Search(v, K);
sw.Stop();
timewatchesHNSW.Add(sw.ElapsedMilliseconds);
var hits = 0;
foreach (var r in result)
@ -522,15 +695,39 @@ namespace HNSWDemo
}
}
totalHits.Add(hits);
sw.Restart();
result = opt_world.Search(v, K);
sw.Stop();
timewatchesOptHNSW.Add(sw.ElapsedMilliseconds);
hits = 0;
foreach (var r in result)
{
if (gt.ContainsKey(r.Item1))
{
hits++;
}
}
totalOptHits.Add(hits);
}
Console.WriteLine($"MIN Accuracity: {totalHits.Min() * 100 / K}%");
Console.WriteLine($"AVG Accuracity: {totalHits.Average() * 100 / K}%");
Console.WriteLine($"MAX Accuracity: {totalHits.Max() * 100 / K}%");
Console.WriteLine($"MIN Opt Accuracity: {totalOptHits.Min() * 100 / K}%");
Console.WriteLine($"AVG Opt Accuracity: {totalOptHits.Average() * 100 / K}%");
Console.WriteLine($"MAX Opt Accuracity: {totalOptHits.Max() * 100 / K}%");
Console.WriteLine($"MIN HNSW TIME: {timewatchesHNSW.Min()} ms");
Console.WriteLine($"AVG HNSW TIME: {timewatchesHNSW.Average()} ms");
Console.WriteLine($"MAX HNSW TIME: {timewatchesHNSW.Max()} ms");
Console.WriteLine($"MIN Opt HNSW TIME: {timewatchesOptHNSW.Min()} ms");
Console.WriteLine($"AVG Opt HNSW TIME: {timewatchesOptHNSW.Average()} ms");
Console.WriteLine($"MAX Opt HNSW TIME: {timewatchesOptHNSW.Max()} ms");
Console.WriteLine($"MIN NP TIME: {timewatchesNP.Min()} ms");
Console.WriteLine($"AVG NP TIME: {timewatchesNP.Average()} ms");
Console.WriteLine($"MAX NP TIME: {timewatchesNP.Max()} ms");

@ -21,7 +21,7 @@ namespace ZeroLevel.HNSW
public float[] Bounds { get; }
public int[] Values { get; }
internal Histogram(HistogramMode mode, IList<float> data)
public Histogram(HistogramMode mode, IList<float> data)
{
Mode = mode;
Min = data.Min();
@ -171,6 +171,13 @@ namespace ZeroLevel.HNSW
threshold = k;
}
}
/*
var local_max = Values[threshold];
for (int i = threshold + 1; i < Values.Length; i++)
{
}
*/
return threshold;
}
#endregion

@ -16,6 +16,7 @@ namespace ZeroLevel.HNSW.Services
var max = histogram.Bounds[threshold];
var R = (max + min) / 2;
// 2. Get links with distances less than R
var resultLinks = new SortedList<long, float>();
foreach (var pair in links)

@ -1,80 +1,90 @@
using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
namespace ZeroLevel.HNSW.Services
namespace ZeroLevel.HNSW
{
/// <summary>
/// Binary heap wrapper around the <see cref="IList{T}"/>
/// It's a max-heap implementation i.e. the maximum element is always on top.
/// But the order of elements can be customized by providing <see cref="IComparer{T}"/> instance.
/// </summary>
/// <typeparam name="T">The type of the items in the source list.</typeparam>
public class BinaryHeap<T>
public class BinaryHeap :
IEnumerable<(int, float)>
{
/// <summary>
/// Initializes a new instance of the <see cref="BinaryHeap{T}"/> class.
/// </summary>
/// <param name="buffer">The buffer to store heap items.</param>
public BinaryHeap(IList<T> buffer)
: this(buffer, Comparer<T>.Default)
{
}
private static BinaryHeap _empty = new BinaryHeap();
/// <summary>
/// Initializes a new instance of the <see cref="BinaryHeap{T}"/> class.
/// </summary>
/// <param name="buffer">The buffer to store heap items.</param>
/// <param name="comparer">The comparer which defines order of items.</param>
public BinaryHeap(IList<T> buffer, IComparer<T> comparer)
public static BinaryHeap Empty => _empty;
private readonly List<(int, float)> _data;
private bool _frozen = false;
public (int, float) Nearest => _data[_data.Count - 1];
public (int, float) Farthest => _data[0];
public (int, float) PopNearest()
{
if (buffer == null)
if (this._data.Any())
{
throw new ArgumentNullException(nameof(buffer));
var result = this._data[this._data.Count - 1];
this._data.RemoveAt(this._data.Count - 1);
return result;
}
return (-1, -1);
}
this.Buffer = buffer;
this.Comparer = comparer;
for (int i = 1; i < this.Buffer.Count; ++i)
public (int, float) PopFarthest()
{
if (this._data.Any())
{
this.SiftUp(i);
var result = this._data.First();
this._data[0] = this._data.Last();
this._data.RemoveAt(this._data.Count - 1);
this.SiftDown(0);
return result;
}
return (-1, -1);
}
/// <summary>
/// Gets the heap comparer.
/// </summary>
public IComparer<T> Comparer { get; private set; }
public int Count => _data.Count;
public void Clear() => _data.Clear();
/// <summary>
/// Gets the buffer of the heap.
/// Initializes a new instance of the <see cref="BinaryHeap{T}"/> class.
/// </summary>
public IList<T> Buffer { get; private set; }
/// <param name="buffer">The buffer to store heap items.</param>
public BinaryHeap(int k = -1, bool frozen = false)
{
_frozen = frozen;
if (k > 0)
_data = new List<(int, float)>(k);
else
_data = new List<(int, float)>();
}
/// <summary>
/// Pushes item to the heap.
/// </summary>
/// <param name="item">The item to push.</param>
public void Push(T item)
public void Push(int item, float distance)
{
this.Buffer.Add(item);
this.SiftUp(this.Buffer.Count - 1);
this._data.Add((item, distance));
this.SiftUp(this._data.Count - 1);
}
/// <summary>
/// Pops the item from the heap.
/// </summary>
/// <returns>The popped item.</returns>
public T Pop()
public (int, float) Pop()
{
if (this.Buffer.Any())
if (this._data.Any())
{
var result = this.Buffer.First();
var result = this._data.First();
this.Buffer[0] = this.Buffer.Last();
this.Buffer.RemoveAt(this.Buffer.Count - 1);
this._data[0] = this._data.Last();
this._data.RemoveAt(this._data.Count - 1);
this.SiftDown(0);
return result;
@ -90,21 +100,19 @@ namespace ZeroLevel.HNSW.Services
/// <param name="i">The position of item where heap property is violated.</param>
private void SiftDown(int i)
{
while (i < this.Buffer.Count)
while (i < this._data.Count)
{
int l = (2 * i) + 1;
int r = l + 1;
if (l >= this.Buffer.Count)
if (l >= this._data.Count)
{
break;
}
int m = r < this.Buffer.Count && this.Comparer.Compare(this.Buffer[l], this.Buffer[r]) < 0 ? r : l;
if (this.Comparer.Compare(this.Buffer[m], this.Buffer[i]) <= 0)
int m = ((r < this._data.Count) && this._data[l].Item2 < this._data[r].Item2) ? r : l;
if (this._data[m].Item2 <= this._data[i].Item2)
{
break;
}
this.Swap(i, m);
i = m;
}
@ -120,11 +128,10 @@ namespace ZeroLevel.HNSW.Services
while (i > 0)
{
int p = (i - 1) / 2;
if (this.Comparer.Compare(this.Buffer[i], this.Buffer[p]) <= 0)
if (this._data[i].Item2 <= this._data[p].Item2)
{
break;
}
this.Swap(i, p);
i = p;
}
@ -137,9 +144,19 @@ namespace ZeroLevel.HNSW.Services
/// <param name="j">The second index.</param>
private void Swap(int i, int j)
{
var temp = this.Buffer[i];
this.Buffer[i] = this.Buffer[j];
this.Buffer[j] = temp;
var temp = this._data[i];
this._data[i] = this._data[j];
this._data[j] = temp;
}
public IEnumerator<(int, float)> GetEnumerator()
{
return _data.GetEnumerator();
}
IEnumerator IEnumerable.GetEnumerator()
{
return _data.GetEnumerator();
}
}
}

@ -104,11 +104,23 @@ namespace ZeroLevel.HNSW
_rwLock.EnterReadLock();
try
{
foreach (var (k, v) in Search(_set, id))
if (_set.Count == 1)
{
var k = _set.Keys[0];
var v = _set[k];
var id1 = (int)(k >> HALF_LONG_BITS);
var id2 = (int)(k - (((long)id1) << HALF_LONG_BITS));
yield return (id1, id2, v);
if (id1 == id) yield return (id, id2, v);
else if (id2 == id) yield return (id1, id, v);
}
else if (_set.Count > 1)
{
foreach (var (k, v) in Search(_set, id))
{
var id1 = (int)(k >> HALF_LONG_BITS);
var id2 = (int)(k - (((long)id1) << HALF_LONG_BITS));
yield return (id1, id2, v);
}
}
}
finally

@ -53,6 +53,12 @@ namespace ZeroLevel.HNSW
int index = 0;
for (int ni = 1; ni < nearest.Length; ni++)
{
// Если осталась ссылка узла на себя, удаляем ее в первую очередь
if (nearest[ni].Item1 == nearest[ni].Item2)
{
index = ni;
break;
}
if (nearest[ni].Item3 > distance)
{
index = ni;
@ -81,6 +87,23 @@ namespace ZeroLevel.HNSW
}
#region Implementation of https://arxiv.org/ftp/arxiv/papers/1603/1603.09320.pdf
internal int FingEntryPointAtLayer(Func<int, float> targetCosts)
{
var set = new HashSet<int>(_links.Items().Select(p => p.Item1));
int minId = -1;
float minDist = float.MaxValue;
foreach (var id in set)
{
var d = targetCosts(id);
if (d < minDist && Math.Abs(d) > float.Epsilon)
{
minDist = d;
minId = id;
}
}
return minId;
}
/// <summary>
/// Algorithm 2
/// </summary>

@ -0,0 +1,491 @@
using System;
using System.Collections.Generic;
using System.Linq;
using ZeroLevel.Services.Serialization;
namespace ZeroLevel.HNSW.Services.OPT
{
/// <summary>
/// NSW graph
/// </summary>
internal sealed class OptLayer<TItem>
: IBinarySerializable
{
private readonly NSWOptions<TItem> _options;
private readonly VectorSet<TItem> _vectors;
private readonly CompactBiDirectionalLinksSet _links;
internal SortedList<long, float> Links => _links.Links;
/// <summary>
/// There are links е the layer
/// </summary>
internal bool HasLinks => (_links.Count > 0);
/// <summary>
/// HNSW layer
/// </summary>
/// <param name="options">HNSW graph options</param>
/// <param name="vectors">General vector set</param>
internal OptLayer(NSWOptions<TItem> options, VectorSet<TItem> vectors)
{
_options = options;
_vectors = vectors;
_links = new CompactBiDirectionalLinksSet();
}
/// <summary>
/// Adding new bidirectional link
/// </summary>
/// <param name="q">New node</param>
/// <param name="p">The node with which the connection will be made</param>
/// <param name="qpDistance"></param>
/// <param name="isMapLayer"></param>
internal void AddBidirectionallConnections(int q, int p, float qpDistance, bool isMapLayer)
{
// поиск в ширину ближайших узлов к найденному
var nearest = _links.FindLinksForId(p).ToArray();
// если у найденного узла максимальное количество связей
// if │eConn│ > Mmax // shrink connections of e
if (nearest.Length >= (isMapLayer ? _options.M * 2 : _options.M))
{
// ищем связь с самой большой дистанцией
float distance = nearest[0].Item3;
int index = 0;
for (int ni = 1; ni < nearest.Length; ni++)
{
// Если осталась ссылка узла на себя, удаляем ее в первую очередь
if (nearest[ni].Item1 == nearest[ni].Item2)
{
index = ni;
break;
}
if (nearest[ni].Item3 > distance)
{
index = ni;
distance = nearest[ni].Item3;
}
}
// делаем перелинковку вставляя новый узел между найденными
var id1 = nearest[index].Item1;
var id2 = nearest[index].Item2;
_links.Relink(id1, id2, q, qpDistance, _options.Distance(_vectors[id2], _vectors[q]));
}
else
{
if (nearest.Length == 1 && nearest[0].Item1 == nearest[0].Item2)
{
// убираем связи на самих себя
var id1 = nearest[0].Item1;
var id2 = nearest[0].Item2;
_links.Relink(id1, id2, q, qpDistance, _options.Distance(_vectors[id2], _vectors[q]));
}
else
{
// добавляем связь нового узла к найденному
_links.Add(q, p, qpDistance);
}
}
}
/// <summary>
/// Adding a node with a connection to itself
/// </summary>
/// <param name="q"></param>
internal void Append(int q)
{
_links.Add(q, q, 0);
}
#region Implementation of https://arxiv.org/ftp/arxiv/papers/1603/1603.09320.pdf
/// <summary>
/// Algorithm 2
/// </summary>
/// <param name="q">query element</param>
/// <param name="ep">enter points ep</param>
/// <returns>Output: ef closest neighbors to q</returns>
internal void KNearestAtLayer(int entryPointId, Func<int, float> targetCosts, BinaryHeap W, int ef)
{
/*
* v ep // set of visited elements
* C ep // set of candidates
* W ep // dynamic list of found nearest neighbors
* while C > 0
* c extract nearest element from C to q
* f get furthest element from W to q
* if distance(c, q) > distance(f, q)
* break // all elements in W are evaluated
* for each e neighbourhood(c) at layer lc // update C and W
* if e v
* v v e
* f get furthest element from W to q
* if distance(e, q) < distance(f, q) or W < ef
* C C e
* W W e
* if W > ef
* remove furthest element from W to q
* return W
*/
var v = new VisitedBitSet(_vectors.Count, _options.M);
// v ← ep // set of visited elements
v.Add(entryPointId);
var d = targetCosts(entryPointId);
// C ← ep // set of candidates
var C = new BinaryHeap();
C.Push(entryPointId, d);
// W ← ep // dynamic list of found nearest neighbors
W.Push(entryPointId, d);
// run bfs
while (C.Count > 0)
{
// get next candidate to check and expand
var toExpand = C.PopNearest();
var farthestResult = W.Farthest;
if (toExpand.Item2 > farthestResult.Item2)
{
// the closest candidate is farther than farthest result
break;
}
// expand candidate
var neighboursIds = GetNeighbors(toExpand.Item1).ToArray();
for (int i = 0; i < neighboursIds.Length; ++i)
{
int neighbourId = neighboursIds[i];
if (!v.Contains(neighbourId))
{
// enqueue perspective neighbours to expansion list
farthestResult = W.Farthest;
var neighbourDistance = targetCosts(neighbourId);
if (W.Count < ef || neighbourDistance < farthestResult.Item2)
{
C.Push(neighbourId, neighbourDistance);
W.Push(neighbourId, neighbourDistance);
if (W.Count > ef)
{
W.PopFarthest();
}
}
v.Add(neighbourId);
}
}
}
C.Clear();
v.Clear();
}
/// <summary>
/// Algorithm 2
/// </summary>
/// <param name="q">query element</param>
/// <param name="ep">enter points ep</param>
/// <returns>Output: ef closest neighbors to q</returns>
internal void KNearestAtLayer(int entryPointId, Func<int, float> targetCosts, BinaryHeap W, int ef, SearchContext context)
{
/*
* v ep // set of visited elements
* C ep // set of candidates
* W ep // dynamic list of found nearest neighbors
* while C > 0
* c extract nearest element from C to q
* f get furthest element from W to q
* if distance(c, q) > distance(f, q)
* break // all elements in W are evaluated
* for each e neighbourhood(c) at layer lc // update C and W
* if e v
* v v e
* f get furthest element from W to q
* if distance(e, q) < distance(f, q) or W < ef
* C C e
* W W e
* if W > ef
* remove furthest element from W to q
* return W
*/
var v = new VisitedBitSet(_vectors.Count, _options.M);
// v ← ep // set of visited elements
v.Add(entryPointId);
// C ← ep // set of candidates
var C = new BinaryHeap();
var d = targetCosts(entryPointId);
C.Push(entryPointId, d);
// W ← ep // dynamic list of found nearest neighbors
if (context.IsActiveNode(entryPointId))
{
W.Push(entryPointId, d);
}
// run bfs
while (C.Count > 0)
{
// get next candidate to check and expand
var toExpand = C.PopNearest();
if (W.Count > 0)
{
if (toExpand.Item2 > W.Farthest.Item2)
{
// the closest candidate is farther than farthest result
break;
}
}
// expand candidate
var neighboursIds = GetNeighbors(toExpand.Item1).ToArray();
for (int i = 0; i < neighboursIds.Length; ++i)
{
int neighbourId = neighboursIds[i];
if (!v.Contains(neighbourId))
{
// enqueue perspective neighbours to expansion list
var neighbourDistance = targetCosts(neighbourId);
if (context.IsActiveNode(neighbourId))
{
if (W.Count < ef || (W.Count > 0 && neighbourDistance < W.Farthest.Item2))
{
W.Push(neighbourId, neighbourDistance);
if (W.Count > ef)
{
W.PopFarthest();
}
}
}
if (W.Count < ef)
{
C.Push(neighbourId, neighbourDistance);
}
v.Add(neighbourId);
}
}
}
C.Clear();
v.Clear();
}
/// <summary>
/// Algorithm 2, modified for LookAlike
/// </summary>
/// <param name="q">query element</param>
/// <param name="ep">enter points ep</param>
/// <returns>Output: ef closest neighbors to q</returns>
internal void KNearestAtLayer(BinaryHeap W, int ef, SearchContext context)
{
/*
* v ep // set of visited elements
* C ep // set of candidates
* W ep // dynamic list of found nearest neighbors
* while C > 0
* c extract nearest element from C to q
* f get furthest element from W to q
* if distance(c, q) > distance(f, q)
* break // all elements in W are evaluated
* for each e neighbourhood(c) at layer lc // update C and W
* if e v
* v v e
* f get furthest element from W to q
* if distance(e, q) < distance(f, q) or W < ef
* C C e
* W W e
* if W > ef
* remove furthest element from W to q
* return W
*/
// v ← ep // set of visited elements
var v = new VisitedBitSet(_vectors.Count, _options.M);
// C ← ep // set of candidates
var C = new BinaryHeap();
foreach (var ep in context.EntryPoints)
{
var neighboursIds = GetNeighbors(ep).ToArray();
for (int i = 0; i < neighboursIds.Length; ++i)
{
C.Push(ep, _links.Distance(ep, neighboursIds[i]));
}
v.Add(ep);
}
// W ← ep // dynamic list of found nearest neighbors
// run bfs
while (C.Count > 0)
{
// get next candidate to check and expand
var toExpand = C.PopNearest();
if (W.Count > 0)
{
if (toExpand.Item2 > W.Farthest.Item2)
{
// the closest candidate is farther than farthest result
break;
}
}
if (context.IsActiveNode(toExpand.Item1))
{
if (W.Count < ef || W.Count == 0 || (W.Count > 0 && toExpand.Item2 < W.Farthest.Item2))
{
W.Push(toExpand.Item1, toExpand.Item2);
if (W.Count > ef)
{
W.PopFarthest();
}
}
}
}
if (W.Count > ef)
{
while (W.Count > ef)
{
W.PopFarthest();
}
return;
}
else
{
foreach (var c in W)
{
C.Push(c.Item1, c.Item2);
}
}
while (C.Count > 0)
{
// get next candidate to check and expand
var toExpand = C.PopNearest();
// expand candidate
var neighboursIds = GetNeighbors(toExpand.Item1).ToArray();
for (int i = 0; i < neighboursIds.Length; ++i)
{
int neighbourId = neighboursIds[i];
if (!v.Contains(neighbourId))
{
// enqueue perspective neighbours to expansion list
var neighbourDistance = _links.Distance(toExpand.Item1, neighbourId);
if (context.IsActiveNode(neighbourId))
{
if (W.Count < ef || (W.Count > 0 && neighbourDistance < W.Farthest.Item2))
{
W.Push(neighbourId, neighbourDistance);
if (W.Count > ef)
{
W.PopFarthest();
}
}
}
if (W.Count < ef)
{
C.Push(neighbourId, neighbourDistance);
}
v.Add(neighbourId);
}
}
}
C.Clear();
v.Clear();
}
/// <summary>
/// Algorithm 3
/// </summary>
internal BinaryHeap SELECT_NEIGHBORS_SIMPLE(BinaryHeap W, int M)
{
var bestN = M;
if (W.Count > bestN)
{
while (W.Count > bestN)
{
W.PopFarthest();
}
}
return W;
}
/// <summary>
/// Algorithm 4
/// </summary>
/// <param name="q">base element</param>
/// <param name="C">candidate elements</param>
/// <param name="extendCandidates">flag indicating whether or not to extend candidate list</param>
/// <param name="keepPrunedConnections">flag indicating whether or not to add discarded elements</param>
/// <returns>Output: M elements selected by the heuristic</returns>
internal BinaryHeap SELECT_NEIGHBORS_HEURISTIC(Func<int, float> distance, BinaryHeap W, int M)
{
// R ← ∅
var R = new BinaryHeap();
// W ← C // working queue for the candidates
// if extendCandidates // extend candidates by their neighbors
if (_options.ExpandBestSelection)
{
var extendBuffer = new HashSet<int>();
// for each e ∈ C
foreach (var e in W)
{
var neighbors = GetNeighbors(e.Item1);
// for each e_adj ∈ neighbourhood(e) at layer lc
foreach (var e_adj in neighbors)
{
// if eadj ∉ W
if (extendBuffer.Contains(e_adj) == false)
{
extendBuffer.Add(e_adj);
}
}
}
// W ← W eadj
foreach (var id in extendBuffer)
{
W.Push(id, distance(id));
}
}
// Wd ← ∅ // queue for the discarded candidates
var Wd = new BinaryHeap();
// while │W│ > 0 and │R│< M
while (W.Count > 0 && R.Count < M)
{
// e ← extract nearest element from W to q
var (e, ed) = W.PopNearest();
var (fe, fd) = R.PopFarthest();
// if e is closer to q compared to any element from R
if (R.Count == 0 ||
ed < fd)
{
// R ← R e
R.Push(e, ed);
}
else
{
// Wd ← Wd e
Wd.Push(e, ed);
}
}
// if keepPrunedConnections // add some of the discarded // connections from Wd
if (_options.KeepPrunedConnections)
{
// while │Wd│> 0 and │R│< M
while (Wd.Count > 0 && R.Count < M)
{
// R ← R extract nearest element from Wd to q
var nearest = Wd.PopNearest();
R.Push(nearest.Item1, nearest.Item2);
}
}
// return R
return R;
}
#endregion
private IEnumerable<int> GetNeighbors(int id) => _links.FindLinksForId(id).Select(d => d.Item2);
public void Serialize(IBinaryWriter writer)
{
_links.Serialize(writer);
}
public void Deserialize(IBinaryReader reader)
{
_links.Deserialize(reader);
}
internal Histogram GetHistogram(HistogramMode mode) => _links.CalculateHistogram(mode);
}
}

@ -0,0 +1,347 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Threading;
using ZeroLevel.Services.Serialization;
namespace ZeroLevel.HNSW.Services.OPT
{
public class OptWorld<TItem>
{
private readonly NSWOptions<TItem> _options;
private VectorSet<TItem> _vectors;
private OptLayer<TItem>[] _layers;
private int EntryPoint = 0;
private int MaxLayer = 0;
private readonly ProbabilityLayerNumberGenerator _layerLevelGenerator;
private ReaderWriterLockSlim _lockGraph = new ReaderWriterLockSlim();
internal SortedList<long, float> GetNSWLinks() => _layers[0].Links;
public OptWorld(NSWOptions<TItem> options)
{
_options = options;
_vectors = new VectorSet<TItem>();
_layers = new OptLayer<TItem>[_options.LayersCount];
_layerLevelGenerator = new ProbabilityLayerNumberGenerator(_options.LayersCount, _options.M);
for (int i = 0; i < _options.LayersCount; i++)
{
_layers[i] = new OptLayer<TItem>(_options, _vectors);
}
}
internal OptWorld(NSWOptions<TItem> options, Stream stream)
{
_options = options;
Deserialize(stream);
}
/// <summary>
/// Search in the graph K for vectors closest to a given vector
/// </summary>
/// <param name="vector">Given vector</param>
/// <param name="k">Count of elements for search</param>
/// <param name="activeNodes"></param>
/// <returns></returns>
public IEnumerable<(int, TItem, float)> Search(TItem vector, int k)
{
foreach (var pair in KNearest(vector, k))
{
yield return (pair.Item1, _vectors[pair.Item1], pair.Item2);
}
}
public IEnumerable<(int, TItem, float)> Search(TItem vector, int k, SearchContext context)
{
if (context == null)
{
foreach (var pair in KNearest(vector, k))
{
yield return (pair.Item1, _vectors[pair.Item1], pair.Item2);
}
}
else
{
foreach (var pair in KNearest(vector, k, context))
{
yield return (pair.Item1, _vectors[pair.Item1], pair.Item2);
}
}
}
public IEnumerable<(int, TItem, float)> Search(int k, SearchContext context)
{
if (context == null)
{
throw new ArgumentNullException(nameof(context));
}
else
{
foreach (var pair in KNearest(k, context))
{
yield return (pair.Item1, _vectors[pair.Item1], pair.Item2);
}
}
}
/// <summary>
/// Adding vectors batch
/// </summary>
/// <param name="vectors">Vectors</param>
/// <returns>Vector identifiers in a graph</returns>
public int[] AddItems(IEnumerable<TItem> vectors)
{
_lockGraph.EnterWriteLock();
try
{
var ids = _vectors.Append(vectors);
for (int i = 0; i < ids.Length; i++)
{
INSERT(ids[i]);
}
return ids;
}
finally
{
_lockGraph.ExitWriteLock();
}
}
#region https://arxiv.org/ftp/arxiv/papers/1603/1603.09320.pdf
/// <summary>
/// Algorithm 1
/// </summary>
private void INSERT(int q)
{
var distance = new Func<int, float>(candidate => _options.Distance(_vectors[q], _vectors[candidate]));
// W ← ∅ // list for the currently found nearest elements
var W = new BinaryHeap();
// ep ← get enter point for hnsw
//var ep = _layers[MaxLayer].FingEntryPointAtLayer(distance);
//if(ep == -1) ep = EntryPoint;
var ep = EntryPoint;
var epDist = distance(ep);
// L ← level of ep // top layer for hnsw
var L = MaxLayer;
// l ← ⌊-ln(unif(0..1))∙mL⌋ // new elements level
int l = _layerLevelGenerator.GetRandomLayer();
// for lc ← L … l+1
// Проход с верхнего уровня до уровня где появляется элемент, для нахождения точки входа
for (int lc = L; lc > l; --lc)
{
// W ← SEARCH-LAYER(q, ep, ef = 1, lc)
_layers[lc].KNearestAtLayer(ep, distance, W, 1);
// ep ← get the nearest element from W to q
var nearest = W.Nearest;
ep = nearest.Item1;
epDist = nearest.Item2;
W.Clear();
}
//for lc ← min(L, l) … 0
// connecting new node to the small world
for (int lc = Math.Min(L, l); lc >= 0; --lc)
{
if (_layers[lc].HasLinks == false)
{
_layers[lc].Append(q);
}
else
{
// W ← SEARCH - LAYER(q, ep, efConstruction, lc)
_layers[lc].KNearestAtLayer(ep, distance, W, _options.EFConstruction);
// ep ← W
var nearest = W.Nearest;
ep = nearest.Item1;
epDist = nearest.Item2;
// neighbors ← SELECT-NEIGHBORS(q, W, M, lc) // alg. 3 or alg. 4
var neighbors = SelectBestForConnecting(lc, distance, W);
// add bidirectionall connectionts from neighbors to q at layer lc
// for each e ∈ neighbors // shrink connections if needed
foreach (var e in neighbors)
{
// eConn ← neighbourhood(e) at layer lc
_layers[lc].AddBidirectionallConnections(q, e.Item1, e.Item2, lc == 0);
// if distance from newNode to newNeighbour is better than to bestPeer => update bestPeer
if (e.Item2 < epDist)
{
ep = e.Item1;
epDist = e.Item2;
}
}
W.Clear();
}
}
// if l > L
if (l > L)
{
// set enter point for hnsw to q
L = l;
MaxLayer = l;
EntryPoint = ep;
}
}
/// <summary>
/// Get maximum allowed connections for the given level.
/// </summary>
/// <remarks>
/// Article: Section 4.1:
/// "Selection of the Mmax0 (the maximum number of connections that an element can have in the zero layer) also
/// has a strong influence on the search performance, especially in case of high quality(high recall) search.
/// Simulations show that setting Mmax0 to M(this corresponds to kNN graphs on each layer if the neighbors
/// selection heuristic is not used) leads to a very strong performance penalty at high recall.
/// Simulations also suggest that 2∙M is a good choice for Mmax0;
/// setting the parameter higher leads to performance degradation and excessive memory usage."
/// </remarks>
/// <param name="layer">The level of the layer.</param>
/// <returns>The maximum number of connections.</returns>
private int GetM(int layer)
{
return layer == 0 ? 2 * _options.M : _options.M;
}
private BinaryHeap SelectBestForConnecting(int layer, Func<int, float> distance, BinaryHeap candidates)
{
if (_options.SelectionHeuristic == NeighbourSelectionHeuristic.SelectSimple)
return _layers[layer].SELECT_NEIGHBORS_SIMPLE(candidates, GetM(layer));
return _layers[layer].SELECT_NEIGHBORS_HEURISTIC(distance, candidates, GetM(layer));
}
/// <summary>
/// Algorithm 5
/// </summary>
private BinaryHeap KNearest(TItem q, int k)
{
_lockGraph.EnterReadLock();
try
{
if (_vectors.Count == 0)
{
return BinaryHeap.Empty;
}
var distance = new Func<int, float>(candidate => _options.Distance(q, _vectors[candidate]));
// W ← ∅ // set for the current nearest elements
var W = new BinaryHeap(k + 1);
// ep ← get enter point for hnsw
var ep = EntryPoint;
// L ← level of ep // top layer for hnsw
var L = MaxLayer;
// for lc ← L … 1
for (int layer = L; layer > 0; --layer)
{
// W ← SEARCH-LAYER(q, ep, ef = 1, lc)
_layers[layer].KNearestAtLayer(ep, distance, W, 1);
// ep ← get nearest element from W to q
ep = W.Nearest.Item1;
W.Clear();
}
// W ← SEARCH-LAYER(q, ep, ef, lc =0)
_layers[0].KNearestAtLayer(ep, distance, W, k);
// return K nearest elements from W to q
return W;
}
finally
{
_lockGraph.ExitReadLock();
}
}
private BinaryHeap KNearest(TItem q, int k, SearchContext context)
{
_lockGraph.EnterReadLock();
try
{
if (_vectors.Count == 0)
{
return BinaryHeap.Empty;
}
var distance = new Func<int, float>(candidate => _options.Distance(q, _vectors[candidate]));
// W ← ∅ // set for the current nearest elements
var W = new BinaryHeap(k + 1);
// ep ← get enter point for hnsw
var ep = EntryPoint;
// L ← level of ep // top layer for hnsw
var L = MaxLayer;
// for lc ← L … 1
for (int layer = L; layer > 0; --layer)
{
// W ← SEARCH-LAYER(q, ep, ef = 1, lc)
_layers[layer].KNearestAtLayer(ep, distance, W, 1);
// ep ← get nearest element from W to q
ep = W.Nearest.Item1;
W.Clear();
}
// W ← SEARCH-LAYER(q, ep, ef, lc =0)
_layers[0].KNearestAtLayer(ep, distance, W, k, context);
// return K nearest elements from W to q
return W;
}
finally
{
_lockGraph.ExitReadLock();
}
}
private BinaryHeap KNearest(int k, SearchContext context)
{
_lockGraph.EnterReadLock();
try
{
if (_vectors.Count == 0)
{
return BinaryHeap.Empty;
}
var distance = new Func<int, int, float>((id1, id2) => _options.Distance(_vectors[id1], _vectors[id2]));
// W ← ∅ // set for the current nearest elements
var W = new BinaryHeap(k + 1);
// W ← SEARCH-LAYER(q, ep, ef, lc =0)
_layers[0].KNearestAtLayer(W, k, context);
// return K nearest elements from W to q
return W;
}
finally
{
_lockGraph.ExitReadLock();
}
}
#endregion
public void Serialize(Stream stream)
{
using (var writer = new MemoryStreamWriter(stream))
{
writer.WriteInt32(EntryPoint);
writer.WriteInt32(MaxLayer);
_vectors.Serialize(writer);
writer.WriteInt32(_layers.Length);
foreach (var l in _layers)
{
l.Serialize(writer);
}
}
}
public void Deserialize(Stream stream)
{
using (var reader = new MemoryStreamReader(stream))
{
this.EntryPoint = reader.ReadInt32();
this.MaxLayer = reader.ReadInt32();
_vectors = new VectorSet<TItem>();
_vectors.Deserialize(reader);
var countLayers = reader.ReadInt32();
_layers = new OptLayer<TItem>[countLayers];
for (int i = 0; i < countLayers; i++)
{
_layers[i] = new OptLayer<TItem>(_options, _vectors);
_layers[i].Deserialize(reader);
}
}
}
public Histogram GetHistogram(HistogramMode mode = HistogramMode.SQRT)
=> _layers[0].GetHistogram(mode);
}
}

@ -11,7 +11,7 @@ namespace ZeroLevel.HNSW
public class SmallWorld<TItem>
{
private readonly NSWOptions<TItem> _options;
private readonly VectorSet<TItem> _vectors;
private VectorSet<TItem> _vectors;
private Layer<TItem>[] _layers;
private int EntryPoint = 0;
private int MaxLayer = 0;
@ -118,6 +118,8 @@ namespace ZeroLevel.HNSW
// W ← ∅ // list for the currently found nearest elements
IDictionary<int, float> W = new Dictionary<int, float>();
// ep ← get enter point for hnsw
//var ep = _layers[MaxLayer].FingEntryPointAtLayer(distance);
//if(ep == -1) ep = EntryPoint;
var ep = EntryPoint;
var epDist = distance(ep);
// L ← level of ep // top layer for hnsw
@ -334,6 +336,7 @@ namespace ZeroLevel.HNSW
{
this.EntryPoint = reader.ReadInt32();
this.MaxLayer = reader.ReadInt32();
_vectors = new VectorSet<TItem>();
_vectors.Deserialize(reader);
var countLayers = reader.ReadInt32();
_layers = new Layer<TItem>[countLayers];

@ -1,7 +1,8 @@
<Project Sdk="Microsoft.NET.Sdk">
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFramework>net5.0</TargetFramework>
<Platforms>AnyCPU;x64</Platforms>
</PropertyGroup>
<ItemGroup>

Loading…
Cancel
Save

Powered by TurnKey Linux.