[NOTES] Expectation Maximization
This is an extremely long and mathematically rigorous essay on Expectation Maximization (EM) Algorithm from the lens of a Computer Science graduate who has not taken a math heavy course in a while.
Some of the best long-form essays, be they in text or video form, have one common denominator: They come from individuals who have developed a deep obsession with the subject of the essay. Obsession has positive connotations in this context. However, I'm not obsessed with EM and it holds no special place in my heart (as you will see soon). It's just that I have set a 4-hour timer and am forcing myself to learn as much as I can about EM. Maybe a more general "proxy" for the quality of an essay is "hours spent meditating on the subject" (readers will be the judge). I'm sure that trained mathematicians will view this as blasphemy.
I
I'm not a big fan of the unnecessary amount of formalism that some math textbooks use. In fact, one of my controversial beliefs is that some (only some) mathematicians deliberately gate-keep knowledge by finding the most unidiomatic formalism for stating a simple property. I'm especially annoyed but this gate-keeping because you will rarely find this happening in software-engineering. For example, the following code would get me brutal PR reviews and maybe a call from my manager:
// Gatekeep version of checking if a number is even
bool isEvenUnidiomatically(int number) {
std::bitset<32> binaryRepresentation(number);
return binaryRepresentation[0] == 0;
}
Anyway, let's get into notations and get past the gatekeeping.
Background
EM is an optimization algorithm and is:
- Iterative
- Motivated by a notion of missing-ness
- Conditional distribution of what is missing given what is observed
- Can reliably find the optimum through stable, uphill steps
Notation
- : Y is complete data, X is observed data, Z is latent data. This may seem silly to write but note that Y is NOT a function of X and Z. Y is also not X + Z.
- X = M(Y): this is a "many-to-fewer" mapping. Remember that X is a "part of" Y. Don't confuse this with the X we use in traditional ML techniques where X represents training dataset.
- : density function of observed data. Read as: "density of X given theta"
- Read the above equation as: the density of observed data X is integral of density of full data Y given theta. This effectively sums (or integrates) the probabilities of all the different ways in which the observed data could occur, according to the model.
- : density function of Y (complete data). Read as "density of Y given theta"
- : Conditional density of missing data
- is our parameter that we wanna estimate (aka free-parameters/coefficients)
Marginalization Model
Marginalization refers to the process of summing or integrating out one or more variables from a joint distribution to be able to study the distribution of the remaining variables.
- Read as "probability density of x given theta is given as integral of probability density of y given theta over all possible ways in which x can occur"
- OR "the density of observed data X is integral of density of full data Y given theta. This effectively sums (or integrates) the probabilities of all the different ways in which the observed data x could occur, according to the model."
If you see drop-downs like these, try to familiarize yourself with these terms to better grasp the math that will come.
"Frequentist Setting"
Frequentist Setting refers to an interpretation of probability AS A LONG RUNNING FREQUENT EXPERIMENT. Ex: when we claim that the chance of a coin toss outcome is 50% heads/tails, in frequentist setting we are saying "If you toss a coin and measure the outcome and repeat this process for very long time then you will find the probability of any side occurring to be 50% (based on the frequency)"
Optimization
Finding roots or finding some minima/maxima.
Density function
Think of "density" as just another term for "area under the function's curve"
Optimization
Finding roots or finding some minima/maxima.
Likelihood and Log-Liklihood ($L(\theta | x)$)
Likelihood is the probability of seeing some observed data given a parameter estimate. This is typically a function that generates our data. Optimization is all about finding the roots or the maxima or minima for a complicated likelihood function. Log-liklihood is just a log-transform of a likelihood function. It often makes calculations easier. You will see it denoted as $L(\theta | x)$.
EM Algorithm
EM is an optimization algorithm as stated above. It is used for estimating the parameters of probabilistic models. Typically, we deal with a (log-)likelihood function and want to find the set of parameters that are most likely to generated the sample data.
Suppose we want to maximize with respect to we define the following function:
Read as: "Expectation of joint log likelihood of complete data conditional on observed data X = x" Can also be written as:
OR
If this were to be written in a frequentist setting then we'd get a summation instead of an integral.
Now the EM algorithm boils down the the following two steps:
- E-Step: Compute (Expectation of the joint-likelihood function function given parameter (theta) estimate at iteration t)
- M Step: Maximize with respect to . Set (theta at iteration t+1) as the maximizer of .
At this point you will ask, what the fuck is E? And how the fuck did we go from E to an integral? Yes it looks intimidating if you have never had a background to Statistics, but I will explain what it means with an example because it's so silly.
II
When you have mastered numbers, you will in fact no longer be reading numbers, any more than you read words when reading books. You will be reading meanings._ (Harold Geneen, “Managing”)
If it were up to me, I'd reduce the formal definition for EM Algorithm and make it like three lines and avoid all math jargon cuz it did less than 1% to build my intuition. Let's just dive into the examples and get dirty.
Example 1.
This is going to be probably the only detailed example. Rest of the examples will assume that the reader has understood and internalised these steps.
Let's break this down. Read the above statement as "Y1 and Y2 are independently and identically distributed (i.id) and follow (~ symbol) an exponential distribution".
Expectation E[X]
MEAN! AVERAGE! or more formally: "Expectation, also known as the expected value or the mean. It's a measure of the central tendency of a random variable, providing a weighted average of all possible values that the random variable can take on, with the weights being the probabilities of those values. 🥰"
Now the first step is to get the probability density function of an exponential distribution. The probability density function (PDF) for an exponential distribution is given by:
E(X) = \frac{1}{\lambda}
L(\theta | x_1, x_2, \ldots, x_n) = \prod_{i=1}^{n} f(x_i|\theta)
L(\theta | x_1, x_2, \ldots, x_n) = \prod_{i=1}^{n} \theta e^{-\theta x_i} = \theta^n e^{-\theta \sum_{i=1}^{n} x_i}
\ln L(\theta | x_1, x_2, \ldots, x_n) = n \ln(\theta) - \theta \sum_{i=1}^{n} x_i
\log{L(\theta | y)} = \log{f_Y(y|\theta)} = 2\log{\theta} - \theta y_1 - \theta y_2.
Q(\theta|\theta^{(t)}) = 2\log{\theta} - 50\theta - \frac{\theta}{\theta^{(t)}} $ since $ follows from independence.
Ok what about the M step?
The maximizer of is the root of . Thus $ \theta^{(t+1)} = \frac{2\theta{(t)}}{50\theta{(t)}+1}. $ Converges quickly to .
The code to implement the above thing looks as simple as follows:
import math
E = lambda y1, theta: 2 * math.log(theta) - y1 * theta - theta/theta
update = lambda theta: 2*theta/(5*theta + 1) # M step
tol = 2e-09
theta = 5.0 # initial guess
for i in range(100):
new_theta = update(theta)
err = (new_theta - theta)**2
theta = new_theta
if err <= tol: # optimal stopping
break
print(new_theta, E(5, theta))
We must note that since this is a minimization problem, with time, the value of the log-likelihood for decreases! Woohoo!
Example 2.
Okay, now the language from here on will get extremely terse, no bullshit, and only focusing on the most important findings in EM. Examples from this section onwards will be harder to keep up with. But will build a more general understanding of how EM works and how it can be applied.
This is an example covered in the Computational Statistics, by G. H. Givens and J. A. Hoeting (colostate.edu)
Wing color determined by a single gene with three possible alleles which we denote C, I, and T. C is dominant to I, and T is recessive to I. Thus genotypes CC, CI, and CT result in the carbonaria phenotype having solid black coloring. The TT genotype results in the typica phenotype with light-colored wings. The II and IT genotypes produce an intermediate insularia phenotype with mottled wings.
Key points:
Genotypes are expresses as combination of two alleles. They cannot be directly observed
The expression of Genotype is called Phenotype. Phenotype is the physical trait that can be directly observed in the nature.
Carbonaria is dominant to Insularia. Insularia is dominant to Typica
Hard Weinberg Equilibrium states that if two alleles match then their genotype frequency is represented as square of the allele frequency. If the two alleles are different then their genotype frequencies will be represented as 2 times the product of individual allele frequencies.
- Allele frequency:
- genotype frequency: How is the problem modelled?
Y: Genotype counts of the data
X: Phenotype counts of the data (thins we have observed in the nature)
Notice how we haven't defined a model for Z. In EM, you do not need to model Z ever. You simply figure out a mapping between X and Y and the EM algorithm works.
The said mapping is given as follows: $$ y: x = M(y) = (n_{cc} + n_{ci} + n_{ct}, n_{ii}+n_{it}, n_{tt}) $$
Here mapping simply refers to how corresponding terms are related. For example:
The Log-likelihood is a multinomial and is given as follows: $$ \log {f_Y(y|\mathbf{p})} = n_{CC} \log {2 p_C^2} + n_{CI} \log {2 p_C p_I} + n_{CT} \log {2 p_C p_T} + n_{II} \log {2 p_I^2} + n_{IT} \log {2 p_I p_T} + n_{TT} \log {2 p_T^2} + \log \left( \frac{n_{CC} \quad n_{CI} \quad n_{CT} \quad n_{II} \quad n_{IT} \quad n_{TT}}{n} \right) $$
One must now ask, what are the parameters in this case?
- Parameters will be "free". I.e. they will be those components in the log-likelihood that can be adjusted. OR parameters are all those values that ARE NOT observed values
- In the log-likelihood function you will notice that we are not given and nor are we given the . Does that mean we estimate both the p and M(y)? Well no! Read the next part
- This is where we exploit the Hardy Weinberg principle/equilibrium and relate our allele frequencies to genotype frequencies. We express each term in Y as a function of p.
Now let's understand the Q function for this. Q function is simply an optimization transfer problem i.e. It converts the original log-likelihood function into what is called a "Minorizing function". Mathematically it is expressed as follows: $$ \large l(\theta|x) \geq Q(\theta|\theta^{(t)}) + l(\theta^{(t)}|x) -Q(\theta{(t)}|\theta{(t)}) = G(\theta|\theta^{(t)}) $$
Note that our Q function is going to be essentially the same as our Log-likelihood function but y is replaced with its corresponding expectation! The Q function for the above problem looks as follows: $$ Q(p|p^{(t)}) = n_{CC}^{(t)} \log{p_C^{2}} + n_{CI}^{(t)} \log{2p_Cp_I} + n_{CT}^{(t)} \log{2p_Cp_T} + n_{II}^{(t)} \log{p_I^{2}} + n_{IT}^{(t)} \log{2p_Ip_T} + n_{TT} \log{p_T^{2}} + k(n_C, n_I, n_T, p^{(t)}) $$
Well this is the exact same at the log-likelihood function right? No
- The (t) at the top represents the estimated value at iteration t
- Recall that the Q function is expressed as . Expectation (E) of any function will replace the main input parameter of the function with its expected value (read average value).
- Notice how Q function is NOT a function of y or x ! It's only a function parameters!
- Ok then how do we know the expected value of Y?
Getting Expectations is going to be the key learning for getting to the E step.
- With help of an Example we can see that the E for n_CC is given as follows:
- How did we arrive at this? Well think of n_C as the total number of phenotypes with carbonaria dominant. The remaining part is simply a proportion. A thumb rule to use here is that to calculate the proportion, simple include the immediate genotype frequency in the neumerator and in the denominator use the sum of all the allele frequencies where Carbonaria is dominant
- Let's take another example. How would n_IT look?
$$
E[N_{IT}|n_C, n_I, n_T, p^{(t)}] = n_{IT}^{(t)} = \frac{n_I * 2p_{I}{(t)}p_T{(t)}}{(p_I{(t)})2 + 2p_{I}{(t)}p_T{(t)}}
$$
- Notice how the denominator only contains the sum of genotype frequencies where atleast one insularia and one typical allele is present.
- One final example to seal the deal: Assume you now have another unknown population that you call . All we know is that this sample is a mix of Insularia and Typica alleles. How would n_IT look now?
- Answer: We start by writing a new mapping for Y and X
- Notice how for the second term, the proportion is estimated by including all genotype frequencies where Insularia and Typica occur.
- How do I get the M step? Well you simply take a derivative of the Q function w.r.t each parameter in p and then solve for the generation solution of p. One simple trick to write the update step is to just write it like this: $$ p_C^{(t+1)} = \frac{2n_{CC}^{(t)} + n_{CI}^{(t)} + n_{CT}^{(t)}}{2n} $$
- Well this is trivial to see because our allele frequencies HAVE to be the ratio sum of all corresponding genotypes to the total number of observations. Sometimes instead of doing the full derivation just common sense. M step is essentially your opportunity to update your parameter estimates. Update params can be used to compute the Q function.
III
Terence Tao views mathematics as climbing peaks, trying to reach various goals. "Some peaks are just out of your reach", he adds. I'm fond of this take as it implies that there exists a terrain of mathematics (perhaps too grand for the human-brain to comprehend). We are simply blind travellers on this vast and largely hidden super-structure. "Intuition" can be thought of as a viewport for this super-structure. Sadly, because humanity does not YET have a hive-mind, every new generation has to reconstruct this viewport over the course of several years. What hurts even more is that every person's viewport spawns at some random co-ordinates on this super-structure; some one might be cursed with a viewport that shows a flat terrain till the end of horizon (people who have poor math intuition) while some might have a viewport for a region that gives them a more zoomed-out or larger view of this super-structure. In my case, I was cursed with a view-port with a flat terrain however, because of the brown-family-peer-pressure I was forced to take higher level math. In just a couple of months, I had "an intelligence explosion" and I could suddenly "see" more of math.
I've been collecting a bunch of tricks to effectively expand my viewport. Before we get dirtier with EM math, I'll share one controversial trick. This trick has to do with the "sensation of understanding". It feels weird to type it out but I experience the moment of "I understand X" almost as a clear sensation similar to taste or touch. And one interesting personal revelation is that, the sensation of understanding is strongest when I'm able to successfully place a concept in my mind relative to other concepts. We'll come back to this but let me quickly tell you about the time I was trying to understand a complicated math proof and I was unable to tell how the author went from step 1 to step 2. Because I was running late for an exam, I just made up a shitty + vulgar mnemonic to remember the steps. Then the most interesting thing happened when I tried to study another proof a while later. I was able to follow the steps for this new unseen proof using the same shitty + vulgar mnemonic as a proxy for the real underlying intuition. This has a very similar sensation to that of understanding the move from step 1 to step 2.
So my biggest take away was that if you don't understand how some parts of math works, just make up vulgar assumptions (vulgar is easier to recall). There is an occlusion in your viewport and you gotta move on. I promise eventually your vulgar assumption get's replace with the real intuition (if you spend enough time playing around with math). Now for this next part, don't focus too much on reasoning or intuitive understanding. Just follow along and get used to it.
EM in Exponential Families
Sufficient Statistics
All the necessary statistics/info about the sample data that is needed for the inference of the probability density that generated the sample data. Often denoted by s(t).
Steps:
E Step: Compute the Expected values of the sufficient statistics for the complete data: $$ \large s^{(t)} = E[s(Y) | x, \theta^{(t)}] = \int s(y)f_{Z|X}(z | x, \theta^{(t)}) , dz $$
M step: Now with the newly acquired solve for by finding root of . Notice how its the un-conditional expectation (no dependency on x) of sufficient statistic
Run the above two steps till convergence
Missing Info Principle
We start with the following:
where . Read as "H is expected value of the log-likelihood of hidden/missing data given the parameter estimates at time t".
This intuitively makes sense because if Q function is the expectation of the full data then subtracting the expectation of missing data from it will simply give the expectation of the observed data.
Information
In stats, information or Fisher information is an attribute of a random variable. It reveals how much information the random variable has about a parameter. It is given as the negative of Expectation of the second order partial derivative of a log-likelihood function 😭 OR: $$ \Large I(\theta) = -E\left[ \frac{\partial^2}{\partial \theta^2} \log L(\theta) \right] $$
Now taking the information of we get:
Finally, the missing info principle is:
Supplemented EM (SEM) algorithm
This is used to find the variance of the parameter estimates.
EM Mapping
Simply the process that takes the parameters at iteration to iteration t+1. I.e. the E and the M step. Denoted by $\psi$.
Just remember this:
IV
The previous sections cover enough basics about EM. This part covers some hard-core EM application examples. This is an exercise in modelling: Taking real world problems and then casting it to a concrete model.
Example 3.
You are given a dataset consisting of data points assumed to be generated from a mixture of two Gaussian distributions with unknown means and known, equal variances . The mixing coefficient indicates the probability that a given data point comes from the first Gaussian distribution, and is the probability for the second distribution. The goal is to estimate the parameters , and using the Expectation-Maximization (EM) algorithm.
If you have followed the EM-algorithm steps so far, then it'd be be obvious to you on what needs to be done. To outline the steps roughly:
- Start with the probability density function for GMM.
- Put a symbol before it (cuz i.id) and boom this is our likelihood function.
- Take a log-transform to get the log-likelihood function.
- Compute the Q function as the expected value of the log-likelihood function.
- Compute the M step by maximizing the Q function w.r.t the parameters. Seems easy right?
In case of Gaussian Mixture Models, we encounter a in step 3 which is truly disastrous. Because simplifying the log transform of a sum is pretty much not possible. We use some awesome-sauce tricks to by-pass this log-transform. Let me elaborate a little more:
Some important notations:
The formula for a normal distribution looks as follows:
Likelihood for GMM
The likelihood is given by:
where
On the second and third lines above, the sum is over possible cluster assignments (k) for point (i). Taking the log,
And there you go! You have the god-forsaken log of a sum that we cannot eliminate. BUT WE TOTALLY CAN ELIMINATE IT WITH SOME MATH WIZARDRY! USE JENSEN'S INEQUALITY MF.
Using Jensen's inequality, we get elegantly handle that summation symbol. I highly suggest you follow this source
is called the auxiliary function. Now back to our original problem. The Auxiliary function for this problem is just the responsibility function.
E-step:
In the E-step, we compute the responsibility (\gamma (z_{i1})) that the first Gaussian distribution has for generating each data point (x_i). This is given by:
where is the probability density function of the Gaussian distribution:
M-step:
In the M-step, we update the parameters (\mu_1, \mu_2), and (\pi) using the responsibilities computed in the E-step.
- Update and :
The new estimates for and are calculated by taking the weighted average of all data points, weighted by their responsibilities:
- Update :
The new estimate for is the average responsibility that the first Gaussian distribution has for generating all data points:
The EM algorithm iteratively applies these E-step and M-step computations until the parameter estimates converge, that is, until the changes in the parameter estimates and between successive iterations are sufficiently small. Each iteration consists of using the current parameter estimates to compute the responsibilities in the E-step, and then updating the parameters based on these responsibilities in the M-step, thus maximizing the likelihood of the observed data given the model.
Example 4.
You are given a dataset from a series of experiments where two coins, A and B, are flipped, but the identity of the coin used in each experiment is unknown. Each experiment consists of a fixed number of coin flips, and only the number of heads observed is recorded.
Objective: Your task is to use the Expectation-Maximization (EM) algorithm to estimate the bias (probability of landing heads) of each coin, denoted as and .
Given Data:
- Number of experiments:
- Number of flips in the (i)-th experiment:
- Number of heads observed in the -th experiment:
- Initial bias estimates: and
Questions:
- Derive the E-Step of the EM algorithm for this problem. Calculate the posterior probability that coin A was used in the -th experiment, given the observed data and current bias estimates.
- Derive the M-Step. Update the bias estimates for and based on the posterior probabilities computed in the E-Step.
- Explain how the EM algorithm iteratively improves the estimates of and and discuss the convergence criterion for the algorithm.
Solution:
E-Step: In the E-Step, we calculate the expected value of the latent variable , which indicates whether coin A was used in the -th experiment. The posterior probability that coin A was used, given the observed data and the current parameter estimates, is calculated as follows: $$ \large P(Z_{iA} = 1 | X_i; \theta_A, \theta_B) = \frac{\theta_A^{X_i} (1 - \theta_A)^{N_i - X_i}}{\theta_A^{X_i} (1 - \theta_A)^{N_i - X_i} + \theta_B^{X_i} (1 - \theta_B)^{N_i - X_i}} $$
M-Step: In the M-Step, we update the parameters and to maximize the expected complete-data log-likelihood. The updates are given by: $$ \large \theta_A^{(new)} = \frac{\sum_{i=1}^M P(Z_{iA} = 1 | X_i; \theta_A, \theta_B) X_i}{\sum_{i=1}^M P(Z_{iA} = 1 | X_i; \theta_A, \theta_B) N_i} $$ $$ \large \theta_B^{(new)} = \frac{\sum_{i=1}^M (1 - P(Z_{iA} = 1 | X_i; \theta_A, \theta_B)) X_i}{\sum_{i=1}^M (1 - P(Z_{iA} = 1 | X_i; \theta_A, \theta_B)) N_i} $$
Iterative Improvement and Convergence: The EM algorithm iteratively applies the E-Step and M-Step until the changes in the parameter estimates and become negligible, indicating convergence. A common convergence criterion is the absolute change in the log-likelihood function, which is expected to increase with each iteration. When the increase falls below a predefined threshold, the algorithm is considered to have converged.