Fix INSERT algorithm
pull/1/head
unknown 3 years ago
parent 28ccefa255
commit adf09b08c8

@ -0,0 +1,7 @@
namespace HNSWDemo.Model
{
public enum Gender
{
Unknown, Male, Feemale
}
}

@ -0,0 +1,51 @@
using System;
using System.Collections.Generic;
using ZeroLevel.HNSW;
namespace HNSWDemo.Model
{
public class Person
{
public Gender Gender { get; set; }
public int Age { get; set; }
public long Number { get; set; }
private static (float[], Person) Generate(int vector_size)
{
var rnd = new Random((int)Environment.TickCount);
var vector = new float[vector_size];
DefaultRandomGenerator.Instance.NextFloats(vector);
VectorUtils.NormalizeSIMD(vector);
var p = new Person();
p.Age = rnd.Next(15, 80);
var gr = rnd.Next(0, 3);
p.Gender = (gr == 0) ? Gender.Male : (gr == 1) ? Gender.Feemale : Gender.Unknown;
p.Number = CreateNumber(rnd);
return (vector, p);
}
public static List<(float[], Person)> GenerateRandom(int vectorSize, int vectorsCount)
{
var vectors = new List<(float[], Person)>();
for (int i = 0; i < vectorsCount; i++)
{
vectors.Add(Generate(vectorSize));
}
return vectors;
}
static HashSet<long> _exists = new HashSet<long>();
private static long CreateNumber(Random rnd)
{
long start_number;
do
{
start_number = 79600000000L;
start_number = start_number + rnd.Next(4, 8) * 10000000;
start_number += rnd.Next(0, 1000000);
}
while (_exists.Add(start_number) == false);
return start_number;
}
}
}

@ -1,596 +1,17 @@
using System; using HNSWDemo.Tests;
using System.Collections.Generic; using System;
using System.Diagnostics;
using System.Drawing;
using System.IO;
using System.Linq;
using ZeroLevel.HNSW;
using ZeroLevel.HNSW.Services;
namespace HNSWDemo namespace HNSWDemo
{ {
class Program class Program
{ {
public class VectorsDirectCompare
{
private const int HALF_LONG_BITS = 32;
private readonly IList<float[]> _vectors;
private readonly Func<float[], float[], float> _distance;
public VectorsDirectCompare(List<float[]> vectors, Func<float[], float[], float> distance)
{
_vectors = vectors;
_distance = distance;
}
public IEnumerable<(int, float)> KNearest(float[] v, int k)
{
var weights = new Dictionary<int, float>();
for (int i = 0; i < _vectors.Count; i++)
{
var d = _distance(v, _vectors[i]);
weights[i] = d;
}
return weights.OrderBy(p => p.Value).Take(k).Select(p => (p.Key, p.Value));
}
public List<HashSet<int>> DetectClusters()
{
var links = new SortedList<long, float>();
for (int i = 0; i < _vectors.Count; i++)
{
for (int j = i + 1; j < _vectors.Count; j++)
{
long k = (((long)(i)) << HALF_LONG_BITS) + j;
links.Add(k, _distance(_vectors[i], _vectors[j]));
}
}
// 1. Find R - bound between intra-cluster distances and out-of-cluster distances
var histogram = new Histogram(HistogramMode.SQRT, links.Values);
int threshold = histogram.OTSU();
var min = histogram.Bounds[threshold - 1];
var max = histogram.Bounds[threshold];
var R = (max + min) / 2;
// 2. Get links with distances less than R
var resultLinks = new SortedList<long, float>();
foreach (var pair in links)
{
if (pair.Value < R)
{
resultLinks.Add(pair.Key, pair.Value);
}
}
// 3. Extract clusters
List<HashSet<int>> clusters = new List<HashSet<int>>();
foreach (var pair in resultLinks)
{
var k = pair.Key;
var id1 = (int)(k >> HALF_LONG_BITS);
var id2 = (int)(k - (((long)id1) << HALF_LONG_BITS));
bool found = false;
foreach (var c in clusters)
{
if (c.Contains(id1))
{
c.Add(id2);
found = true;
break;
}
else if (c.Contains(id2))
{
c.Add(id1);
found = true;
break;
}
}
if (found == false)
{
var c = new HashSet<int>();
c.Add(id1);
c.Add(id2);
clusters.Add(c);
}
}
return clusters;
}
}
public class QVectorsDirectCompare
{
private const int HALF_LONG_BITS = 32;
private readonly IList<byte[]> _vectors;
private readonly Func<byte[], byte[], float> _distance;
public QVectorsDirectCompare(List<byte[]> vectors, Func<byte[], byte[], float> distance)
{
_vectors = vectors;
_distance = distance;
}
public IEnumerable<(int, float)> KNearest(byte[] v, int k)
{
var weights = new Dictionary<int, float>();
for (int i = 0; i < _vectors.Count; i++)
{
var d = _distance(v, _vectors[i]);
weights[i] = d;
}
return weights.OrderBy(p => p.Value).Take(k).Select(p => (p.Key, p.Value));
}
public List<HashSet<int>> DetectClusters()
{
var links = new SortedList<long, float>();
for (int i = 0; i < _vectors.Count; i++)
{
for (int j = i + 1; j < _vectors.Count; j++)
{
long k = (((long)(i)) << HALF_LONG_BITS) + j;
links.Add(k, _distance(_vectors[i], _vectors[j]));
}
}
// 1. Find R - bound between intra-cluster distances and out-of-cluster distances
var histogram = new Histogram(HistogramMode.SQRT, links.Values);
int threshold = histogram.OTSU();
var min = histogram.Bounds[threshold - 1];
var max = histogram.Bounds[threshold];
var R = (max + min) / 2;
// 2. Get links with distances less than R
var resultLinks = new SortedList<long, float>();
foreach (var pair in links)
{
if (pair.Value < R)
{
resultLinks.Add(pair.Key, pair.Value);
}
}
// 3. Extract clusters
List<HashSet<int>> clusters = new List<HashSet<int>>();
foreach (var pair in resultLinks)
{
var k = pair.Key;
var id1 = (int)(k >> HALF_LONG_BITS);
var id2 = (int)(k - (((long)id1) << HALF_LONG_BITS));
bool found = false;
foreach (var c in clusters)
{
if (c.Contains(id1))
{
c.Add(id2);
found = true;
break;
}
else if (c.Contains(id2))
{
c.Add(id1);
found = true;
break;
}
}
if (found == false)
{
var c = new HashSet<int>();
c.Add(id1);
c.Add(id2);
clusters.Add(c);
}
}
return clusters;
}
}
public class QLVectorsDirectCompare
{
private const int HALF_LONG_BITS = 32;
private readonly IList<long[]> _vectors;
private readonly Func<long[], long[], float> _distance;
public QLVectorsDirectCompare(List<long[]> vectors, Func<long[], long[], float> distance)
{
_vectors = vectors;
_distance = distance;
}
public IEnumerable<(int, float)> KNearest(long[] v, int k)
{
var weights = new Dictionary<int, float>();
for (int i = 0; i < _vectors.Count; i++)
{
var d = _distance(v, _vectors[i]);
weights[i] = d;
}
return weights.OrderBy(p => p.Value).Take(k).Select(p => (p.Key, p.Value));
}
public List<HashSet<int>> DetectClusters()
{
var links = new SortedList<long, float>();
for (int i = 0; i < _vectors.Count; i++)
{
for (int j = i + 1; j < _vectors.Count; j++)
{
long k = (((long)(i)) << HALF_LONG_BITS) + j;
links.Add(k, _distance(_vectors[i], _vectors[j]));
}
}
// 1. Find R - bound between intra-cluster distances and out-of-cluster distances
var histogram = new Histogram(HistogramMode.SQRT, links.Values);
int threshold = histogram.OTSU();
var min = histogram.Bounds[threshold - 1];
var max = histogram.Bounds[threshold];
var R = (max + min) / 2;
// 2. Get links with distances less than R
var resultLinks = new SortedList<long, float>();
foreach (var pair in links)
{
if (pair.Value < R)
{
resultLinks.Add(pair.Key, pair.Value);
}
}
// 3. Extract clusters
List<HashSet<int>> clusters = new List<HashSet<int>>();
foreach (var pair in resultLinks)
{
var k = pair.Key;
var id1 = (int)(k >> HALF_LONG_BITS);
var id2 = (int)(k - (((long)id1) << HALF_LONG_BITS));
bool found = false;
foreach (var c in clusters)
{
if (c.Contains(id1))
{
c.Add(id2);
found = true;
break;
}
else if (c.Contains(id2))
{
c.Add(id1);
found = true;
break;
}
}
if (found == false)
{
var c = new HashSet<int>();
c.Add(id1);
c.Add(id2);
clusters.Add(c);
}
}
return clusters;
}
}
public enum Gender
{
Unknown, Male, Feemale
}
public class Person
{
public Gender Gender { get; set; }
public int Age { get; set; }
public long Number { get; set; }
private static (float[], Person) Generate(int vector_size)
{
var rnd = new Random((int)Environment.TickCount);
var vector = new float[vector_size];
DefaultRandomGenerator.Instance.NextFloats(vector);
VectorUtils.NormalizeSIMD(vector);
var p = new Person();
p.Age = rnd.Next(15, 80);
var gr = rnd.Next(0, 3);
p.Gender = (gr == 0) ? Gender.Male : (gr == 1) ? Gender.Feemale : Gender.Unknown;
p.Number = CreateNumber(rnd);
return (vector, p);
}
public static List<(float[], Person)> GenerateRandom(int vectorSize, int vectorsCount)
{
var vectors = new List<(float[], Person)>();
for (int i = 0; i < vectorsCount; i++)
{
vectors.Add(Generate(vectorSize));
}
return vectors;
}
static HashSet<long> _exists = new HashSet<long>();
private static long CreateNumber(Random rnd)
{
long start_number;
do
{
start_number = 79600000000L;
start_number = start_number + rnd.Next(4, 8) * 10000000;
start_number += rnd.Next(0, 1000000);
}
while (_exists.Add(start_number) == false);
return start_number;
}
}
private static List<float[]> RandomVectors(int vectorSize, int vectorsCount)
{
var vectors = new List<float[]>();
for (int i = 0; i < vectorsCount; i++)
{
var vector = new float[vectorSize];
DefaultRandomGenerator.Instance.NextFloats(vector);
VectorUtils.NormalizeSIMD(vector);
vectors.Add(vector);
}
return vectors;
}
static void Main(string[] args) static void Main(string[] args)
{ {
QuantizatorTest(); new AutoClusteringTest().Run();
Console.WriteLine("Completed"); Console.WriteLine("Completed");
Console.ReadKey(); Console.ReadKey();
} }
static void QAccuracityTest()
{
int K = 200;
var count = 5000;
var testCount = 500;
var dimensionality = 128;
var totalHits = new List<int>();
var timewatchesNP = new List<float>();
var timewatchesHNSW = new List<float>();
var q = new Quantizator(-1f, 1f);
var samples = RandomVectors(dimensionality, count).Select(v => q.QuantizeToLong(v)).ToList();
var sw = new Stopwatch();
var test = new QLVectorsDirectCompare(samples, CosineDistance.NonOptimized);
var world = new SmallWorld<long[]>(NSWOptions<long[]>.Create(8, 12, 100, 100, CosineDistance.NonOptimized));
sw.Start();
var ids = world.AddItems(samples.ToArray());
sw.Stop();
Console.WriteLine($"Insert {ids.Length} items: {sw.ElapsedMilliseconds} ms");
Console.WriteLine("Start test");
var test_vectors = RandomVectors(dimensionality, testCount).Select(v => q.QuantizeToLong(v)).ToList();
foreach (var v in test_vectors)
{
sw.Restart();
var gt = test.KNearest(v, K).ToDictionary(p => p.Item1, p => p.Item2);
sw.Stop();
timewatchesNP.Add(sw.ElapsedMilliseconds);
sw.Restart();
var result = world.Search(v, K);
sw.Stop();
timewatchesHNSW.Add(sw.ElapsedMilliseconds);
var hits = 0;
foreach (var r in result)
{
if (gt.ContainsKey(r.Item1))
{
hits++;
}
}
totalHits.Add(hits);
}
Console.WriteLine($"MIN Accuracity: {totalHits.Min() * 100 / K}%");
Console.WriteLine($"AVG Accuracity: {totalHits.Average() * 100 / K}%");
Console.WriteLine($"MAX Accuracity: {totalHits.Max() * 100 / K}%");
Console.WriteLine($"MIN HNSW TIME: {timewatchesHNSW.Min()} ms");
Console.WriteLine($"AVG HNSW TIME: {timewatchesHNSW.Average()} ms");
Console.WriteLine($"MAX HNSW TIME: {timewatchesHNSW.Max()} ms");
Console.WriteLine($"MIN NP TIME: {timewatchesNP.Min()} ms");
Console.WriteLine($"AVG NP TIME: {timewatchesNP.Average()} ms");
Console.WriteLine($"MAX NP TIME: {timewatchesNP.Max()} ms");
}
static void QInsertTimeExplosionTest()
{
var count = 10000;
var iterationCount = 100;
var dimensionality = 128;
var sw = new Stopwatch();
var world = new SmallWorld<long[]>(NSWOptions<long[]>.Create(6, 12, 100, 100, CosineDistance.NonOptimized));
var q = new Quantizator(-1f, 1f);
for (int i = 0; i < iterationCount; i++)
{
var samples = RandomVectors(dimensionality, count);
sw.Restart();
var ids = world.AddItems(samples.Select(v => q.QuantizeToLong(v)).ToArray());
sw.Stop();
Console.WriteLine($"ITERATION: [{i.ToString("D4")}] COUNT: [{ids.Length}] ELAPSED [{sw.ElapsedMilliseconds} ms]");
}
}
static void AccuracityTest()
{
int K = 200;
var count = 3000;
var testCount = 500;
var dimensionality = 128;
var totalHits = new List<int>();
var timewatchesNP = new List<float>();
var timewatchesHNSW = new List<float>();
var samples = RandomVectors(dimensionality, count);
var sw = new Stopwatch();
var test = new VectorsDirectCompare(samples, CosineDistance.NonOptimized);
var world = new SmallWorld<float[]>(NSWOptions<float[]>.Create(8, 12, 100, 100, CosineDistance.NonOptimized));
sw.Start();
var ids = world.AddItems(samples.ToArray());
sw.Stop();
/*
byte[] dump;
using (var ms = new MemoryStream())
{
world.Serialize(ms);
dump = ms.ToArray();
}
Console.WriteLine($"Full dump size: {dump.Length} bytes");
ReadOnlySmallWorld<float[]> world;
using (var ms = new MemoryStream(dump))
{
world = SmallWorld.CreateReadOnlyWorldFrom<float[]>(NSWReadOnlyOption<float[]>.Create(100, CosineDistance.NonOptimized, true, true, selectionHeuristic: NeighbourSelectionHeuristic.SelectSimple), ms);
}
*/
Console.WriteLine($"Insert {ids.Length} items: {sw.ElapsedMilliseconds} ms");
Console.WriteLine("Start test");
var test_vectors = RandomVectors(dimensionality, testCount);
foreach (var v in test_vectors)
{
sw.Restart();
var gt = test.KNearest(v, K).ToDictionary(p => p.Item1, p => p.Item2);
sw.Stop();
timewatchesNP.Add(sw.ElapsedMilliseconds);
sw.Restart();
var result = world.Search(v, K);
sw.Stop();
timewatchesHNSW.Add(sw.ElapsedMilliseconds);
var hits = 0;
foreach (var r in result)
{
if (gt.ContainsKey(r.Item1))
{
hits++;
}
}
totalHits.Add(hits);
}
Console.WriteLine($"MIN Accuracity: {totalHits.Min() * 100 / K}%");
Console.WriteLine($"AVG Accuracity: {totalHits.Average() * 100 / K}%");
Console.WriteLine($"MAX Accuracity: {totalHits.Max() * 100 / K}%");
Console.WriteLine($"MIN HNSW TIME: {timewatchesHNSW.Min()} ms");
Console.WriteLine($"AVG HNSW TIME: {timewatchesHNSW.Average()} ms");
Console.WriteLine($"MAX HNSW TIME: {timewatchesHNSW.Max()} ms");
Console.WriteLine($"MIN NP TIME: {timewatchesNP.Min()} ms");
Console.WriteLine($"AVG NP TIME: {timewatchesNP.Average()} ms");
Console.WriteLine($"MAX NP TIME: {timewatchesNP.Max()} ms");
}
static void QuantizatorTest()
{
var samples = RandomVectors(128, 500000);
var min = samples.SelectMany(s => s).Min();
var max = samples.SelectMany(s => s).Max();
var q = new Quantizator(min, max);
var q_samples = samples.Select(s => q.QuantizeToLong(s)).ToArray();
// comparing
var list = new List<float>();
for (int i = 0; i < samples.Count - 1; i++)
{
var v1 = samples[i];
var v2 = samples[i + 1];
var dist = CosineDistance.NonOptimized(v1, v2);
var qv1 = q_samples[i];
var qv2 = q_samples[i + 1];
var qdist = CosineDistance.NonOptimized(qv1, qv2);
list.Add(Math.Abs(dist - qdist));
}
Console.WriteLine($"Min diff: {list.Min()}");
Console.WriteLine($"Avg diff: {list.Average()}");
Console.WriteLine($"Max diff: {list.Max()}");
}
static void SaveRestoreTest()
{
var count = 1000;
var dimensionality = 128;
var samples = RandomVectors(dimensionality, count);
var world = new SmallWorld<float[]>(NSWOptions<float[]>.Create(6, 15, 200, 200, CosineDistance.ForUnits));
var sw = new Stopwatch();
sw.Start();
var ids = world.AddItems(samples.ToArray());
sw.Stop();
Console.WriteLine($"Insert {ids.Length} items on {sw.ElapsedMilliseconds} ms");
Console.WriteLine("Start test");
byte[] dump;
using (var ms = new MemoryStream())
{
world.Serialize(ms);
dump = ms.ToArray();
}
Console.WriteLine($"Full dump size: {dump.Length} bytes");
byte[] testDump;
var restoredWorld = new SmallWorld<float[]>(NSWOptions<float[]>.Create(6, 15, 200, 200, CosineDistance.ForUnits));
using (var ms = new MemoryStream(dump))
{
restoredWorld.Deserialize(ms);
}
using (var ms = new MemoryStream())
{
restoredWorld.Serialize(ms);
testDump = ms.ToArray();
}
if (testDump.Length != dump.Length)
{
Console.WriteLine($"Incorrect restored size. Got {testDump.Length}. Expected: {dump.Length}");
return;
}
}
static void InsertTimeExplosionTest()
{
var count = 10000;
var iterationCount = 100;
var dimensionality = 128;
var sw = new Stopwatch();
var world = new SmallWorld<float[]>(NSWOptions<float[]>.Create(6, 12, 100, 100, CosineDistance.NonOptimized));
for (int i = 0; i < iterationCount; i++)
{
var samples = RandomVectors(dimensionality, count);
sw.Restart();
var ids = world.AddItems(samples.ToArray());
sw.Stop();
Console.WriteLine($"ITERATION: [{i.ToString("D4")}] COUNT: [{ids.Length}] ELAPSED [{sw.ElapsedMilliseconds} ms]");
}
}
/* /*
static void TestOnMnist() static void TestOnMnist()
{ {
@ -646,272 +67,8 @@ namespace HNSWDemo
{ {
Console.WriteLine($"Cluster {i + 1} countains {clusters[i].Count} items"); Console.WriteLine($"Cluster {i + 1} countains {clusters[i].Count} items");
} }
}
static void AutoClusteringTest()
{
var vectors = RandomVectors(128, 3000);
var world = SmallWorld.CreateWorld<float[]>(NSWOptions<float[]>.Create(8, 16, 200, 200, Metrics.L2Euclidean, selectionHeuristic: NeighbourSelectionHeuristic.SelectSimple));
world.AddItems(vectors);
var clusters = AutomaticGraphClusterer.DetectClusters(world);
Console.WriteLine($"Found {clusters.Count} clusters");
for (int i = 0; i < clusters.Count; i++)
{
Console.WriteLine($"Cluster {i + 1} countains {clusters[i].Count} items");
}
}
static void HistogramTest()
{
var vectors = RandomVectors(128, 3000);
var world = SmallWorld.CreateWorld<float[]>(NSWOptions<float[]>.Create(8, 16, 200, 200, Metrics.L2Euclidean, selectionHeuristic: NeighbourSelectionHeuristic.SelectSimple));
world.AddItems(vectors);
var histogram = world.GetHistogram();
int threshold = histogram.OTSU();
var min = histogram.Bounds[threshold - 1];
var max = histogram.Bounds[threshold];
var R = (max + min) / 2;
DrawHistogram(histogram, @"D:\hist.jpg");
}
static void DrawHistogram(Histogram histogram, string filename)
{
var wb = 1200 / histogram.Values.Length;
var k = 600.0f / (float)histogram.Values.Max();
var maxes = histogram.GetMaximums().ToDictionary(m => m.Index, m => m);
int threshold = histogram.OTSU();
using (var bmp = new Bitmap(1200, 600))
{
using (var g = Graphics.FromImage(bmp))
{
for (int i = 0; i<histogram.Values.Length; i++)
{
var height = (int)(histogram.Values[i] * k);
if (maxes.ContainsKey(i))
{
g.DrawRectangle(Pens.Red, i* wb, bmp.Height - height, wb, height);
g.DrawRectangle(Pens.Red, i* wb + 1, bmp.Height - height, wb - 1, height);
}
else
{
g.DrawRectangle(Pens.Blue, i * wb, bmp.Height - height, wb, height);
}
if (i == threshold)
{
g.DrawLine(Pens.Green, i * wb + wb / 2, 0, i * wb + wb / 2, bmp.Height);
}
}
}
bmp.Save(filename);
}
}
static void TransformToCompactWorldTest()
{
var count = 10000;
var dimensionality = 128;
var samples = RandomVectors(dimensionality, count);
var world = new SmallWorld<float[]>(NSWOptions<float[]>.Create(6, 15, 200, 200, CosineDistance.ForUnits));
var ids = world.AddItems(samples.ToArray());
Console.WriteLine("Start test");
byte[] dump;
using (var ms = new MemoryStream())
{
world.Serialize(ms);
dump = ms.ToArray();
}
Console.WriteLine($"Full dump size: {dump.Length} bytes");
ReadOnlySmallWorld<float[]> compactWorld;
using (var ms = new MemoryStream(dump))
{
compactWorld = SmallWorld.CreateReadOnlyWorldFrom<float[]>(NSWReadOnlyOption<float[]>.Create(200, CosineDistance.ForUnits), ms);
}
// Compare worlds outputs
int K = 200;
var hits = 0;
var miss = 0;
var testCount = 1000;
var sw = new Stopwatch();
var timewatchesHNSW = new List<float>();
var timewatchesHNSWCompact = new List<float>();
var test_vectors = RandomVectors(dimensionality, testCount);
foreach (var v in test_vectors)
{
sw.Restart();
var gt = world.Search(v, K).Select(e => e.Item1).ToHashSet();
sw.Stop();
timewatchesHNSW.Add(sw.ElapsedMilliseconds);
sw.Restart();
var result = compactWorld.Search(v, K).Select(e => e.Item1).ToHashSet();
sw.Stop();
timewatchesHNSWCompact.Add(sw.ElapsedMilliseconds);
foreach (var r in result)
{
if (gt.Contains(r))
{
hits++;
}
else
{
miss++;
}
}
}
byte[] smallWorldDump;
using (var ms = new MemoryStream())
{
compactWorld.Serialize(ms);
smallWorldDump = ms.ToArray();
}
var p = smallWorldDump.Length * 100.0f / dump.Length;
Console.WriteLine($"Compact dump size: {smallWorldDump.Length} bytes. Decrease: {100 - p}%");
Console.WriteLine($"HITS: {hits}");
Console.WriteLine($"MISSES: {miss}");
Console.WriteLine($"MIN HNSW TIME: {timewatchesHNSW.Min()} ms");
Console.WriteLine($"AVG HNSW TIME: {timewatchesHNSW.Average()} ms");
Console.WriteLine($"MAX HNSW TIME: {timewatchesHNSW.Max()} ms");
Console.WriteLine($"MIN HNSWCompact TIME: {timewatchesHNSWCompact.Min()} ms");
Console.WriteLine($"AVG HNSWCompact TIME: {timewatchesHNSWCompact.Average()} ms");
Console.WriteLine($"MAX HNSWCompact TIME: {timewatchesHNSWCompact.Max()} ms");
} }
static void TransformToCompactWorldTestWithAccuracity()
{
var count = 10000;
var dimensionality = 128;
var samples = RandomVectors(dimensionality, count);
var test = new VectorsDirectCompare(samples, CosineDistance.ForUnits);
var world = new SmallWorld<float[]>(NSWOptions<float[]>.Create(6, 15, 200, 200, CosineDistance.ForUnits));
var ids = world.AddItems(samples.ToArray());
Console.WriteLine("Start test");
byte[] dump;
using (var ms = new MemoryStream())
{
world.Serialize(ms);
dump = ms.ToArray();
}
ReadOnlySmallWorld<float[]> compactWorld;
using (var ms = new MemoryStream(dump))
{
compactWorld = SmallWorld.CreateReadOnlyWorldFrom<float[]>(NSWReadOnlyOption<float[]>.Create(200, CosineDistance.ForUnits), ms);
}
// Compare worlds outputs
int K = 200;
var hits = 0;
var miss = 0;
var testCount = 2000;
var sw = new Stopwatch();
var timewatchesNP = new List<float>();
var timewatchesHNSW = new List<float>();
var timewatchesHNSWCompact = new List<float>();
var test_vectors = RandomVectors(dimensionality, testCount);
var totalHitsHNSW = new List<int>();
var totalHitsHNSWCompact = new List<int>();
foreach (var v in test_vectors)
{
var npHitsHNSW = 0;
var npHitsHNSWCompact = 0;
sw.Restart();
var gtNP = test.KNearest(v, K).Select(p => p.Item1).ToHashSet();
sw.Stop();
timewatchesNP.Add(sw.ElapsedMilliseconds);
sw.Restart();
var gt = world.Search(v, K).Select(e => e.Item1).ToHashSet();
sw.Stop();
timewatchesHNSW.Add(sw.ElapsedMilliseconds);
sw.Restart();
var result = compactWorld.Search(v, K).Select(e => e.Item1).ToHashSet();
sw.Stop();
timewatchesHNSWCompact.Add(sw.ElapsedMilliseconds);
foreach (var r in result)
{
if (gt.Contains(r))
{
hits++;
}
else
{
miss++;
}
if (gtNP.Contains(r))
{
npHitsHNSWCompact++;
}
}
foreach (var r in gt)
{
if (gtNP.Contains(r))
{
npHitsHNSW++;
}
}
totalHitsHNSW.Add(npHitsHNSW);
totalHitsHNSWCompact.Add(npHitsHNSWCompact);
}
byte[] smallWorldDump;
using (var ms = new MemoryStream())
{
compactWorld.Serialize(ms);
smallWorldDump = ms.ToArray();
}
var p = smallWorldDump.Length * 100.0f / dump.Length;
Console.WriteLine($"Full dump size: {dump.Length} bytes");
Console.WriteLine($"Compact dump size: {smallWorldDump.Length} bytes. Decrease: {100 - p}%");
Console.WriteLine($"HITS: {hits}");
Console.WriteLine($"MISSES: {miss}");
Console.WriteLine($"MIN NP TIME: {timewatchesNP.Min()} ms");
Console.WriteLine($"AVG NP TIME: {timewatchesNP.Average()} ms");
Console.WriteLine($"MAX NP TIME: {timewatchesNP.Max()} ms");
Console.WriteLine($"MIN HNSW TIME: {timewatchesHNSW.Min()} ms");
Console.WriteLine($"AVG HNSW TIME: {timewatchesHNSW.Average()} ms");
Console.WriteLine($"MAX HNSW TIME: {timewatchesHNSW.Max()} ms");
Console.WriteLine($"MIN HNSWCompact TIME: {timewatchesHNSWCompact.Min()} ms");
Console.WriteLine($"AVG HNSWCompact TIME: {timewatchesHNSWCompact.Average()} ms");
Console.WriteLine($"MAX HNSWCompact TIME: {timewatchesHNSWCompact.Max()} ms");
Console.WriteLine($"MIN HNSW Accuracity: {totalHitsHNSW.Min() * 100 / K}%");
Console.WriteLine($"AVG HNSW Accuracity: {totalHitsHNSW.Average() * 100 / K}%");
Console.WriteLine($"MAX HNSW Accuracity: {totalHitsHNSW.Max() * 100 / K}%");
Console.WriteLine($"MIN HNSWCompact Accuracity: {totalHitsHNSWCompact.Min() * 100 / K}%");
Console.WriteLine($"AVG HNSWCompact Accuracity: {totalHitsHNSWCompact.Average() * 100 / K}%");
Console.WriteLine($"MAX HNSWCompact Accuracity: {totalHitsHNSWCompact.Max() * 100 / K}%");
}
static void FilterTest() static void FilterTest()
{ {

@ -0,0 +1,75 @@
using HNSWDemo.Utils;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using ZeroLevel.HNSW;
namespace HNSWDemo.Tests
{
public class AccuracityTest
: ITest
{
private static int K = 200;
private static int count = 3000;
private static int testCount = 500;
private static int dimensionality = 128;
public void Run()
{
var totalHits = new List<int>();
var timewatchesNP = new List<float>();
var timewatchesHNSW = new List<float>();
var samples = VectorUtils.RandomVectors(dimensionality, count);
var sw = new Stopwatch();
var test = new VectorsDirectCompare(samples, CosineDistance.NonOptimized);
var world = new SmallWorld<float[]>(NSWOptions<float[]>.Create(8, 12, 100, 100, CosineDistance.NonOptimized));
sw.Start();
var ids = world.AddItems(samples.ToArray());
sw.Stop();
Console.WriteLine($"Insert {ids.Length} items: {sw.ElapsedMilliseconds} ms");
Console.WriteLine("Start test");
var test_vectors = VectorUtils.RandomVectors(dimensionality, testCount);
foreach (var v in test_vectors)
{
sw.Restart();
var gt = test.KNearest(v, K).ToDictionary(p => p.Item1, p => p.Item2);
sw.Stop();
timewatchesNP.Add(sw.ElapsedMilliseconds);
sw.Restart();
var result = world.Search(v, K);
sw.Stop();
timewatchesHNSW.Add(sw.ElapsedMilliseconds);
var hits = 0;
foreach (var r in result)
{
if (gt.ContainsKey(r.Item1))
{
hits++;
}
}
totalHits.Add(hits);
}
Console.WriteLine($"MIN Accuracity: {totalHits.Min() * 100 / K}%");
Console.WriteLine($"AVG Accuracity: {totalHits.Average() * 100 / K}%");
Console.WriteLine($"MAX Accuracity: {totalHits.Max() * 100 / K}%");
Console.WriteLine($"MIN HNSW TIME: {timewatchesHNSW.Min()} ms");
Console.WriteLine($"AVG HNSW TIME: {timewatchesHNSW.Average()} ms");
Console.WriteLine($"MAX HNSW TIME: {timewatchesHNSW.Max()} ms");
Console.WriteLine($"MIN NP TIME: {timewatchesNP.Min()} ms");
Console.WriteLine($"AVG NP TIME: {timewatchesNP.Average()} ms");
Console.WriteLine($"MAX NP TIME: {timewatchesNP.Max()} ms");
}
}
}

@ -0,0 +1,26 @@
using System;
using ZeroLevel.HNSW;
using ZeroLevel.HNSW.Services;
namespace HNSWDemo.Tests
{
public class AutoClusteringTest
: ITest
{
private static int Count = 3000;
private static int Dimensionality = 128;
public void Run()
{
var vectors = VectorUtils.RandomVectors(Dimensionality, Count);
var world = SmallWorld.CreateWorld<float[]>(NSWOptions<float[]>.Create(8, 16, 200, 200, Metrics.L2Euclidean));
world.AddItems(vectors);
var clusters = AutomaticGraphClusterer.DetectClusters(world);
Console.WriteLine($"Found {clusters.Count} clusters");
for (int i = 0; i < clusters.Count; i++)
{
Console.WriteLine($"Cluster {i + 1} countains {clusters[i].Count} items");
}
}
}
}

@ -0,0 +1,69 @@
using System;
using System.Drawing;
using System.Linq;
using ZeroLevel.HNSW;
namespace HNSWDemo.Tests
{
public class HistogramTest
: ITest
{
private static int Count = 3000;
private static int Dimensionality = 128;
private static int Width = 3000;
private static int Height = 3000;
public void Run()
{
var vectors = VectorUtils.RandomVectors(Dimensionality, Count);
var world = SmallWorld.CreateWorld<float[]>(NSWOptions<float[]>.Create(8, 16, 200, 200, Metrics.L2Euclidean));
world.AddItems(vectors);
var distance = new Func<int, int, float>((id1, id2) => Metrics.L2Euclidean(world.GetVector(id1), world.GetVector(id2)));
var weights = world.GetLinks().SelectMany(pair => pair.Value.Select(id => distance(pair.Key, id)));
var histogram = new Histogram(HistogramMode.SQRT, weights);
histogram.Smooth();
int threshold = histogram.OTSU();
var min = histogram.Bounds[threshold - 1];
var max = histogram.Bounds[threshold];
var R = (max + min) / 2;
DrawHistogram(histogram, @"D:\hist.jpg");
}
static void DrawHistogram(Histogram histogram, string filename)
{
var wb = Width / histogram.Values.Length;
var k = ((float)Height) / (float)histogram.Values.Max();
var maxes = histogram.GetMaximums().ToDictionary(m => m.Index, m => m);
int threshold = histogram.OTSU();
using (var bmp = new Bitmap(Width, Height))
{
using (var g = Graphics.FromImage(bmp))
{
for (int i = 0; i < histogram.Values.Length; i++)
{
var height = (int)(histogram.Values[i] * k);
if (maxes.ContainsKey(i))
{
g.DrawRectangle(Pens.Red, i * wb, bmp.Height - height, wb, height);
g.DrawRectangle(Pens.Red, i * wb + 1, bmp.Height - height, wb - 1, height);
}
else
{
g.DrawRectangle(Pens.Blue, i * wb, bmp.Height - height, wb, height);
}
if (i == threshold)
{
g.DrawLine(Pens.Green, i * wb + wb / 2, 0, i * wb + wb / 2, bmp.Height);
}
}
}
bmp.Save(filename);
}
}
}
}

@ -0,0 +1,7 @@
namespace HNSWDemo.Tests
{
public interface ITest
{
void Run();
}
}

@ -0,0 +1,28 @@
using System;
using System.Diagnostics;
using ZeroLevel.HNSW;
namespace HNSWDemo.Tests
{
public class InsertTimeExplosionTest
: ITest
{
private static int Count = 10000;
private static int IterationCount = 100;
private static int Dimensionality = 128;
public void Run()
{
var sw = new Stopwatch();
var world = new SmallWorld<float[]>(NSWOptions<float[]>.Create(6, 12, 100, 100, CosineDistance.NonOptimized));
for (int i = 0; i < IterationCount; i++)
{
var samples = VectorUtils.RandomVectors(Dimensionality, Count);
sw.Restart();
var ids = world.AddItems(samples.ToArray());
sw.Stop();
Console.WriteLine($"ITERATION: [{i.ToString("D4")}] COUNT: [{ids.Length}] ELAPSED [{sw.ElapsedMilliseconds} ms]");
}
}
}
}

@ -0,0 +1,43 @@
using System;
using System.Collections.Generic;
using System.Linq;
using ZeroLevel.HNSW;
using ZeroLevel.HNSW.Services;
namespace HNSWDemo.Tests
{
public class QuantizatorTest
: ITest
{
private static int Count = 500000;
private static int Dimensionality = 128;
public void Run()
{
var samples = VectorUtils.RandomVectors(Dimensionality, Count);
var min = samples.SelectMany(s => s).Min();
var max = samples.SelectMany(s => s).Max();
var q = new Quantizator(min, max);
var q_samples = samples.Select(s => q.QuantizeToLong(s)).ToArray();
// comparing
var list = new List<float>();
for (int i = 0; i < samples.Count - 1; i++)
{
var v1 = samples[i];
var v2 = samples[i + 1];
var dist = CosineDistance.NonOptimized(v1, v2);
var qv1 = q_samples[i];
var qv2 = q_samples[i + 1];
var qdist = CosineDistance.NonOptimized(qv1, qv2);
list.Add(Math.Abs(dist - qdist));
}
Console.WriteLine($"Min diff: {list.Min()}");
Console.WriteLine($"Avg diff: {list.Average()}");
Console.WriteLine($"Max diff: {list.Max()}");
}
}
}

@ -0,0 +1,79 @@
using HNSWDemo.Utils;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using ZeroLevel.HNSW;
using ZeroLevel.HNSW.Services;
namespace HNSWDemo.Tests
{
public class QuantizeAccuracityTest
: ITest
{
private static int Count = 5000;
private static int Dimensionality = 128;
private static int K = 200;
private static int TestCount =500;
public void Run()
{
var totalHits = new List<int>();
var timewatchesNP = new List<float>();
var timewatchesHNSW = new List<float>();
var q = new Quantizator(-1f, 1f);
var s = VectorUtils.RandomVectors(Dimensionality, Count);
var samples = s.Select(v => q.QuantizeToLong(v)).ToList();
var sw = new Stopwatch();
var test = new VectorsDirectCompare(s, CosineDistance.NonOptimized);
var world = new SmallWorld<long[]>(NSWOptions<long[]>.Create(6, 8, 100, 100, CosineDistance.NonOptimized));
sw.Start();
var ids = world.AddItems(samples.ToArray());
sw.Stop();
Console.WriteLine($"Insert {ids.Length} items: {sw.ElapsedMilliseconds} ms");
Console.WriteLine("Start test");
var tv = VectorUtils.RandomVectors(Dimensionality, TestCount);
var test_vectors = tv.Select(v => q.QuantizeToLong(v)).ToList();
for (int i = 0; i < tv.Count; i++)
{
sw.Restart();
var gt = test.KNearest(tv[i], K).ToDictionary(p => p.Item1, p => p.Item2);
sw.Stop();
timewatchesNP.Add(sw.ElapsedMilliseconds);
sw.Restart();
var result = world.Search(test_vectors[i], K);
sw.Stop();
timewatchesHNSW.Add(sw.ElapsedMilliseconds);
var hits = 0;
foreach (var r in result)
{
if (gt.ContainsKey(r.Item1))
{
hits++;
}
}
totalHits.Add(hits);
}
Console.WriteLine($"MIN Accuracity: {totalHits.Min() * 100 / K}%");
Console.WriteLine($"AVG Accuracity: {totalHits.Average() * 100 / K}%");
Console.WriteLine($"MAX Accuracity: {totalHits.Max() * 100 / K}%");
Console.WriteLine($"MIN HNSW TIME: {timewatchesHNSW.Min()} ms");
Console.WriteLine($"AVG HNSW TIME: {timewatchesHNSW.Average()} ms");
Console.WriteLine($"MAX HNSW TIME: {timewatchesHNSW.Max()} ms");
Console.WriteLine($"MIN NP TIME: {timewatchesNP.Min()} ms");
Console.WriteLine($"AVG NP TIME: {timewatchesNP.Average()} ms");
Console.WriteLine($"MAX NP TIME: {timewatchesNP.Max()} ms");
}
}
}

@ -0,0 +1,71 @@
using System;
using System.Drawing;
using System.Linq;
using ZeroLevel.HNSW;
using ZeroLevel.HNSW.Services;
namespace HNSWDemo.Tests
{
public class QuantizeHistogramTest
: ITest
{
private static int Count = 3000;
private static int Dimensionality = 128;
private static int Width = 3000;
private static int Height = 3000;
public void Run()
{
var vectors = VectorUtils.RandomVectors(Dimensionality, Count);
var q = new Quantizator(-1f, 1f);
var world = SmallWorld.CreateWorld<long[]>(NSWOptions<long[]>.Create(8, 16, 200, 200, CosineDistance.NonOptimized));
world.AddItems(vectors.Select(v => q.QuantizeToLong(v)).ToList());
var distance = new Func<int, int, float>((id1, id2) => CosineDistance.NonOptimized(world.GetVector(id1), world.GetVector(id2)));
var weights = world.GetLinks().SelectMany(pair => pair.Value.Select(id => distance(pair.Key, id)));
var histogram = new Histogram(HistogramMode.SQRT, weights);
histogram.Smooth();
int threshold = histogram.OTSU();
var min = histogram.Bounds[threshold - 1];
var max = histogram.Bounds[threshold];
var R = (max + min) / 2;
DrawHistogram(histogram, @"D:\hist.jpg");
}
static void DrawHistogram(Histogram histogram, string filename)
{
var wb = Width / histogram.Values.Length;
var k = ((float)Height) / (float)histogram.Values.Max();
var maxes = histogram.GetMaximums().ToDictionary(m => m.Index, m => m);
int threshold = histogram.OTSU();
using (var bmp = new Bitmap(Width, Height))
{
using (var g = Graphics.FromImage(bmp))
{
for (int i = 0; i < histogram.Values.Length; i++)
{
var height = (int)(histogram.Values[i] * k);
if (maxes.ContainsKey(i))
{
g.DrawRectangle(Pens.Red, i * wb, bmp.Height - height, wb, height);
g.DrawRectangle(Pens.Red, i * wb + 1, bmp.Height - height, wb - 1, height);
}
else
{
g.DrawRectangle(Pens.Blue, i * wb, bmp.Height - height, wb, height);
}
if (i == threshold)
{
g.DrawLine(Pens.Green, i * wb + wb / 2, 0, i * wb + wb / 2, bmp.Height);
}
}
}
bmp.Save(filename);
}
}
}
}

@ -0,0 +1,31 @@
using System;
using System.Diagnostics;
using System.Linq;
using ZeroLevel.HNSW;
using ZeroLevel.HNSW.Services;
namespace HNSWDemo.Tests
{
public class QuantizeInsertTimeExplosionTest
: ITest
{
private static int Count = 10000;
private static int IterationCount = 100;
private static int Dimensionality = 128;
public void Run()
{
var sw = new Stopwatch();
var world = new SmallWorld<long[]>(NSWOptions<long[]>.Create(6, 12, 100, 100, CosineDistance.NonOptimized));
var q = new Quantizator(-1f, 1f);
for (int i = 0; i < IterationCount; i++)
{
var samples = VectorUtils.RandomVectors(Dimensionality, Count);
sw.Restart();
var ids = world.AddItems(samples.Select(v => q.QuantizeToLong(v)).ToArray());
sw.Stop();
Console.WriteLine($"ITERATION: [{i.ToString("D4")}] COUNT: [{ids.Length}] ELAPSED [{sw.ElapsedMilliseconds} ms]");
}
}
}
}

@ -0,0 +1,52 @@
using System;
using System.Diagnostics;
using System.IO;
using ZeroLevel.HNSW;
namespace HNSWDemo.Tests
{
public class SaveRestoreTest
: ITest
{
private static int Count = 1000;
private static int Dimensionality = 128;
public void Run()
{
var samples = VectorUtils.RandomVectors(Dimensionality, Count);
var world = new SmallWorld<float[]>(NSWOptions<float[]>.Create(6, 15, 200, 200, CosineDistance.ForUnits));
var sw = new Stopwatch();
sw.Start();
var ids = world.AddItems(samples.ToArray());
sw.Stop();
Console.WriteLine($"Insert {ids.Length} items on {sw.ElapsedMilliseconds} ms");
Console.WriteLine("Start test");
byte[] dump;
using (var ms = new MemoryStream())
{
world.Serialize(ms);
dump = ms.ToArray();
}
Console.WriteLine($"Full dump size: {dump.Length} bytes");
byte[] testDump;
var restoredWorld = new SmallWorld<float[]>(NSWOptions<float[]>.Create(6, 15, 200, 200, CosineDistance.ForUnits));
using (var ms = new MemoryStream(dump))
{
restoredWorld.Deserialize(ms);
}
using (var ms = new MemoryStream())
{
restoredWorld.Serialize(ms);
testDump = ms.ToArray();
}
if (testDump.Length != dump.Length)
{
Console.WriteLine($"Incorrect restored size. Got {testDump.Length}. Expected: {dump.Length}");
return;
}
}
}
}

@ -0,0 +1,95 @@
using System;
using System.Collections.Generic;
using System.Linq;
using ZeroLevel.HNSW;
namespace HNSWDemo.Utils
{
public class QLVectorsDirectCompare
{
private const int HALF_LONG_BITS = 32;
private readonly IList<long[]> _vectors;
private readonly Func<long[], long[], float> _distance;
public QLVectorsDirectCompare(List<long[]> vectors, Func<long[], long[], float> distance)
{
_vectors = vectors;
_distance = distance;
}
public IEnumerable<(int, float)> KNearest(long[] v, int k)
{
var weights = new Dictionary<int, float>();
for (int i = 0; i < _vectors.Count; i++)
{
var d = _distance(v, _vectors[i]);
weights[i] = d;
}
return weights.OrderBy(p => p.Value).Take(k).Select(p => (p.Key, p.Value));
}
public List<HashSet<int>> DetectClusters()
{
var links = new SortedList<long, float>();
for (int i = 0; i < _vectors.Count; i++)
{
for (int j = i + 1; j < _vectors.Count; j++)
{
long k = (((long)(i)) << HALF_LONG_BITS) + j;
links.Add(k, _distance(_vectors[i], _vectors[j]));
}
}
// 1. Find R - bound between intra-cluster distances and out-of-cluster distances
var histogram = new Histogram(HistogramMode.SQRT, links.Values);
int threshold = histogram.OTSU();
var min = histogram.Bounds[threshold - 1];
var max = histogram.Bounds[threshold];
var R = (max + min) / 2;
// 2. Get links with distances less than R
var resultLinks = new SortedList<long, float>();
foreach (var pair in links)
{
if (pair.Value < R)
{
resultLinks.Add(pair.Key, pair.Value);
}
}
// 3. Extract clusters
List<HashSet<int>> clusters = new List<HashSet<int>>();
foreach (var pair in resultLinks)
{
var k = pair.Key;
var id1 = (int)(k >> HALF_LONG_BITS);
var id2 = (int)(k - (((long)id1) << HALF_LONG_BITS));
bool found = false;
foreach (var c in clusters)
{
if (c.Contains(id1))
{
c.Add(id2);
found = true;
break;
}
else if (c.Contains(id2))
{
c.Add(id1);
found = true;
break;
}
}
if (found == false)
{
var c = new HashSet<int>();
c.Add(id1);
c.Add(id2);
clusters.Add(c);
}
}
return clusters;
}
}
}

@ -0,0 +1,95 @@
using System;
using System.Collections.Generic;
using System.Linq;
using ZeroLevel.HNSW;
namespace HNSWDemo.Utils
{
public class QVectorsDirectCompare
{
private const int HALF_LONG_BITS = 32;
private readonly IList<byte[]> _vectors;
private readonly Func<byte[], byte[], float> _distance;
public QVectorsDirectCompare(List<byte[]> vectors, Func<byte[], byte[], float> distance)
{
_vectors = vectors;
_distance = distance;
}
public IEnumerable<(int, float)> KNearest(byte[] v, int k)
{
var weights = new Dictionary<int, float>();
for (int i = 0; i < _vectors.Count; i++)
{
var d = _distance(v, _vectors[i]);
weights[i] = d;
}
return weights.OrderBy(p => p.Value).Take(k).Select(p => (p.Key, p.Value));
}
public List<HashSet<int>> DetectClusters()
{
var links = new SortedList<long, float>();
for (int i = 0; i < _vectors.Count; i++)
{
for (int j = i + 1; j < _vectors.Count; j++)
{
long k = (((long)i) << HALF_LONG_BITS) + j;
links.Add(k, _distance(_vectors[i], _vectors[j]));
}
}
// 1. Find R - bound between intra-cluster distances and out-of-cluster distances
var histogram = new Histogram(HistogramMode.SQRT, links.Values);
int threshold = histogram.OTSU();
var min = histogram.Bounds[threshold - 1];
var max = histogram.Bounds[threshold];
var R = (max + min) / 2;
// 2. Get links with distances less than R
var resultLinks = new SortedList<long, float>();
foreach (var pair in links)
{
if (pair.Value < R)
{
resultLinks.Add(pair.Key, pair.Value);
}
}
// 3. Extract clusters
List<HashSet<int>> clusters = new List<HashSet<int>>();
foreach (var pair in resultLinks)
{
var k = pair.Key;
var id1 = (int)(k >> HALF_LONG_BITS);
var id2 = (int)(k - (((long)id1) << HALF_LONG_BITS));
bool found = false;
foreach (var c in clusters)
{
if (c.Contains(id1))
{
c.Add(id2);
found = true;
break;
}
else if (c.Contains(id2))
{
c.Add(id1);
found = true;
break;
}
}
if (found == false)
{
var c = new HashSet<int>();
c.Add(id1);
c.Add(id2);
clusters.Add(c);
}
}
return clusters;
}
}
}

@ -0,0 +1,95 @@
using System;
using System.Collections.Generic;
using System.Linq;
using ZeroLevel.HNSW;
namespace HNSWDemo.Utils
{
public class VectorsDirectCompare
{
private const int HALF_LONG_BITS = 32;
private readonly IList<float[]> _vectors;
private readonly Func<float[], float[], float> _distance;
public VectorsDirectCompare(List<float[]> vectors, Func<float[], float[], float> distance)
{
_vectors = vectors;
_distance = distance;
}
public IEnumerable<(int, float)> KNearest(float[] v, int k)
{
var weights = new Dictionary<int, float>();
for (int i = 0; i < _vectors.Count; i++)
{
var d = _distance(v, _vectors[i]);
weights[i] = d;
}
return weights.OrderBy(p => p.Value).Take(k).Select(p => (p.Key, p.Value));
}
public List<HashSet<int>> DetectClusters()
{
var links = new SortedList<long, float>();
for (int i = 0; i < _vectors.Count; i++)
{
for (int j = i + 1; j < _vectors.Count; j++)
{
long k = (((long)(i)) << HALF_LONG_BITS) + j;
links.Add(k, _distance(_vectors[i], _vectors[j]));
}
}
// 1. Find R - bound between intra-cluster distances and out-of-cluster distances
var histogram = new Histogram(HistogramMode.SQRT, links.Values);
int threshold = histogram.OTSU();
var min = histogram.Bounds[threshold - 1];
var max = histogram.Bounds[threshold];
var R = (max + min) / 2;
// 2. Get links with distances less than R
var resultLinks = new SortedList<long, float>();
foreach (var pair in links)
{
if (pair.Value < R)
{
resultLinks.Add(pair.Key, pair.Value);
}
}
// 3. Extract clusters
List<HashSet<int>> clusters = new List<HashSet<int>>();
foreach (var pair in resultLinks)
{
var k = pair.Key;
var id1 = (int)(k >> HALF_LONG_BITS);
var id2 = (int)(k - (((long)id1) << HALF_LONG_BITS));
bool found = false;
foreach (var c in clusters)
{
if (c.Contains(id1))
{
c.Add(id2);
found = true;
break;
}
else if (c.Contains(id2))
{
c.Add(id1);
found = true;
break;
}
}
if (found == false)
{
var c = new HashSet<int>();
c.Add(id1);
c.Add(id2);
clusters.Add(c);
}
}
return clusters;
}
}
}

@ -16,6 +16,12 @@ namespace ZeroLevel.HNSW
/// Max search buffer for inserting /// Max search buffer for inserting
/// </summary> /// </summary>
public readonly int EFConstruction; public readonly int EFConstruction;
public static NSWOptions<float[]> Create(int v1, int v2, int v3, int v4, Func<float[], float[], float> l2Euclidean, object selectionHeuristic)
{
throw new NotImplementedException();
}
/// <summary> /// <summary>
/// Distance function beetween vectors /// Distance function beetween vectors
/// </summary> /// </summary>

@ -1,4 +1,6 @@
using System.Collections.Generic; using System;
using System.Collections.Generic;
using System.Linq;
namespace ZeroLevel.HNSW.Services namespace ZeroLevel.HNSW.Services
{ {
@ -6,11 +8,20 @@ namespace ZeroLevel.HNSW.Services
{ {
private const int HALF_LONG_BITS = 32; private const int HALF_LONG_BITS = 32;
/*public static List<HashSet<int>> DetectClusters<T>(SmallWorld<T> world) private class Link
{ {
var links = world.GetNSWLinks(); public int Id1;
public int Id2;
public float Distance;
}
public static List<HashSet<int>> DetectClusters<T>(SmallWorld<T> world)
{
var distance = world.DistanceFunction;
var links = world.GetLinks().SelectMany(pair => pair.Value.Select(id => new Link { Id1 = pair.Key, Id2 = id, Distance = distance(pair.Key, id) })).ToList();
// 1. Find R - bound between intra-cluster distances and out-of-cluster distances // 1. Find R - bound between intra-cluster distances and out-of-cluster distances
var histogram = new Histogram(HistogramMode.SQRT, links.Values); var histogram = new Histogram(HistogramMode.SQRT, links.Select(l => l.Distance));
int threshold = histogram.OTSU(); int threshold = histogram.OTSU();
var min = histogram.Bounds[threshold - 1]; var min = histogram.Bounds[threshold - 1];
var max = histogram.Bounds[threshold]; var max = histogram.Bounds[threshold];
@ -18,23 +29,21 @@ namespace ZeroLevel.HNSW.Services
// 2. Get links with distances less than R // 2. Get links with distances less than R
var resultLinks = new SortedList<long, float>(); var resultLinks = new List<Link>();
foreach (var pair in links) foreach (var l in links)
{ {
if (pair.Value < R) if (l.Distance < R)
{ {
resultLinks.Add(pair.Key, pair.Value); resultLinks.Add(l);
} }
} }
// 3. Extract clusters // 3. Extract clusters
List<HashSet<int>> clusters = new List<HashSet<int>>(); List<HashSet<int>> clusters = new List<HashSet<int>>();
foreach (var pair in resultLinks) foreach (var l in resultLinks)
{ {
var k = pair.Key; var id1 = l.Id1;
var id1 = (int)(k >> HALF_LONG_BITS); var id2 = l.Id2;
var id2 = (int)(k - (((long)id1) << HALF_LONG_BITS));
bool found = false; bool found = false;
foreach (var c in clusters) foreach (var c in clusters)
{ {
@ -60,6 +69,6 @@ namespace ZeroLevel.HNSW.Services
} }
} }
return clusters; return clusters;
}*/ }
} }
} }

