Waste-free Sequential Monte Carlo

tags
Sampling

Algorithm from this article.

  • How many MCMC steps needed for optimal performace? (always set arbitratily)

  • Number of steps should be set adaptively, especially for situation like tmpering.

  • Intermediate steps are wasted.

    Waste-free SMC uses intermediate steps as new particles.

Principle #

If there are \(N\) particles, we only resample \(M = N/P\) particles. Each resampled particle is moved \(P-1\) times and each iterate is taken to form a new sample of size \(N\).

How can we integrate both processes in the same algorithm?

The implement of sequential MC first consists in a resampling step:

resampling_idx = resampling_fn(weights, rng_key, num_resampled)
particles = jax.tree_map(lambda x: x[resampling_index], particles)

We can only resample between \(1\) and \(N\) particles. But since we dont'

Waste-free SMC may be recast as a standard SMC sampler that propopagates and reweighs particles that are markov chains of length P.

Indeed we have \(M\) chains of length \(P\).

\(z \in \mathcal{Z} = \Xi^P\). The potential functions are as:

\begin{equation} G^{wf}(z) = \frac{1}{P} \sum_{p=1}^{P} G_{t}(z[p]) \end{equation}

and the initial distribution

\begin{equation} \nu^{wf}(\mathrm{d}z) = \prod_{p=1}^{P} \nu(\mathrm{d}z[p]) \end{equation}

And the transition kernel:

\begin{equation}
 M_{t}^{wf}(z_{t-1}, \mathrm{d}z_{t}) = \left\{\sum_{p=1}^{P}\right\}
\end{equation}
  1. Choose one chain of length \(P\) with probability \(\propto \sum_p G_{t-1}(z_{t-1}[p])\);
  2. Choose one component \(q\) randomly with probability \(\propto G_{t-1}(z_{t-1}[q])\) to be the starting point of the next chain;
  3. Repeat

Which leads to the algorithm 2 in the paper:

  1. sampled_idx <- resample(M, weights)
  2. sampled_particles = particles[sampled_idx]
  3. traces = jax.scan(jax.vmap(mcmc)(keys, new_particles), P)
  4. particles = traces.flatten()
  5. weights = G(particles)
  6. return particles, weights

The equivalent algorithm 1 would be:

  1. sampled_particles = particles
  2. traces = jax.scan(jax.vmap(mcmc)(keys, new_particles), P)
  3. particles = traces[-1, :]
  4. weights = G(particles)
  5. return particles, weights

Code #

Let us now refactor the original smc code:

def smc(
    mcmc_kernel_factory: Callable,
    mcmc_state_generator: Callable,
    resampling_fn: Callable,
    num_mcmc_iterations: int,
    is_waste_free: bool = False,
):

    if is_waste_free:
	num_mcmc_iterations = num_mcmc_iterations - 1

    def kernel(
	rng_key: PRNGKey,
	state: SMCState,
	logprob_fn: Callable,
	log_weight_fn: Callable,
    ) -> Tuple[SMCState, SMCInfo]:

	weights, particles = state
	scan_key, resampling_key = jax.random.split(rng_key, 2)

	num_particles = weights.shape[0]
	if is_waste_free:
	    sub_num_particles, remainder = divmod(
		num_particles, num_mcmc_iterations + 1
	    )
	    if remainder > 0:
		raise ValueError(
		    "`num_mcmc_iterations` must be a divider "
		    f"of `num_particles`, {num_mcmc_iterations} and "
		    f"{num_particles} were given"
		)
	else:
	    sub_num_particles = num_particles

	resampling_index = resampling_fn(weights, resampling_key, sub_num_particles)
	particles = jax.tree_map(lambda x: x[resampling_index], particles)

	# First advance the particles using the MCMC kernel
	mcmc_kernel = mcmc_kernel_factory(logprob_fn)

	def mcmc_body_fn(carry, curr_key):
	    curr_particles, n_accepted = carry
	    keys = jax.random.split(curr_key, sub_num_particles)
	    new_particles, mcmc_info = jax.vmap(mcmc_kernel, in_axes=(0, 0))(
		keys, curr_particles
	    )
	    n_accepted = n_accepted + mcmc_info.is_accepted
	    return (new_particles, n_accepted), new_particles

	mcmc_state = jax.vmap(mcmc_state_generator, in_axes=(0, None))(
	    particles, logprob_fn
	)

	keys = jax.random.split(scan_key, num_mcmc_iterations)
	(proposed_states, total_accepted), proposed_states_history = jax.lax.scan(
	    mcmc_body_fn, (mcmc_state, jnp.zeros((sub_num_particles,))), keys
	)
	acceptance_rate = jnp.mean(total_accepted / num_mcmc_iterations)

	if is_waste_free:
	    initial_position, tree_def = jax.tree_flatten(mcmc_state.position)
	    chains_history, _ = jax.tree_flatten(proposed_states_history.position)

	    position_history = [
		jnp.concatenate([jnp.expand_dims(elem1, 0), elem2])
		for elem1, elem2 in zip(initial_position, chains_history)
	    ]
	    position_history = jax.tree_unflatten(tree_def, position_history)
	    proposed_particles = jax.tree_map(
		lambda z: jnp.reshape(z, (num_particles,) + z.shape[2:]),
		position_history,
	    )

	else:
	    proposed_particles = proposed_states.position
	# Resample the particles depending on their respective weights
	log_weights = jax.vmap(log_weight_fn, in_axes=(0,))(proposed_particles)
	weights, log_likelihood_increment = _normalize(log_weights)

	state = SMCState(weights, proposed_particles)
	info = SMCInfo(resampling_index, log_likelihood_increment, acceptance_rate)
	return state, info

    return kernel

First separate vanilly and waste-free SMC: