CS After Dark

[NOTES] Expectation Maximization

This is an extremely long and mathematically rigorous essay on Expectation Maximization (EM) Algorithm from the lens of a Computer Science graduate who has not taken a math heavy course in a while.

Some of the best long-form essays, be they in text or video form, have one common denominator: They come from individuals who have developed a deep obsession with the subject of the essay. Obsession has positive connotations in this context. However, I'm not obsessed with EM and it holds no special place in my heart (as you will see soon). It's just that I have set a 4-hour timer and am forcing myself to learn as much as I can about EM. Maybe a more general "proxy" for the quality of an essay is "hours spent meditating on the subject" (readers will be the judge). I'm sure that trained mathematicians will view this as blasphemy.

I

I'm not a big fan of the unnecessary amount of formalism that some math textbooks use. In fact, one of my controversial beliefs is that some (only some) mathematicians deliberately gate-keep knowledge by finding the most unidiomatic formalism for stating a simple property. I'm especially annoyed but this gate-keeping because you will rarely find this happening in software-engineering. For example, the following code would get me brutal PR reviews and maybe a call from my manager:

// Gatekeep version of checking if a number is even
bool isEvenUnidiomatically(int number) { 
	std::bitset<32> binaryRepresentation(number);
	return binaryRepresentation[0] == 0; 
}

Anyway, let's get into notations and get past the gatekeeping.

Background

EM is an optimization algorithm and is:

Notation

Marginalization Model

Marginalization refers to the process of summing or integrating out one or more variables from a joint distribution to be able to study the distribution of the remaining variables.

fX(x|θ)=[y:M(y)=x]fY(y|θ)dy

If you see drop-downs like these, try to familiarize yourself with these terms to better grasp the math that will come.

"Frequentist Setting"
Frequentist Setting refers to an interpretation of probability AS A LONG RUNNING FREQUENT EXPERIMENT. Ex: when we claim that the chance of a coin toss outcome is 50% heads/tails, in frequentist setting we are saying "If you toss a coin and measure the outcome and repeat this process for very long time then you will find the probability of any side occurring to be 50% (based on the frequency)"

Optimization
Finding roots or finding some minima/maxima.

Density function
Think of "density" as just another term for "area under the function's curve"

Optimization
Finding roots or finding some minima/maxima.

Likelihood and Log-Liklihood ($L(\theta | x)$)
Likelihood is the probability of seeing some observed data given a parameter estimate. This is typically a function that generates our data. Optimization is all about finding the roots or the maxima or minima for a complicated likelihood function. Log-liklihood is just a log-transform of a likelihood function. It often makes calculations easier. You will see it denoted as $L(\theta | x)$.

EM Algorithm

EM is an optimization algorithm as stated above. It is used for estimating the parameters of probabilistic models. Typically, we deal with a (log-)likelihood function and want to find the set of parameters that are most likely to generated the sample data.

Suppose we want to maximize L(θ|x) with respect to θ we define the following function:

Q(θ|θ(t))=E{logL(θ|Y)|x,θ(t)}

Read as: "Expectation of joint log likelihood of complete data conditional on observed data X = x" Can also be written as:

Q(θ|θ(t))=E{logfY(y|θ)|x,θ(t)}

OR

Q(θ|θ(t))=[logfY(y|θ)]fZ|X(z|x,θ)dz

If this were to be written in a frequentist setting then we'd get a summation instead of an integral.

Now the EM algorithm boils down the the following two steps:

  1. E-Step: Compute Q(θ|θ(t)) (Expectation of the joint-likelihood function function given parameter (theta) estimate at iteration t)
  2. M Step: Maximize Q(θ|θ(t)) with respect to θ. Set θ(t+1) (theta at iteration t+1) as the maximizer of Q.

At this point you will ask, what the fuck is E? And how the fuck did we go from E to an integral? Yes it looks intimidating if you have never had a background to Statistics, but I will explain what it means with an example because it's so silly.

II

When you have mastered numbers, you will in fact no longer be reading numbers, any more than you read words when reading books. You will be reading meanings._ (Harold Geneen, “Managing”)

