Waste-free Sequential Monte Carlo
- tags
- Sampling
Algorithm from this article.
-
How many MCMC steps needed for optimal performace? (always set arbitratily)
-
Number of steps should be set adaptively, especially for situation like tmpering.
-
Intermediate steps are wasted.
Waste-free SMC uses intermediate steps as new particles.
Principle #
If there are \(N\) particles, we only resample \(M = N/P\) particles. Each resampled particle is moved \(P-1\) times and each iterate is taken to form a new sample of size \(N\).
How can we integrate both processes in the same algorithm?
The implement of sequential MC first consists in a resampling step:
resampling_idx = resampling_fn(weights, rng_key, num_resampled)
particles = jax.tree_map(lambda x: x[resampling_index], particles)
We can only resample between \(1\) and \(N\) particles. But since we dont'
Waste-free SMC may be recast as a standard SMC sampler that propopagates and reweighs particles that are markov chains of length P.
Indeed we have \(M\) chains of length \(P\).
\(z \in \mathcal{Z} = \Xi^P\). The potential functions are as:
\begin{equation} G^{wf}(z) = \frac{1}{P} \sum_{p=1}^{P} G_{t}(z[p]) \end{equation}
and the initial distribution
\begin{equation} \nu^{wf}(\mathrm{d}z) = \prod_{p=1}^{P} \nu(\mathrm{d}z[p]) \end{equation}
And the transition kernel:
\begin{equation}
M_{t}^{wf}(z_{t-1}, \mathrm{d}z_{t}) = \left\{\sum_{p=1}^{P}\right\}
\end{equation}
- Choose one chain of length \(P\) with probability \(\propto \sum_p G_{t-1}(z_{t-1}[p])\);
- Choose one component \(q\) randomly with probability \(\propto G_{t-1}(z_{t-1}[q])\) to be the starting point of the next chain;
- Repeat
Which leads to the algorithm 2 in the paper:
- sampled_idx <- resample(M, weights)
- sampled_particles = particles[sampled_idx]
- traces = jax.scan(jax.vmap(mcmc)(keys, new_particles), P)
- particles = traces.flatten()
- weights = G(particles)
- return particles, weights
The equivalent algorithm 1 would be:
- sampled_particles = particles
- traces = jax.scan(jax.vmap(mcmc)(keys, new_particles), P)
- particles = traces[-1, :]
- weights = G(particles)
- return particles, weights
Code #
Let us now refactor the original smc code:
def smc(
mcmc_kernel_factory: Callable,
mcmc_state_generator: Callable,
resampling_fn: Callable,
num_mcmc_iterations: int,
is_waste_free: bool = False,
):
if is_waste_free:
num_mcmc_iterations = num_mcmc_iterations - 1
def kernel(
rng_key: PRNGKey,
state: SMCState,
logprob_fn: Callable,
log_weight_fn: Callable,
) -> Tuple[SMCState, SMCInfo]:
weights, particles = state
scan_key, resampling_key = jax.random.split(rng_key, 2)
num_particles = weights.shape[0]
if is_waste_free:
sub_num_particles, remainder = divmod(
num_particles, num_mcmc_iterations + 1
)
if remainder > 0:
raise ValueError(
"`num_mcmc_iterations` must be a divider "
f"of `num_particles`, {num_mcmc_iterations} and "
f"{num_particles} were given"
)
else:
sub_num_particles = num_particles
resampling_index = resampling_fn(weights, resampling_key, sub_num_particles)
particles = jax.tree_map(lambda x: x[resampling_index], particles)
# First advance the particles using the MCMC kernel
mcmc_kernel = mcmc_kernel_factory(logprob_fn)
def mcmc_body_fn(carry, curr_key):
curr_particles, n_accepted = carry
keys = jax.random.split(curr_key, sub_num_particles)
new_particles, mcmc_info = jax.vmap(mcmc_kernel, in_axes=(0, 0))(
keys, curr_particles
)
n_accepted = n_accepted + mcmc_info.is_accepted
return (new_particles, n_accepted), new_particles
mcmc_state = jax.vmap(mcmc_state_generator, in_axes=(0, None))(
particles, logprob_fn
)
keys = jax.random.split(scan_key, num_mcmc_iterations)
(proposed_states, total_accepted), proposed_states_history = jax.lax.scan(
mcmc_body_fn, (mcmc_state, jnp.zeros((sub_num_particles,))), keys
)
acceptance_rate = jnp.mean(total_accepted / num_mcmc_iterations)
if is_waste_free:
initial_position, tree_def = jax.tree_flatten(mcmc_state.position)
chains_history, _ = jax.tree_flatten(proposed_states_history.position)
position_history = [
jnp.concatenate([jnp.expand_dims(elem1, 0), elem2])
for elem1, elem2 in zip(initial_position, chains_history)
]
position_history = jax.tree_unflatten(tree_def, position_history)
proposed_particles = jax.tree_map(
lambda z: jnp.reshape(z, (num_particles,) + z.shape[2:]),
position_history,
)
else:
proposed_particles = proposed_states.position
# Resample the particles depending on their respective weights
log_weights = jax.vmap(log_weight_fn, in_axes=(0,))(proposed_particles)
weights, log_likelihood_increment = _normalize(log_weights)
state = SMCState(weights, proposed_particles)
info = SMCInfo(resampling_index, log_likelihood_increment, acceptance_rate)
return state, info
return kernel
First separate vanilly and waste-free SMC: