diff --git a/TestHNSW/HNSWDemo/Program.cs b/TestHNSW/HNSWDemo/Program.cs index ba378e6..74ffc2f 100644 --- a/TestHNSW/HNSWDemo/Program.cs +++ b/TestHNSW/HNSWDemo/Program.cs @@ -9,7 +9,7 @@ namespace HNSWDemo { static void Main(string[] args) { - new LALTest().Run(); + new QuantizatorTest().Run(); // new AutoClusteringMNISTTest().Run(); // new AccuracityTest().Run(); Console.WriteLine("Completed"); diff --git a/TestHNSW/HNSWDemo/Tests/QuantizatorTest.cs b/TestHNSW/HNSWDemo/Tests/QuantizatorTest.cs index b8575be..3566779 100644 --- a/TestHNSW/HNSWDemo/Tests/QuantizatorTest.cs +++ b/TestHNSW/HNSWDemo/Tests/QuantizatorTest.cs @@ -10,7 +10,7 @@ namespace HNSWDemo.Tests : ITest { private static int Count = 500000; - private static int Dimensionality = 128; + private static int Dimensionality = 221; public void Run() { @@ -18,7 +18,7 @@ namespace HNSWDemo.Tests var min = samples.SelectMany(s => s).Min(); var max = samples.SelectMany(s => s).Max(); var q = new Quantizator(min, max); - var q_samples = samples.Select(s => q.QuantizeToLong(s)).ToArray(); + var q_samples = samples.Select(s => q.QuantizeToInt(s)).ToArray(); // comparing var list = new List(); diff --git a/ZeroLevel.HNSW/Services/Quantizator.cs b/ZeroLevel.HNSW/Services/Quantizator.cs index 0364af0..8f031fb 100644 --- a/ZeroLevel.HNSW/Services/Quantizator.cs +++ b/ZeroLevel.HNSW/Services/Quantizator.cs @@ -27,13 +27,12 @@ namespace ZeroLevel.HNSW.Services public int[] QuantizeToInt(float[] v) { - if (v.Length % 4 != 0) - { - throw new ArgumentOutOfRangeException("v.Length % 4 must be zero!"); - } - var result = new int[v.Length / 4]; + var diff = v.Length % 4; + int count = (v.Length - diff) / 4; + var result = new int[((diff == 0) ? 0 : 1) + (v.Length / 4)]; byte[] buf = new byte[4]; - for (int i = 0; i < v.Length; i += 4) + int i = 0; + for (; i < count * 4; i += 4) { buf[0] = _quantizeInRange(v[i]); buf[1] = _quantizeInRange(v[i + 1]); @@ -41,18 +40,29 @@ namespace ZeroLevel.HNSW.Services buf[3] = _quantizeInRange(v[i + 3]); result[(i >> 2)] = BitConverter.ToInt32(buf); } + if (diff != 0) + { + for (var j = 0; j < diff; j++) + { + buf[j] = _quantizeInRange(v[i + j]); + } + for (var j = diff; j < 4; j++) + { + buf[j] = 0; + } + result[(i >> 2)] = BitConverter.ToInt32(buf); + } return result; } public long[] QuantizeToLong(float[] v) { - if (v.Length % 8 != 0) - { - throw new ArgumentOutOfRangeException("v.Length % 8 must be zero!"); - } - var result = new long[v.Length / 8]; + var diff = v.Length % 8; + int count = (v.Length - diff) / 8; + var result = new long[((diff == 0) ? 0 : 1) + (v.Length / 8)]; byte[] buf = new byte[8]; - for (int i = 0; i < v.Length; i += 8) + int i = 0; + for (; i < count * 8; i += 8) { buf[0] = _quantizeInRange(v[i + 0]); buf[1] = _quantizeInRange(v[i + 1]); @@ -65,6 +75,18 @@ namespace ZeroLevel.HNSW.Services result[(i >> 3)] = BitConverter.ToInt64(buf); } + if (diff != 0) + { + for (var j = 0; j < diff; j++) + { + buf[j] = _quantizeInRange(v[i + j]); + } + for (var j = diff; j < 8; j++) + { + buf[j] = 0; + } + result[(i >> 3)] = BitConverter.ToInt64(buf); + } return result; }