If it were up to me, I'd reduce the formal definition for EM Algorithm and make it like three lines and avoid all math jargon cuz it did less than 1% to build my intuition. Let's just dive into the examples and get dirty.

Example 1.

This is going to be probably the only detailed example. Rest of the examples will assume that the reader has understood and internalised these steps.

Y1,Y2~i.i.d. Exp(θ) with Y1=5 observed but Y2 missing.

Let's break this down. Read the above statement as "Y1 and Y2 are independently and identically distributed (i.id) and follow (~ symbol) an exponential distribution".


Expectation E[X]
MEAN! AVERAGE! or more formally: "Expectation, also known as the expected value or the mean. It's a measure of the central tendency of a random variable, providing a weighted average of all possible values that the random variable can take on, with the weights being the probabilities of those values. 🥰"

Now the first step is to get the probability density function of an exponential distribution. The probability density function (PDF) for an exponential distribution is given by:

f(x|λ)=λeλxfor x0$$Anditsexpectationisgivenby:

E(X) = \frac{1}{\lambda}

Thelikelihoodfunction$L(θ)$forasetofindependentandidenticallydistributed(i.i.d.)observations$x1,x2,,xn$istheproductoftheindividualPDFs(literallyjustslapabigProductsigninfrontofthedistribution):

L(\theta | x_1, x_2, \ldots, x_n) = \prod_{i=1}^{n} f(x_i|\theta)

SubstitutingtheexponentialPDFintothelikelihoodfunctiongives(thisthisissodumb.Wejustaddedaproductsign):

L(\theta | x_1, x_2, \ldots, x_n) = \prod_{i=1}^{n} \theta e^{-\theta x_i} = \theta^n e^{-\theta \sum_{i=1}^{n} x_i}

Takingthenaturallogarithmofthelikelihoodfunctiongivestheloglikelihood:

\ln L(\theta | x_1, x_2, \ldots, x_n) = n \ln(\theta) - \theta \sum_{i=1}^{n} x_i

Sonow,comingbacktoouroriginalfunction,canyouguesswhattheloglikelihoodfunctionis?Itsjustthis!

\log{L(\theta | y)} = \log{f_Y(y|\theta)} = 2\log{\theta} - \theta y_1 - \theta y_2.

BOOM!(wejustaddedtheloglikelihoodsforY1andY2)Now,rememberthattheEstepinvolvescomputingtheQfunction.Whatisthat?Howdowedothat?Well,thatswhenyousimplysubstitutey1andy2withtheirrespectiveExpectedvalue.Thus

Q(\theta|\theta^{(t)}) = 2\log{\theta} - 50\theta - \frac{\theta}{\theta^{(t)}} $ since E{Y2|y1,θ(t)}=E{Y2|θ(t)}=1θ(t)$ follows from independence.

Ok what about the M step?
The maximizer of Q(θ|θ(t)) is the root of 2θ51θ(t)=0. Thus $ \theta^{(t+1)} = \frac{2\theta{(t)}}{50\theta{(t)}+1}. $ Converges quickly to θ^=0.2. The code to implement the above thing looks as simple as follows:

import math
E = lambda y1, theta: 2 * math.log(theta) - y1 * theta - theta/theta
update = lambda theta: 2*theta/(5*theta + 1) # M step

tol = 2e-09
theta = 5.0 # initial guess
for i in range(100):
    new_theta = update(theta)
    err = (new_theta - theta)**2
    theta = new_theta
    if err <= tol: # optimal stopping
        break
    print(new_theta, E(5, theta)) 

We must note that since this is a minimization problem, with time, the value of the log-likelihood for θ(t) decreases! Woohoo!

Example 2.

Okay, now the language from here on will get extremely terse, no bullshit, and only focusing on the most important findings in EM. Examples from this section onwards will be harder to keep up with. But will build a more general understanding of how EM works and how it can be applied.

This is an example covered in the Computational Statistics, by G. H. Givens and J. A. Hoeting (colostate.edu)

Wing color determined by a single gene with three possible alleles which we denote C, I, and T. C is dominant to I, and T is recessive to I. Thus genotypes CC, CI, and CT result in the carbonaria phenotype having solid black coloring. The TT genotype results in the typica phenotype with light-colored wings. The II and IT genotypes produce an intermediate insularia phenotype with mottled wings.

Key points:

E[NCC|nC,nI,nT,p(t)]=nCC(t)=nC(pC(t))2(pC(t))2+2pC(t)pI(t)+2pC(t)pT(t) (nC,nI,nT,nU)=(ncc+nci+nct,nii+nit,ntt,Nii+Nit+Ntt)E[NIT|nC,nI,nT,nU,p(t)]=nIT(t)=nI*2*pI(t)pT(t)(pI(t))2+nU*2*pI(t)pT(t)(pI(t))2+2pI(t)pT(t)+(pT(t))2
	- Notice how for the second term, the proportion is estimated by including all genotype frequencies where Insularia and Typica occur.

III

Terence Tao views mathematics as climbing peaks, trying to reach various goals. "Some peaks are just out of your reach", he adds. I'm fond of this take as it implies that there exists a terrain of mathematics (perhaps too grand for the human-brain to comprehend). We are simply blind travellers on this vast and largely hidden super-structure. "Intuition" can be thought of as a viewport for this super-structure. Sadly, because humanity does not YET have a hive-mind, every new generation has to reconstruct this viewport over the course of several years. What hurts even more is that every person's viewport spawns at some random co-ordinates on this super-structure; some one might be cursed with a viewport that shows a flat terrain till the end of horizon (people who have poor math intuition) while some might have a viewport for a region that gives them a more zoomed-out or larger view of this super-structure. In my case, I was cursed with a view-port with a flat terrain however, because of the brown-family-peer-pressure I was forced to take higher level math. In just a couple of months, I had "an intelligence explosion" and I could suddenly "see" more of math.

I've been collecting a bunch of tricks to effectively expand my viewport. Before we get dirtier with EM math, I'll share one controversial trick. This trick has to do with the "sensation of understanding". It feels weird to type it out but I experience the moment of "I understand X" almost as a clear sensation similar to taste or touch. And one interesting personal revelation is that, the sensation of understanding is strongest when I'm able to successfully place a concept in my mind relative to other concepts. We'll come back to this but let me quickly tell you about the time I was trying to understand a complicated math proof and I was unable to tell how the author went from step 1 to step 2. Because I was running late for an exam, I just made up a shitty + vulgar mnemonic to remember the steps. Then the most interesting thing happened when I tried to study another proof a while later. I was able to follow the steps for this new unseen proof using the same shitty + vulgar mnemonic as a proxy for the real underlying intuition. This has a very similar sensation to that of understanding the move from step 1 to step 2.

So my biggest take away was that if you don't understand how some parts of math works, just make up vulgar assumptions (vulgar is easier to recall). There is an occlusion in your viewport and you gotta move on. I promise eventually your vulgar assumption get's replace with the real intuition (if you spend enough time playing around with math). Now for this next part, don't focus too much on reasoning or intuitive understanding. Just follow along and get used to it.

EM in Exponential Families


Sufficient Statistics
All the necessary statistics/info about the sample data that is needed for the inference of the probability density that generated the sample data. Often denoted by s(t).

Steps:

Missing Info Principle

We start with the following:

logfX(x|θ)=Q(θ|θ(t))H(θ|θ(t))

where H(θ|θ(t))=E[log(fZ|X(Z,x,θ(t)))|x,θ(t)]. Read as "H is expected value of the log-likelihood of hidden/missing data given the parameter estimates at time t".

This intuitively makes sense because if Q function is the expectation of the full data then subtracting the expectation of missing data from it will simply give the expectation of the observed data.


Information
In stats, information or Fisher information is an attribute of a random variable. It reveals how much information the random variable has about a parameter. It is given as the negative of Expectation of the second order partial derivative of a log-likelihood function 😭 OR: $$ \Large I(\theta) = -E\left[ \frac{\partial^2}{\partial \theta^2} \log L(\theta) \right] $$

