import numpy as np
from ..base import BaseNumericEncoder
from ...helper import batching, train_required
[docs]class VladEncoder(BaseNumericEncoder):
batch_size = 2048
def __init__(self, num_clusters: int,
using_faiss_pred: bool = False,
*args, **kwargs):
super().__init__(*args, **kwargs)
self.num_clusters = num_clusters
self.using_faiss_pred = using_faiss_pred
self.centroids = None
self.index_flat = None
[docs] def kmeans_train(self, vecs):
import faiss
kmeans = faiss.Kmeans(vecs.shape[1], self.num_clusters, niter=5, verbose=False)
kmeans.train(vecs)
self.centroids = kmeans.centroids
self.centroids_l2 = np.sum(self.centroids ** 2, axis=1).reshape([1, -1])
self.centroids_trans = np.transpose(self.centroids)
if self.using_faiss_pred:
self.faiss_index()
[docs] def faiss_index(self):
import faiss
self.index_flat = faiss.IndexFlatL2(self.centroids.shape[1])
self.index_flat.add(self.centroids)
[docs] def kmeans_pred(self, vecs):
if self.using_faiss_pred:
_, pred = self.index_flat.search(vecs.astype(np.float32), 1)
return np.reshape(pred, [-1])
else:
vecs_l2 = np.sum(vecs ** 2, axis=1).reshape([-1, 1])
dist = vecs_l2 + self.centroids_l2 - 2 * np.matmul(vecs, self.centroids_trans)
return np.argmax(dist, axis=-1).reshape([-1]).astype(np.int32)
[docs] @batching
def train(self, vecs: np.ndarray, *args, **kwargs):
vecs = vecs.reshape([-1, vecs.shape[-1]])
assert len(vecs) > self.num_clusters, 'number of data should be larger than number of clusters'
self.kmeans_train(vecs)
[docs] @train_required
@batching
def encode(self, vecs: np.ndarray, *args, **kwargs) -> np.ndarray:
knn_output = [self.kmeans_pred(vecs_) for vecs_ in vecs]
output = []
for chunk_count, chunk in enumerate(vecs):
res = np.zeros((self.centroids.shape[0], self.centroids.shape[1]))
for frame_count, frame in enumerate(chunk):
center_index = knn_output[chunk_count][frame_count]
res[center_index] += (frame - self.centroids[center_index])
res = res.reshape([-1])
output.append(res / np.sum(res ** 2) ** 0.5)
return np.array(output, dtype=np.float32)
def _copy_from(self, x: 'VladEncoder') -> None:
self.num_clusters = x.num_clusters
self.centroids = x.centroids
self.centroids_l2 = x.centroids_l2
self.centroids_trans = np.transpose(self.centroids)
self.using_faiss_pred = x.using_faiss_pred
if self.using_faiss_pred:
self.faiss_index()
def __setstate__(self, state):
super().__setstate__(state)
if self.using_faiss_pred:
self.faiss_index()
def __getstate__(self):
state = super().__getstate__()
del state['index_flat']
return state