HNSW fix quantization

pull/1/head
unknown 3 years ago
parent 0fdda146da
commit 2253dd7fdc

@ -9,7 +9,7 @@ namespace HNSWDemo
{ {
static void Main(string[] args) static void Main(string[] args)
{ {
new LALTest().Run(); new QuantizatorTest().Run();
// new AutoClusteringMNISTTest().Run(); // new AutoClusteringMNISTTest().Run();
// new AccuracityTest().Run(); // new AccuracityTest().Run();
Console.WriteLine("Completed"); Console.WriteLine("Completed");

@ -10,7 +10,7 @@ namespace HNSWDemo.Tests
: ITest : ITest
{ {
private static int Count = 500000; private static int Count = 500000;
private static int Dimensionality = 128; private static int Dimensionality = 221;
public void Run() public void Run()
{ {
@ -18,7 +18,7 @@ namespace HNSWDemo.Tests
var min = samples.SelectMany(s => s).Min(); var min = samples.SelectMany(s => s).Min();
var max = samples.SelectMany(s => s).Max(); var max = samples.SelectMany(s => s).Max();
var q = new Quantizator(min, max); var q = new Quantizator(min, max);
var q_samples = samples.Select(s => q.QuantizeToLong(s)).ToArray(); var q_samples = samples.Select(s => q.QuantizeToInt(s)).ToArray();
// comparing // comparing
var list = new List<float>(); var list = new List<float>();

@ -27,13 +27,12 @@ namespace ZeroLevel.HNSW.Services
public int[] QuantizeToInt(float[] v) public int[] QuantizeToInt(float[] v)
{ {
if (v.Length % 4 != 0) var diff = v.Length % 4;
{ int count = (v.Length - diff) / 4;
throw new ArgumentOutOfRangeException("v.Length % 4 must be zero!"); var result = new int[((diff == 0) ? 0 : 1) + (v.Length / 4)];
}
var result = new int[v.Length / 4];
byte[] buf = new byte[4]; byte[] buf = new byte[4];
for (int i = 0; i < v.Length; i += 4) int i = 0;
for (; i < count * 4; i += 4)
{ {
buf[0] = _quantizeInRange(v[i]); buf[0] = _quantizeInRange(v[i]);
buf[1] = _quantizeInRange(v[i + 1]); buf[1] = _quantizeInRange(v[i + 1]);
@ -41,18 +40,29 @@ namespace ZeroLevel.HNSW.Services
buf[3] = _quantizeInRange(v[i + 3]); buf[3] = _quantizeInRange(v[i + 3]);
result[(i >> 2)] = BitConverter.ToInt32(buf); result[(i >> 2)] = BitConverter.ToInt32(buf);
} }
if (diff != 0)
{
for (var j = 0; j < diff; j++)
{
buf[j] = _quantizeInRange(v[i + j]);
}
for (var j = diff; j < 4; j++)
{
buf[j] = 0;
}
result[(i >> 2)] = BitConverter.ToInt32(buf);
}
return result; return result;
} }
public long[] QuantizeToLong(float[] v) public long[] QuantizeToLong(float[] v)
{ {
if (v.Length % 8 != 0) var diff = v.Length % 8;
{ int count = (v.Length - diff) / 8;
throw new ArgumentOutOfRangeException("v.Length % 8 must be zero!"); var result = new long[((diff == 0) ? 0 : 1) + (v.Length / 8)];
}
var result = new long[v.Length / 8];
byte[] buf = new byte[8]; byte[] buf = new byte[8];
for (int i = 0; i < v.Length; i += 8) int i = 0;
for (; i < count * 8; i += 8)
{ {
buf[0] = _quantizeInRange(v[i + 0]); buf[0] = _quantizeInRange(v[i + 0]);
buf[1] = _quantizeInRange(v[i + 1]); buf[1] = _quantizeInRange(v[i + 1]);
@ -65,6 +75,18 @@ namespace ZeroLevel.HNSW.Services
result[(i >> 3)] = BitConverter.ToInt64(buf); result[(i >> 3)] = BitConverter.ToInt64(buf);
} }
if (diff != 0)
{
for (var j = 0; j < diff; j++)
{
buf[j] = _quantizeInRange(v[i + j]);
}
for (var j = diff; j < 8; j++)
{
buf[j] = 0;
}
result[(i >> 3)] = BitConverter.ToInt64(buf);
}
return result; return result;
} }

Loading…
Cancel
Save

Powered by TurnKey Linux.