Now taking the information of logfX(x|θ) we get:

(θ|x)=Q(θ|ω)|ω=θ+H(θ|ω)|ω=θ

Finally, the missing info principle is:

ı^x(θ)=ı^y(θ)ı^z|x(θ)

Supplemented EM (SEM) algorithm

This is used to find the variance of the parameter estimates.


EM Mapping
Simply the process that takes the parameters at iteration to iteration t+1. I.e. the E and the M step. Denoted by $\psi$.

Just remember this:

var{θ^}=ı^Y(θ^)1(I+ψ(θ^)T(Iψ(θ^))1)

IV

The previous sections cover enough basics about EM. This part covers some hard-core EM application examples. This is an exercise in modelling: Taking real world problems and then casting it to a concrete model.

Example 3.

You are given a dataset consisting of N data points assumed to be generated from a mixture of two Gaussian distributions with unknown means μ1,μ2 and known, equal variances σ2. The mixing coefficient π indicates the probability that a given data point comes from the first Gaussian distribution, and (1π) is the probability for the second distribution. The goal is to estimate the parameters (μ1,μ2), and (π) using the Expectation-Maximization (EM) algorithm.

If you have followed the EM-algorithm steps so far, then it'd be be obvious to you on what needs to be done. To outline the steps roughly:

  1. Start with the probability density function for GMM.
  2. Put a symbol before it (cuz i.id) and boom this is our likelihood function.
  3. Take a log-transform to get the log-likelihood function.
  4. Compute the Q function as the expected value of the log-likelihood function.
  5. Compute the M step by maximizing the Q function w.r.t the parameters. Seems easy right?

In case of Gaussian Mixture Models, we encounter a log in step 3 which is truly disastrous. Because simplifying the log transform of a sum is pretty much not possible. We use some awesome-sauce tricks to by-pass this log-transform. Let me elaborate a little more:

Some important notations:

xidatazicluster assignment for iμcenter of cluster kΣkspread of cluster kwkproportion of data in cluster k (mixture weights)

The formula for a normal distribution looks as follows:

p(X=x|μ,Σ)=1(2π)p/2|Σ|1/2exp(12(Xμ)TΣ1(Xμ)).

Likelihood for GMM

reference

The likelihood is given by:

likelihood=P({X1,,Xn}={x1,,xn}|w,μ,Σ),

where

w=[w1,,wk],μ=[μ1,,μk],Σ=[Σ1,,Σk].likelihood(θ)=iP(Xi=xi|θ),=ik=1KP(Xi=xi|zi=k,θ)P(zi=k|θ)(law of total probability),=ik=1K𝒩(xi;μk,Σk)wk.

On the second and third lines above, the sum is over possible cluster assignments (k) for point (i). Taking the log,

loglikelihood(θ)=logikP(Xi=xi|zi=k,θ)P(zi=k|θ),=ilogkP(Xi=xi|zi=k,θ)P(zi=k|θ).

And there you go! You have the god-forsaken log of a sum that we cannot eliminate. BUT WE TOTALLY CAN ELIMINATE IT WITH SOME MATH WIZARDRY! USE JENSEN'S INEQUALITY MF.

Using Jensen's inequality, we get elegantly handle that summation symbol. I highly suggest you follow this source

