Imagine we're tasked with developing a model that can group data points effectively, based on their similarities. Traditionally, the approach involves using supervised learning, where we train the model with pairs or triplets from a labeled dataset. This method aims to teach the model to position similar data points closer together in the embedding space we're crafting. The underlying hope is that, through exposure to a representative variety of pairs or triplets, the model will learn to discern and thus cluster data points effectively.
However, if clustering is our end task, one might wonder why we don't optimize the model directly against a clustering metric, such as the RAND index. Sounds straight-forward but here's the catch: clustering is a discrete process, which means the evaluation metrics for clustering are also discrete functions and therefore non-differentiable. Gradient descent algorithms cannot be effectively guided by such functions, which is a stumbling block for training models, like Transformers, directly with a loss function based on the RAND index.
Recent advancements by Vlastelica et al. have introduced a promising workaround. They've demonstrated a technique that effectively generates meaningful gradients from discrete loss functions, a method they term 'blackbox backpropagation'. This opens up the intriguing possibility: could we leverage this technique to train a deep representation model specifically for clustering, guided directly by a clustering metric? Would this approach yield superior clustering results? And importantly, how does its scalability compare with traditional pair or triplet-based training methods?
These are the central questions we explore in our work, "Learn The Big Picture: Representation Learning for Clustering," where we delve into the potential of direct optimization for clustering tasks.
Paper: Learn The Big Picture: Representation Learning for Clustering
Github repo: Blackbox_clustering
If this work caught your attention, maybe you would also like to know how we leveraged clustering tasks to improve dense retrieval models. You can find the project overview here and the related paper here.