@ -2,7 +2,6 @@
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using ZeroLevel.HNSW.Services; using ZeroLevel.HNSW.Services;
using ZeroLevel.Services.Pools;
using ZeroLevel.Services.Serialization; using ZeroLevel.Services.Serialization;
namespace ZeroLevel.HNSW namespace ZeroLevel.HNSW
@ -16,20 +15,28 @@ namespace ZeroLevel.HNSW
private readonly NSWOptions<TItem> _options; private readonly NSWOptions<TItem> _options;
private readonly VectorSet<TItem> _vectors; private readonly VectorSet<TItem> _vectors;
private readonly LinksSet _links; private readonly LinksSet _links;
//internal SortedList<long, float> Links => _links.Links; public readonly int M;
private readonly Dictionary<int, float> connections;
internal IDictionary<int, HashSet<int>> Links => _links.Links;
/// <summary> /// <summary>
/// There are links е the layer /// There are links е the layer
/// </summary> /// </summary>
internal bool HasLinks => (_links.Count > 0); internal bool HasLinks => (_links.Count > 0);
private int GetM(bool nswLayer) internal IEnumerable<int> this[int vector_index] => _links.FindNeighbors(vector_index);
{
return nswLayer ? 2 * _options.M : _options.M;
}
/// <summary> /// <summary>
/// HNSW layer /// HNSW layer
/// <remarks>
/// Article: Section 4.1:
/// "Selection of the Mmax0 (the maximum number of connections that an element can have in the zero layer) also
/// has a strong influence on the search performance, especially in case of high quality(high recall) search.
/// Simulations show that setting Mmax0 to M(this corresponds to kNN graphs on each layer if the neighbors
/// selection heuristic is not used) leads to a very strong performance penalty at high recall.
/// Simulations also suggest that 2∙M is a good choice for Mmax0;
/// setting the parameter higher leads to performance degradation and excessive memory usage."
/// </remarks>
/// </summary> /// </summary>
/// <param name="options">HNSW graph options</param> /// <param name="options">HNSW graph options</param>
/// <param name="vectors">General vector set</param> /// <param name="vectors">General vector set</param>
@ -37,7 +44,9 @@ namespace ZeroLevel.HNSW
{ {
_options = options; _options = options;
_vectors = vectors; _vectors = vectors;
_links = new LinksSet(GetM(nswLayer), (id1, id2) => options.Distance(_vectors[id1], _vectors[id2])); M = nswLayer ? 2 * _options.M : _options.M;
_links = new LinksSet(M);
connections = new Dictionary<int, float>(M + 1);
} }
internal int FindEntryPointAtLayer(Func<int, float> targetCosts) internal int FindEntryPointAtLayer(Func<int, float> targetCosts)
@ -58,13 +67,89 @@ namespace ZeroLevel.HNSW
return minId; return minId;
} }
internal void AddBidirectionallConnections(int q, int p) internal void Push(int q, int ep, MinHeap W, Func<int, float> distance)
{
if (HasLinks == false)
{
AddBidirectionallConnections(q, q);
}
else
{
// W ← SEARCH - LAYER(q, ep, efConstruction, lc)
foreach (var i in KNearestAtLayer(ep, distance, _options.EFConstruction))
{
W.Push(i);
}
int count = 0;
connections.Clear();
while (count < M && W.Count > 0)
{
var nearest = W.Pop();
var nearest_nearest = GetNeighbors(nearest.Item1).ToArray();
if (nearest_nearest.Length < M)
{
if (AddBidirectionallConnections(q, nearest.Item1))
{
connections.Add(nearest.Item1, nearest.Item2);
count++;
}
}
else
{
if ((M - count) < 2)
{
// remove link q - max_q
var max = connections.OrderBy(pair => pair.Value).First();
RemoveBidirectionallConnections(q, max.Key);
connections.Remove(max.Key);
}
// get nearest_nearest candidate
var mn_id = -1;
var mn_d = float.MinValue;
for (int i = 0; i < nearest_nearest.Length; i++)
{
var d = _options.Distance(_vectors[nearest.Item1], _vectors[nearest_nearest[i]]);
if (q != nearest_nearest[i] && connections.ContainsKey(nearest_nearest[i]) == false)
{
if (mn_id == -1 || d > mn_d)
{
mn_d = d;
mn_id = nearest_nearest[i];
}
}
}
// remove link neareset - nearest_nearest
RemoveBidirectionallConnections(nearest.Item1, mn_id);
// add link q - neareset
if (AddBidirectionallConnections(q, nearest.Item1))
{
connections.Add(nearest.Item1, nearest.Item2);
count++;
}
// add link q - max_nearest_nearest
if (AddBidirectionallConnections(q, mn_id))
{
connections.Add(mn_id, mn_d);
count++;
}
}
}
}
}
internal void RemoveBidirectionallConnections(int q, int p)
{
_links.RemoveIndex(q, p);
}
internal bool AddBidirectionallConnections(int q, int p)
{ {
if (q == p) if (q == p)
{ {
if (EntryPoint >= 0) if (EntryPoint >= 0)
{ {
_links.Add(q, EntryPoint); return _links.Add(q, EntryPoint);
} }
else else
{ {
@ -73,14 +158,13 @@ namespace ZeroLevel.HNSW
} }
else else
{ {
_links.Add(q, p); return _links.Add(q, p);
} }
return false;
} }
private int EntryPoint = -1; private int EntryPoint = -1;
internal void Trim(int id) => _links.Trim(id);
#region Implementation of https://arxiv.org/ftp/arxiv/papers/1603/1603.09320.pdf #region Implementation of https://arxiv.org/ftp/arxiv/papers/1603/1603.09320.pdf
/// <summary> /// <summary>
/// Algorithm 2 /// Algorithm 2
@ -349,7 +433,7 @@ namespace ZeroLevel.HNSW
*/ */
#endregion #endregion
private IEnumerable<int> GetNeighbors(int id) => _links.FindNeighbors(id); internal IEnumerable<int> GetNeighbors(int id) => _links.FindNeighbors(id);
public void Serialize(IBinaryWriter writer) public void Serialize(IBinaryWriter writer)
{ {

@ -6,161 +6,15 @@ using ZeroLevel.Services.Serialization;
namespace ZeroLevel.HNSW namespace ZeroLevel.HNSW
{ {
/*
internal struct Link
: IEquatable<Link>
{
public int Id;
public float Distance;
public override int GetHashCode()
{
return Id.GetHashCode();
}
public override bool Equals(object obj)
{
if (obj is Link)
return this.Equals((Link)obj);
return false;
}
public bool Equals(Link other)
{
return this.Id == other.Id;
}
}
public class LinksSetWithCachee
{
private ConcurrentDictionary<int, HashSet<Link>> _set = new ConcurrentDictionary<int, HashSet<Link>>();
internal int Count => _set.Count;
private readonly int _M;
private readonly Func<int, int, float> _distance;
public LinksSetWithCachee(int M, Func<int, int, float> distance)
{
_distance = distance;
_M = M;
}
internal IEnumerable<int> FindNeighbors(int id)
{
if (_set.ContainsKey(id))
{
return _set[id].Select(l=>l.Id);
}
return Enumerable.Empty<int>();
}
internal void RemoveIndex(int id1, int id2)
{
var link1 = new Link { Id = id1 };
var link2 = new Link { Id = id2 };
_set[id1].Remove(link2);
_set[id2].Remove(link1);
}
internal bool Add(int id1, int id2, float distance)
{
if (!_set.ContainsKey(id1))
{
_set[id1] = new HashSet<Link>();
}
if (!_set.ContainsKey(id2))
{
_set[id2] = new HashSet<Link>();
}
var r1 = _set[id1].Add(new Link { Id = id2, Distance = distance });
var r2 = _set[id2].Add(new Link { Id = id1, Distance = distance });
//TrimSet(_set[id1]);
TrimSet(id2, _set[id2]);
return r1 || r2;
}
internal void Trim(int id) => TrimSet(id, _set[id]);
private void TrimSet(int id, HashSet<Link> set)
{
if (set.Count > _M)
{
var removeCount = set.Count - _M;
var removeLinks = set.OrderByDescending(n => n.Distance).Take(removeCount).ToArray();
foreach (var l in removeLinks)
{
set.Remove(l);
}
}
}
public void Dispose()
{
_set.Clear();
_set = null;
}
private const int HALF_LONG_BITS = 32;
public void Serialize(IBinaryWriter writer)
{
writer.WriteBoolean(false); // true - set with weights
writer.WriteInt32(_set.Sum(pair => pair.Value.Count));
foreach (var record in _set)
{
var id = record.Key;
foreach (var r in record.Value)
{
var key = (((long)(id)) << HALF_LONG_BITS) + r;
writer.WriteLong(key);
}
}
}
public void Deserialize(IBinaryReader reader)
{
if (reader.ReadBoolean() == false)
{
throw new InvalidOperationException("Incompatible data format. The set does not contain weights.");
}
_set.Clear();
_set = null;
var count = reader.ReadInt32();
_set = new ConcurrentDictionary<int, HashSet<int>>();
for (int i = 0; i < count; i++)
{
var key = reader.ReadLong();
var id1 = (int)(key >> HALF_LONG_BITS);
var id2 = (int)(key - (((long)id1) << HALF_LONG_BITS));
if (!_set.ContainsKey(id1))
{
_set[id1] = new HashSet<int>();
}
_set[id1].Add(id2);
}
}
}
*/
public class LinksSet public class LinksSet
{ {
private ConcurrentDictionary<int, HashSet<int>> _set = new ConcurrentDictionary<int, HashSet<int>>(); private ConcurrentDictionary<int, HashSet<int>> _set = new ConcurrentDictionary<int, HashSet<int>>();
internal IDictionary<int, HashSet<int>> Links => _set; internal IDictionary<int, HashSet<int>> Links => _set;
internal int Count => _set.Count; internal int Count => _set.Count;
private readonly int _M; private readonly int _M;
private readonly Func<int, int, float> _distance;
public LinksSet(int M, Func<int, int, float> distance) public LinksSet(int M)
{ {
_distance = distance;
_M = M; _M = M;
} }
@ -207,27 +61,9 @@ namespace ZeroLevel.HNSW
} }
var r1 = _set[id1].Add(id2); var r1 = _set[id1].Add(id2);
var r2 = _set[id2].Add(id1); var r2 = _set[id2].Add(id1);
TrimSet(id1, _set[id1]);
TrimSet(id2, _set[id2]);
return r1 || r2; return r1 || r2;
} }
internal void Trim(int id) => TrimSet(id, _set[id]);
private void TrimSet(int id, HashSet<int> set)
{
if (set.Count > _M)
{
var removeCount = set.Count - _M;
var removeLinks = set.OrderByDescending(n => _distance(id, n)).Take(removeCount).ToArray();
foreach (var l in removeLinks)
{
set.Remove(l);
}
}
}
public void Dispose() public void Dispose()
{ {

@ -1,11 +1,12 @@
using System.Collections.Generic; using System.Collections;
using System.Collections.Generic;
using System.Threading; using System.Threading;
using ZeroLevel.Services.Serialization; using ZeroLevel.Services.Serialization;
namespace ZeroLevel.HNSW namespace ZeroLevel.HNSW
{ {
internal sealed class VectorSet<T> internal sealed class VectorSet<T>
: IBinarySerializable : IEnumerable<T>, IBinarySerializable
{ {
private List<T> _set = new List<T>(); private List<T> _set = new List<T>();
private SpinLock _lock = new SpinLock(); private SpinLock _lock = new SpinLock();
@ -73,5 +74,15 @@ namespace ZeroLevel.HNSW
writer.WriteCompatible<T>(r); writer.WriteCompatible<T>(r);
} }
} }
public IEnumerator<T> GetEnumerator()
{
return _set.GetEnumerator();
}
IEnumerator IEnumerable.GetEnumerator()
{
return _set.GetEnumerator();
}
} }
} }

@ -18,12 +18,19 @@ namespace ZeroLevel.HNSW
private readonly ProbabilityLayerNumberGenerator _layerLevelGenerator; private readonly ProbabilityLayerNumberGenerator _layerLevelGenerator;
private ReaderWriterLockSlim _lockGraph = new ReaderWriterLockSlim(); private ReaderWriterLockSlim _lockGraph = new ReaderWriterLockSlim();
public readonly Func<int, int, float> DistanceFunction;
public TItem GetVector(int id) => _vectors[id];
public IDictionary<int, HashSet<int>> GetLinks() => _layers[0].Links;
public SmallWorld(NSWOptions<TItem> options) public SmallWorld(NSWOptions<TItem> options)
{ {
_options = options; _options = options;
_vectors = new VectorSet<TItem>(); _vectors = new VectorSet<TItem>();
_layers = new Layer<TItem>[_options.LayersCount]; _layers = new Layer<TItem>[_options.LayersCount];
_layerLevelGenerator = new ProbabilityLayerNumberGenerator(_options.LayersCount, _options.M); _layerLevelGenerator = new ProbabilityLayerNumberGenerator(_options.LayersCount, _options.M);
DistanceFunction = new Func<int, int, float>((id1, id2) => _options.Distance(_vectors[id1], _vectors[id2]));
for (int i = 0; i < _options.LayersCount; i++) for (int i = 0; i < _options.LayersCount; i++)
{ {
_layers[i] = new Layer<TItem>(_options, _vectors, i == 0); _layers[i] = new Layer<TItem>(_options, _vectors, i == 0);
@ -151,42 +158,14 @@ namespace ZeroLevel.HNSW
// connecting new node to the small world // connecting new node to the small world
for (int lc = Math.Min(L, l); lc >= 0; --lc) for (int lc = Math.Min(L, l); lc >= 0; --lc)
{ {
if (_layers[lc].HasLinks == false) _layers[lc].Push(q, ep, W, distance);
{ // ep ← W
_layers[lc].AddBidirectionallConnections(q, q); if (W.TryPeek(out id, out value))
}
else
{ {
// W ← SEARCH - LAYER(q, ep, efConstruction, lc) ep = id;
foreach (var i in _layers[lc].KNearestAtLayer(ep, distance, _options.EFConstruction)) epDist = value;
{
W.Push(i);
}
// ep ← W
if (W.TryPeek(out id, out value))
{
ep = id;
epDist = value;
}
// neighbors ← SELECT-NEIGHBORS(q, W, M, lc) // alg. 3 or alg. 4
var neighbors = SelectBestForConnecting(lc, W);
// add bidirectionall connectionts from neighbors to q at layer lc
// for each e ∈ neighbors // shrink connections if needed
foreach (var e in neighbors)
{
// eConn ← neighbourhood(e) at layer lc
_layers[lc].AddBidirectionallConnections(q, e.Item1);
// if distance from newNode to newNeighbour is better than to bestPeer => update bestPeer
if (e.Item2 < epDist)
{
ep = e.Item1;
epDist = e.Item2;
}
}
W.Clear();
} }
W.Clear();
} }
// if l > L // if l > L
if (l > L) if (l > L)
@ -198,30 +177,37 @@ namespace ZeroLevel.HNSW
} }
} }
/// <summary> public void TestWorld()
/// Get maximum allowed connections for the given level.
/// </summary>
/// <remarks>
/// Article: Section 4.1:
/// "Selection of the Mmax0 (the maximum number of connections that an element can have in the zero layer) also
/// has a strong influence on the search performance, especially in case of high quality(high recall) search.
/// Simulations show that setting Mmax0 to M(this corresponds to kNN graphs on each layer if the neighbors
/// selection heuristic is not used) leads to a very strong performance penalty at high recall.
/// Simulations also suggest that 2∙M is a good choice for Mmax0;
/// setting the parameter higher leads to performance degradation and excessive memory usage."
/// </remarks>
/// <param name="layer">The level of the layer.</param>
/// <returns>The maximum number of connections.</returns>
private int GetM(int layer)
{
return layer == 0 ? 2 * _options.M : _options.M;
}
private IEnumerable<(int, float)> SelectBestForConnecting(int layer, MinHeap candidates)
{ {
int count = GetM(layer); for (var v = 0; v < _vectors.Count; v++)
while (count >= 0 && candidates.Count > 0) {
yield return candidates.Pop(); var nearest = _layers[0][v].ToArray();
if (nearest.Length > _layers[0].M)
{
Console.WriteLine($"V{v}. Count of links ({nearest.Length}) more than max ({_layers[0].M})");
}
}
// coverage test
var ep = 0;
var visited = new HashSet<int>();
var next = new Stack<int>();
next.Push(ep);
while (next.Count > 0)
{
ep = next.Pop();
visited.Add(ep);
foreach (var n in _layers[0].GetNeighbors(ep))
{
if (visited.Contains(n) == false)
{
next.Push(n);
}
}
}
if (visited.Count != _vectors.Count)
{
Console.Write($"Vectors count ({_vectors.Count}) less than BFS visited nodes count ({visited.Count})");
}
} }
/// <summary> /// <summary>
@ -385,8 +371,5 @@ namespace ZeroLevel.HNSW
} }
} }
} }
/*public Histogram GetHistogram(HistogramMode mode = HistogramMode.SQRT)
=> _layers[0].GetHistogram(mode);*/
} }
} }

@ -1,6 +1,7 @@
using System; using System;
using System.Numerics; using System.Numerics;
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
namespace ZeroLevel.HNSW namespace ZeroLevel.HNSW
{ {

@ -6,6 +6,19 @@ namespace ZeroLevel.HNSW
{ {
public static class VectorUtils public static class VectorUtils
{ {
public static List<float[]> RandomVectors(int vectorSize, int vectorsCount)
{
var vectors = new List<float[]>();
for (int i = 0; i < vectorsCount; i++)
{
var vector = new float[vectorSize];
DefaultRandomGenerator.Instance.NextFloats(vector);
VectorUtils.NormalizeSIMD(vector);
vectors.Add(vector);
}
return vectors;
}
public static float Magnitude(IList<float> vector) public static float Magnitude(IList<float> vector)
{ {
float magnitude = 0.0f; float magnitude = 0.0f;

Loading…
Cancel
Save

Powered by TurnKey Linux.