[link]
Summary by n0mad 3 years ago
**Motivation**
A possible approach to industrialise deep and shallow embeddings in various retrieval tasks is to decouple the representations of documents (video clips/items in recommendations) and queries (users in recsys). Under this scheme, the degree of matching ("relevance") $r$ of the document with representation $x$ and the query with representation $q$ is calculated as a softmax function:
$$
r := \frac{e^{q^Tx}}{\sum_i e^{q^Tx_i}}
$$
At learning time, both representations are independent until the last softmax layer. This allows a convenient run-time scheme: the representations of the documents are pre-calculated and stored in some form of index, and the query-document relevance is calculated online, at the query-time. As for the fixed query the relevance is monotone w.r.t. the dot product $q^Tx$, the problem of efficient ranking boils down to efficiently finding documents/items with high values of this dot product.
Such an approach is used in papers describing Youtube [1] and Yahoo [2] systems.
This leads to the following **problem statement** (informal): given a dataset of items $\mathbb{X}$ and a sample of queries $\mathbb{Q}$, how can one compress the dataset $\mathbb{X}$ such that the compressed dataset allows to calculate dot product queries w.r.t. uncompressed $q$ with as small distortion as possible? In this paper, squared Euclidean distance similarity queries are also considered.
Formally, we're looking for a compressed representation $x$ for $\hat x$ which minimises the error:
(in the case of dot-product)
$$
L_{dot} = \sum_{x \in \mathbb{X}, q \in \mathbb{Q}} \left(q^Tx - q^T \hat x \right)^T\left(q^Tx - q^T \hat x \right)
$$
(in the case of square of squared Euclidean distance)
$$
L_{Eucl} = \sum_{x \in \mathbb{X}, q \in \mathbb{Q}} \left((q - x)^2 - (q - \hat x)^2 \right)^2
$$
In the centre of the paper is the observation that these two *pairwise* loss functions can be reduced to per-point distortion losses in some modified space. In the case of $L_{dot}$:
$$
L_{dot} = \sum_{x \in \mathbb{X}, q \in \mathbb{Q}} \left(q^Tx - q^T \hat x \right)^T\left(q^Tx - q^T \hat x \right) = \sum_{x \in \mathbb{X}}(x - \hat x)^T \left( \sum_{q \in \mathbb{Q}} q q^T \right)(x - \hat x)
$$
Since $ \left( \sum_{q \in \mathbb{Q}} q q^T \right)$ is a semi-definite positive matrix, we can factorise it as $C^TC$ and plug it back:
$$
L_{dot} =\sum_{x \in \mathbb{X}}(Cx - C\hat x)^T(Cx - C\hat x)=\sum_{z \in \mathbb{Z}}(z - \hat z)^T(z - \hat z)
$$
where $\mathbb{Z} = \{z = C x | x \in \mathbb{X}\}$ is a dataset obtained by modifying $\mathbb{X}$ by applying matrix $C$. That reduces the problem of minimising the distortion of dot-product estimates to the problem of reducing the distortion of individual points in $\mathbb{Z}$.
Assuming that we have an efficient way to compress and store individual points when minimising squared distance distortion $(z - \hat z)^2$, we can turn $\mathbb{X}$ into $\mathbb{Z}$, then compress & index $\mathbb{Z}$.
Luckily, we have Product Quantisation, Optimised Product Quantisation, Additive Product Quantisation, etc to make the required per-point compression $z \rightarrow \hat z$, that minimise the per-point loss.
How one can use compressed $\mathbb{Z}$ at run-time? The query itself must be modified: $q \rightarrow r = (C^{-1})^Tq$. Then, by finding an estimate of $r^T z$ we will find our estimate $q^Tx$:
$$
r^T z = q^TC^{-1}z = q^TC^{-1} \cdot C x = q^T x
$$
Similar reduction is performed for the pairwise squared Euclidean distance loss.
Next, the authors demonstrate the the obtained estimates for the distance and dot product are unbiased when the underlying compression in $\mathbb{Z}$-space is performed by Optimised Product Quantisation (OPQ).
The paper concludes with several experiments that demonstrate that the Pairwise Quantisation better recovers dot-product and squared Euclidean distance than OPQ.
**Overall** the approach seems to be practical, elegant and addresses an important problem. One possible issue is that (authors mention it) the optimisation is performed w.r.t. *all* pairs of sampled queries and datapoints. In practical applications, one only needs to accurately recover dot-product/distance to the closest vectors. For instance, the correct ranking of the documents on positions 100-200 does not matter; while positions 1-10 are extremely important.
[1] https://static.googleusercontent.com/media/research.google.com/ru//pubs/archive/45530.pdf
[2] http://www.kdd.org/kdd2016/papers/files/adf0361-yinA.pdf

more
less