\(\newcommand{\abs}[1]{\lvert#1\rvert}\) \(\newcommand{\norm}[1]{\lVert#1\rVert}\) \(\newcommand{\innerproduct}[2]{\langle#1, #2\rangle}\) \(\newcommand{\Tr}[1]{\operatorname{Tr}\mleft(#1\mright)}\) \(\DeclareMathOperator*{\argmin}{argmin}\) \(\DeclareMathOperator*{\argmax}{argmax}\) \(\DeclareMathOperator{\diag}{diag}\) \(\newcommand{\converge}[1]{\xrightarrow{\makebox[2em][c]{\)\scriptstyle#1\(}}}\) \(\newcommand{\quotes}[1]{``#1''}\) \(\newcommand\ddfrac[2]{\frac{\displaystyle #1}{\displaystyle #2}}\) \(\newcommand{\vect}[1]{\boldsymbol{\mathbf{#1}}}\) \(\newcommand{\E}{\mathbb{E}}\) \(\newcommand{\Var}{\mathrm{Var}}\) \(\newcommand{\Cov}{\mathrm{Cov}}\) \(\renewcommand{\N}{\mathbb{N}}\) \(\renewcommand{\Z}{\mathbb{Z}}\) \(\renewcommand{\R}{\mathbb{R}}\) \(\newcommand{\Q}{\mathbb{Q}}\) \(\newcommand{\C}{\mathbb{C}}\) \(\newcommand{\bbP}{\mathbb{P}}\) \(\newcommand{\rmF}{\mathrm{F}}\) \(\newcommand{\iid}{\mathrm{iid}}\) \(\newcommand{\distas}[1]{\overset{#1}{\sim}}\) \(\newcommand{\cA}{\mathcal{A}}\) \(\newcommand{\cB}{\mathcal{B}}\) \(\newcommand{\cC}{\mathcal{C}}\) \(\newcommand{\cD}{\mathcal{D}}\) \(\newcommand{\cE}{\mathcal{E}}\) \(\newcommand{\cF}{\mathcal{F}}\) \(\newcommand{\cG}{\mathcal{G}}\) \(\newcommand{\cH}{\mathcal{H}}\) \(\newcommand{\cI}{\mathcal{I}}\) \(\newcommand{\cJ}{\mathcal{J}}\) \(\newcommand{\cL}{\mathcal{L}}\) \(\newcommand{\cM}{\mathcal{M}}\) \(\newcommand{\cP}{\mathcal{P}}\) \(\newcommand{\cO}{\mathcal{O}}\) \(\newcommand{\cQ}{\mathcal{Q}}\) \(\newcommand{\cU}{\mathcal{U}}\) \(\newcommand{\cV}{\mathcal{V}}\) \(\newcommand{\cN}{\mathcal{N}}\) \(\newcommand{\cT}{\mathcal{T}}\) \(\newcommand{\cX}{\mathcal{X}}\) \(\newcommand{\cY}{\mathcal{Y}}\) \(\newcommand{\cZ}{\mathcal{Z}}\) \(\newcommand{\cS}{\mathcal{S}}\) \(\newcommand{\shorteqnote}[1]{ & \textcolor{blue}{\text{\small #1}}}\) \(\newcommand{\qimplies}{\quad\Longrightarrow\quad}\) \(\newcommand{\defeq}{\stackrel{\triangle}{=}}\) \(\newcommand{\longdefeq}{\stackrel{\text{def}}{=}}\) \(\newcommand{\equivto}{\iff}\)
In this blog post, we’ll explore the Stochastic Expectation-Maximization (SEM) algorithm and its application to Gaussian Mixture Models (GMMs). Building upon the traditional EM algorithm, SEM introduces stochasticity into the estimation process, offering potential advantages in terms of escaping local optima and handling large datasets. We’ll carefully examine the mathematical derivations and provide a concrete example of GMMs using SEM.
The Expectation-Maximization (EM) algorithm is a powerful tool for finding maximum likelihood estimates in models with latent variables, such as Gaussian Mixture Models (GMMs). However, EM can sometimes get trapped in local maxima, especially in complex or high-dimensional spaces.
Moreover, note that the $Q_t(\Theta) := E_{Z \mid X; \Theta_t}\left[ \log p(X, Z; \Theta) \right]$ which appears in the E-step is simply the conditional expectation of the complete-data log-likelihood in terms of observed variable, given $X$ and assuming the true parameter value is $\Theta_t$. This yields some probabilistic insight on this $Q$ function. Since this is expectation, for all $\Theta$, $Q_t(\Theta)$ is an estimate of the complete-data log-likelihood built on the information of the incomplete data and under the assumption that the true parameter is unknown. In some way, it is not far from being the “best” estimate that we can possibly make without knowing $Z$, because conditional expectation is, by definition, the estimator that minimizes the conditional mean squared error:
\[Q_t(\Theta) = \argmin_{\mu} \int \left[\log p(X, Z; \Theta) -\mu \right]^2 P(Z \mid X; \Theta_t) dZ\]The Stochastic EM (SEM) algorithm introduces randomness into the E-step, which can help the algorithm explore the parameter space more thoroughly and potentially escape local optima. Instead of computing expected values over the latent variables, SEM samples them according to their posterior distributions given the current parameter estimates.
Key Differences Between EM and SEM:
By introducing stochasticity, SEM can offer advantages in terms of convergence properties and computational efficiency, particularly for large datasets.
Let’s consider a GMM with $ K $ Gaussian components. Our goal is to estimate the parameters $ \Theta = { \pi_k, \vect{\mu}k, \vect{\Sigma}_k }{k=1}^K $, where:
Given observed data $X = {\vect{x}_1, \dots, \vect{x}_n}$, we introduce latent variables $Z = {z_1, \dots, z_n}$, where $z_i \in {1, \dots, K}$ indicates the component assignment for $\vect{x}_i$.
Our objective is to maximize the complete-data log-likelihood:
\[\log p(X, Z \mid \Theta) = \sum_{i=1}^n \log \left( \pi_{z_i} \mathcal{N}(\vect{x}_i \mid \vect{\mu}_{z_i}, \vect{\Sigma}_{z_i}) \right)\]In SEM, we alternate between sampling the latent variables $Z$ and updating the parameters $\Theta$.
S-Step: For each data point $\vect{x}_i$, sample $z_i$ from the posterior distribution:
\[P(z_i = k \mid \vect{x}_i; \Theta_t) = \gamma_{t, i, k} = \frac{\pi_{t, k} \mathcal{N}(\vect{x}_i \mid \vect{\mu}_{t, k}, \vect{\Sigma}_{t, k})}{\sum_{j=1}^K \pi_{t, j} \mathcal{N}(\vect{x}_i \mid \vect{\mu}_{t, j}, \vect{\Sigma}_{t, j})}\]Rather than computing the expected value over $ Z $, we sample each $ z_i $ according to $ \gamma_{t, i, k} $.
Given the sampled $Z$, we maximize the complete-data log-likelihood with respect to $\Theta$:
\[\Theta_{t+1} = \arg\max_{\Theta} \sum_{i=1}^n \log \left( \pi_{z_i} \mathcal{N}(\vect{x}_i \mid \vect{\mu}_{z_i}, \vect{\Sigma}_{z_i}) \right)\]We can derive the update equations for $\pi_k$, $\vect{\mu}_k$, and $\vect{\Sigma}_k$.
For each data point $\vect{x}_i$:
Compute Responsibilities:
\[\gamma_{t, i, k} = P(z_i = k \mid \vect{x}_i; \Theta_t) = \frac{\pi_{t, k} \mathcal{N}(\vect{x}_i \mid \vect{\mu}_{t, k}, \vect{\Sigma}_{t, k})}{\sum_{j=1}^K \pi_{t, j} \mathcal{N}(\vect{x}_i \mid \vect{\mu}_{t, j}, \vect{\Sigma}_{t, j})}\]Sample $z_i$:
Given the sampled $ Z $, we update the parameters.
The likelihood w.r.t $\pi_k$ is:
\[L(\pi) = \sum_{i=1}^n \log \pi_{z_i}\]Subject to the constraint $\sum_{k=1}^K \pi_k = 1$. Using the method of Lagrange multipliers:
Formulate Lagrangian:
\[\mathcal{L}(\pi, \lambda) = \sum_{i=1}^n \log \pi_{z_i} + \lambda \left( \sum_{k=1}^K \pi_k - 1 \right)\]Compute Gradient and Set to Zero:
For each $\pi_k$:
\[\frac{\partial \mathcal{L}}{\partial \pi_k} = \sum_{i: z_i = k} \frac{1}{\pi_k} + \lambda = 0\]Solve for $\pi_k$:
\[\pi_k = -\frac{N_k}{\lambda}\]Where $N_k = \sum_{i=1}^n \delta(z_i = k)$ is the number of data points assigned to component $k$.
Apply Constraint:
\[\sum_{k=1}^K \pi_k = -\frac{1}{\lambda} \sum_{k=1}^K N_k = 1 \implies \lambda = -n\]Final Update:
\[\pi_k = \frac{N_k}{n}\]We need to maximize:
\[L(\vect{\mu}_k) = \sum_{i: z_i = k} \log \mathcal{N}(\vect{x}_i \mid \vect{\mu}_k, \vect{\Sigma}_k)\]This is equivalent to minimizing:
\[\sum_{i: z_i = k} (\vect{x}_i - \vect{\mu}_k)^T \vect{\Sigma}_k^{-1} (\vect{x}_i - \vect{\mu}_k)\]Setting the derivative w.r.t $\vect{\mu}_k$ to zero:
\[\frac{\partial L}{\partial \vect{\mu}_k} = \sum_{i: z_i = k} \vect{\Sigma}_k^{-1} (\vect{x}_i - \vect{\mu}_k) = 0\]Solving for $\vect{\mu}_k$:
\[\vect{\mu}_k = \frac{1}{N_k} \sum_{i: z_i = k} \vect{x}_i\]We maximize:
\[L(\vect{\Sigma}_k) = \sum_{i: z_i = k} \log \mathcal{N}(\vect{x}_i \mid \vect{\mu}_k, \vect{\Sigma}_k)\]This involves minimizing:
\[\sum_{i: z_i = k} \left[ \log \det \vect{\Sigma}_k + (\vect{x}_i - \vect{\mu}_k)^T \vect{\Sigma}_k^{-1} (\vect{x}_i - \vect{\mu}_k) \right]\]Setting the derivative w.r.t $\vect{\Sigma}_k$ to zero:
Compute Gradient:
\[\frac{\partial L}{\partial \vect{\Sigma}_k} = \frac{N_k}{2} \vect{\Sigma}_k^{-1} - \frac{1}{2} \vect{\Sigma}_k^{-1} \left( \sum_{i: z_i = k} (\vect{x}_i - \vect{\mu}_k)(\vect{x}_i - \vect{\mu}_k)^T \right) \vect{\Sigma}_k^{-1} = 0\]Solve for $\vect{\Sigma}_k$:
\[\vect{\Sigma}_k = \frac{1}{N_k} \sum_{i: z_i = k} (\vect{x}_i - \vect{\mu}_k)(\vect{x}_i - \vect{\mu}_k)^T\]The SEM algorithm for GMMs can be summarized as follows:
Initialize $\Theta_0 = { \pi_{0, k}, \vect{\mu}{0, k}, \vect{\Sigma}{0, k} }$.
For $t = 0$ to convergence:
Check for Convergence:
The Stochastic EM algorithm offers an alternative to the traditional EM algorithm by incorporating randomness into the estimation process. This stochasticity can help the algorithm avoid local maxima and explore the parameter space more effectively.
Summary of SEM Algorithm for GMMs:
Advantages of SEM:
Considerations:
In conclusion, SEM provides a valuable tool for mixture model estimation, particularly in complex scenarios where traditional EM may struggle. By carefully implementing and analyzing SEM, we can leverage its strengths to achieve robust clustering and parameter estimation.