Meet Flash-KMeans: An IO-Aware, Exact K-Means That Runs Over 200× Faster Than FAISS on GPUs
k-means has been an offline tool for decades. You run it once to preprocess data, then move on. A team of researchers from UC Berkeley and UT Austin released Flash-KMeans, a new open-source library that targets a different setting. Modern AI pipelines now call k-means inside training and inference loops. At that frequency, latency per call matters more than theoretical FLOPs.
Flash-KMeans is an IO-aware implementation of standard Lloyd’s k-means. It does not change the math, and it does not approximate. It only restructures how the algorithm moves data on a GPU. On an NVIDIA H200, the research team reported up to 17.9× end-to-end speedup over the best baseline. Against NVIDIA cuML they report 33×. Against FAISS they report over 200×.
What is Flash-KMeans
Flash-KMeans is a batched k-means library written in Triton GPU kernels. It ships under Apache 2.0 and installs with pip install flash-kmeans.
The output is mathematically identical to standard Lloyd’s k-means. The speedup comes from kernel-level dataflow, not from skipping work. That separates it from algorithmic methods like triangle-inequality pruning or coreset sampling.
A standard Lloyd iteration has two stages. The assignment stage computes each point’s distance to every centroid, then picks the nearest. The update stage averages the points in each cluster to form new centroids. Both stages are simple arithmetic. On GPUs, both are bottlenecked by memory, not compute.
The Two Bottlenecks It Attacks
The first bottleneck is the assignment stage. Standard code builds a full distance matrix D of shape N×K in High Bandwidth Memory (HBM). It writes the matrix, then reads it back to run argmin. For N=65536, K=1024, d=128, B=32, the distance math takes 2.6ms. Writing and consuming D takes about 23ms. The matrix is the cost, not the arithmetic.
Flash-KMeans replaces this with FlashAssign. The design borrows from FlashAttention. FlashAssign streams tiles of points and centroids from HBM into on-chip SRAM. It fuses distance computation with an online argmin. The full N×K matrix is never materialized. This cuts the dominant IO complexity from O(NK) to O(Nd + Kd). At the kernel level, FlashAssign reaches up to 21.2×. In one case it cut assignment from 122.5ms to 5.8ms.
The second bottleneck is the centroid update stage. Standard code uses scatter-style atomic adds. Each thread adds its point into a shared sum buffer keyed by cluster id. Many threads hit the same ‘hot’ cluster at once. That causes atomic contention and hardware serialization. The research team measured only 50 GB/s effective bandwidth here on an H200.
Flash-KMeans replaces this with Sort-Inverse Update. It sorts the 1D assignment vector by cluster id using argsort. Identical cluster ids then form contiguous segments. Each thread block reduces a segment on-chip, then issues one atomic add per segment. The heavy point matrix is never physically permuted. Atomic operations drop from . The kernel reaches up to 6.3×.
Benchmark
The research team test it on an H200 with CUDA 12.8, FP16 data, and d=128. They sweep N, K, and batch size B. They compare against four optimized baselines: fast_pytorch_kmeans, fastkmeans, cuML, and FAISS.
| Comparison | Reported speedup | Workload context |
|---|---|---|
| End-to-end vs best baseline | up to 17.9× | N=8M, K=1024 (large N, small K) |
| vs NVIDIA cuML | 33× | industry library |
| vs FAISS | over 200× | industry library |
| FlashAssign kernel | up to 21.2× | N=1M, K=8192 (assignment) |
| Sort-Inverse Update kernel | up to 6.3× | N=33M, K=4096 (update) |
| Out-of-core, large scale | up to 10.5× | N=400M, K=16384 vs fastkmeans |
One failure mode matters for context. Standard PyTorch implementations run out of memory in large-K regimes. They cannot materialize the N×K matrix. FAISS is the industry-standard library under many production vector-search systems.
The library also runs out-of-core. On one billion points (K=32768, d=128), it finishes an iteration in 41.4s, against 261.8s for the baseline. It uses chunked stream overlap to hide PCIe transfer behind compute. A cache-aware compile heuristic also cuts tuning overhead by up to 175×, within 0.3% of tuned speed.
MTP Interactive Explainer
Flash-KMeans: exact k-means, rebuilt around GPU memory
Same Lloyd’s math as standard k-means — faster only because of dataflow. Run clustering live, watch the update bottleneck, and size the IO it removes.
This runs real Lloyd’s k-means in your browser on 2-D points. The algorithm is identical to what Flash-KMeans accelerates — only the GPU dataflow differs. Each step = one assignment + one centroid update.
Press play. Standard scatter-update serializes when blocks write the same “hot” centroid (red stalls). Sort-Inverse Update sorts cluster IDs first, so each block merges contiguous segments with one atomic add — no conflict.
Standard updates issue one atomic add per token. Many threads hit the same centroid at once, causing contention. Sorting by cluster ID turns scatters into segment-level reductions in on-chip memory.
Standard k-means writes then reads a full N×K distance matrix in HBM. FlashAssign never builds it — it reads X and C once and writes assignments once. Bars show relative HBM round-trips, FP16.
Speedups: Flash-KMeans paper (arXiv:2603.09229), NVIDIA H200. Demo runs in-browser for illustration · github.com/svg-project/flash-kmeans
Use Cases
Faster exact k-means changes what you can run online, not just offline.
- Vector search indexing: FAISS builds its search indices with k-means. Faster k-means lets you re-index as data shifts, instead of rebuilding overnight.
- Sparse attention routing: Routing Transformers and Tactic cluster tokens to route attention. Millisecond k-means makes this viable inside the inference loop.
- KV-cache compression: ClusterKV clusters tokens in semantic space to compress the cache. Cheaper clustering makes per-layer, per-step compression practical.
- Low-bit KV quantization: Recent methods cluster KV entries into codebooks, repeatedly. Faster clustering shrinks that preprocessing cost.
- Diffusion Transformers: Sparse VideoGen2 calls batched k-means during forward passes. It permutes tokens by semantic similarity to exploit sparsity.
Using It
The API mirrors faiss and sklearn. The call below clusters a batched (B, N, d) tensor.
import torch
from flash_kmeans import batch_kmeans_Euclid
x = torch.randn(32, 75600, 128, device="cuda", dtype=torch.float16)
cluster_ids, centers, _ = batch_kmeans_Euclid(
x, n_clusters=1000, tol=1e-4, verbose=True
)
A scikit-learn-style interface is also available.
from flash_kmeans import FlashKMeans
km = FlashKMeans(d=128, k=8192, niter=100)
labels = km.fit_predict(large_cpu_tensor) # device=None uses all visible GPUs
The kernel auto-dispatches by shape and dtype. A small-D path handles d≤512. A split-D path handles larger d without materializing the distance matrix. Multi-GPU runs trigger automatically for large-N data held in CPU memory.
Key Takeaways
- Flash-KMeans is exact, not approximate — same Lloyd’s math, sped up purely by GPU dataflow.
- FlashAssign fuses distance + online argmin, cutting assignment IO from O(NK) to O(Nd+Kd) — up to 21.2×.
- Sort-Inverse Update sorts cluster IDs into segments, replacing scatter atomics — up to 6.3×.
- Reports up to 17.9× end-to-end, 33× over cuML, and over 200× over FAISS on an H200.
- Scales out-of-core to one billion points and cuts tuning overhead up to 175×.
Check out the Paper and Repo. Also, feel free to follow us on Twitter and don’t forget to join our 150k+ML SubReddit and Subscribe to our Newsletter. Wait! are you on telegram? now you can join us on telegram as well.
Need to partner with us for promoting your GitHub Repo OR Hugging Face Page OR Product Release OR Webinar etc.? Connect with us
The post Meet Flash-KMeans: An IO-Aware, Exact K-Means That Runs Over 200× Faster Than FAISS on GPUs appeared first on MarkTechPost.
MarkTechPost