loglikelihood(θ)=ilog𝔼z[P(Xi=xi,Zi=k|θ)P(Zi=k|xi,θt)],i𝔼zlog[P(Xi=xi,Zi=k|θ)P(Zi=k|xi,θt)](Jensen's inequality),=ikP(Zi=k|xi,θt)log[P(Xi=xi,Zi=k|θ)P(Zi=k|xi,θt)]=:A(θ,θt).

(A(·,θt)) is called the auxiliary function. Now back to our original problem. The Auxiliary function for this problem is just the responsibility function.

E-step:

In the E-step, we compute the responsibility (\gamma (z_{i1})) that the first Gaussian distribution has for generating each data point (x_i). This is given by:

γ(zi1)=π𝒩(xi|μ1,σ2)π𝒩(xi|μ1,σ2)+(1π)𝒩(xi|μ2,σ2),

where (𝒩(x|μ,σ2)) is the probability density function of the Gaussian distribution:

𝒩(x|μ,σ2)=12πσ2exp((xμ)22σ2).

M-step:

In the M-step, we update the parameters (\mu_1, \mu_2), and (\pi) using the responsibilities computed in the E-step.

  1. Update (μ1) and (μ2):

The new estimates for (μ1) and (μ2) are calculated by taking the weighted average of all data points, weighted by their responsibilities:

μ1new=i=1Nγ(zi1)xii=1Nγ(zi1),μ2new=i=1N(1γ(zi1))xii=1N(1γ(zi1)).
  1. Update (π):

The new estimate for (π) is the average responsibility that the first Gaussian distribution has for generating all data points:

πnew=1Ni=1Nγ(zi1).

The EM algorithm iteratively applies these E-step and M-step computations until the parameter estimates converge, that is, until the changes in the parameter estimates ((μ1,μ2,) and (π)) between successive iterations are sufficiently small. Each iteration consists of using the current parameter estimates to compute the responsibilities in the E-step, and then updating the parameters based on these responsibilities in the M-step, thus maximizing the likelihood of the observed data given the model.

Example 4.

You are given a dataset from a series of experiments where two coins, A and B, are flipped, but the identity of the coin used in each experiment is unknown. Each experiment consists of a fixed number of coin flips, and only the number of heads observed is recorded.

Objective: Your task is to use the Expectation-Maximization (EM) algorithm to estimate the bias (probability of landing heads) of each coin, denoted as (θA) and (θB).

Given Data:

  1. Number of experiments: (M)
  2. Number of flips in the (i)-th experiment: (Ni)
  3. Number of heads observed in the (i)-th experiment: (Xi)
  4. Initial bias estimates: (θA(0)) and (θB(0))

Questions:

  1. Derive the E-Step of the EM algorithm for this problem. Calculate the posterior probability that coin A was used in the (i)-th experiment, given the observed data and current bias estimates.
  2. Derive the M-Step. Update the bias estimates for (θA) and (θB) based on the posterior probabilities computed in the E-Step.
  3. Explain how the EM algorithm iteratively improves the estimates of (θA) and (θB) and discuss the convergence criterion for the algorithm.

Solution:

  1. E-Step: In the E-Step, we calculate the expected value of the latent variable (ZiA), which indicates whether coin A was used in the (i)-th experiment. The posterior probability that coin A was used, given the observed data (Xi) and the current parameter estimates, is calculated as follows: $$ \large P(Z_{iA} = 1 | X_i; \theta_A, \theta_B) = \frac{\theta_A^{X_i} (1 - \theta_A)^{N_i - X_i}}{\theta_A^{X_i} (1 - \theta_A)^{N_i - X_i} + \theta_B^{X_i} (1 - \theta_B)^{N_i - X_i}} $$

  2. M-Step: In the M-Step, we update the parameters (θA) and (θB) to maximize the expected complete-data log-likelihood. The updates are given by: $$ \large \theta_A^{(new)} = \frac{\sum_{i=1}^M P(Z_{iA} = 1 | X_i; \theta_A, \theta_B) X_i}{\sum_{i=1}^M P(Z_{iA} = 1 | X_i; \theta_A, \theta_B) N_i} $$ $$ \large \theta_B^{(new)} = \frac{\sum_{i=1}^M (1 - P(Z_{iA} = 1 | X_i; \theta_A, \theta_B)) X_i}{\sum_{i=1}^M (1 - P(Z_{iA} = 1 | X_i; \theta_A, \theta_B)) N_i} $$

  3. Iterative Improvement and Convergence: The EM algorithm iteratively applies the E-Step and M-Step until the changes in the parameter estimates (θA) and (θB) become negligible, indicating convergence. A common convergence criterion is the absolute change in the log-likelihood function, which is expected to increase with each iteration. When the increase falls below a predefined threshold, the algorithm is considered to have converged.


##optimization #expectation-maximization #machine-learning #statistics