How-to Guides#

TODO: provide worked examples of specific tasks.

Creating your own model#

Create a new subclass of pypfilt.model.Model and implement the following methods:

  • field_types(), which defines the structure of the particle state vector;

  • init(), which initialises the particle state vectors at the start of a simulation; and

  • update(), which updates the particle state vectors at each time-step.

If you want to allow any elements in the particle state vectors to be smoothed (“regularisation”) the model must also implement the following method:

  • can_smooth(), which identifies the particle state vector elements that can be smoothed by the particle filter.

Note

If your model class contains any internal variables that are not stored in the particle state vectors, you should also implement the resume_from_cache() method so that these variables can be properly initialised when beginning a simulation from a saved state.

It is preferable to store all necessary model variables in the particle state vectors and avoid this complication.

For example, shown below is the code for the Gaussian random walk model (GaussianWalk) provided in the pypfilt.examples.simple module:

class GaussianWalk(pypfilt.Model):
    r"""
    A Gaussian random walk.

    .. math::

       x_t &= x_{t-1} + X_t \\
       X_t &\sim N(\mu = 0, \sigma = 1)

    The initial values :math:`x_0` are defined by the prior distribution for
    ``"x"``:

    .. code-block:: toml

       [prior]
       x = { name = "uniform", args.loc = 10.0, args.scale = 10.0 }
    """

    def field_types(self, ctx):
        return [('x', np.dtype(float))]

    def update(self, ctx, time_step, is_fs, prev, curr):
        """Perform a single time-step."""
        rnd = ctx.component['random']['model']
        step = rnd.normal(loc=0, scale=1, size=curr.shape)
        curr['x'] = prev['x'] + step

Creating your own observation model#

Inherit from the pypfilt.obs.Univariate class to create observation models that are based on Scipy probability distributions.

For example, shown below is the code for the Gaussian random walk observation model provided in the pypfilt.examples.simple module:

class GaussianObs(pypfilt.obs.Univariate):
    r"""
    A Gaussian observation model for the GaussianWalk model.

    .. math::

       \mathcal{L}(y_t \mid x_t) \sim N(\mu = x_t, \sigma = s)

    The observation model has one parameter: the standard deviation :math:`s`,
    whose value is defined by the ``"parameters.sdev"`` setting:

    .. code-block:: toml

       [observations.x]
       model = "pypfilt.examples.simple.GaussianObs"
       parameters.sdev = 0.2
    """

    def distribution(self, ctx, snapshot):
        expected = snapshot.state_vec['x']
        sdev = self.settings['parameters']['sdev']
        return scipy.stats.norm(loc=expected, scale=sdev)

Note

You can override any of the pypfilt.obs.Obs methods. This can be useful if you want to support incomplete observations (override log_llhd()), or need to support custom file formats when loading observations (override from_file()).

Defining prior distributions#

Reading prior samples from external data files#

Prior samples can be read from external data files. This allows you to use arbitrarily-complex model prior distributions, which can incorporate features such as correlations between model parameters.

Samples can be read from space-delimited text files using pypfilt.io.read_table():

[prior]
x = { external = true, table = "input-file.ssv", column = "x" }

Samples can also be read from HDF5 datasets, by specifying the file name, dataset path, and column name:

[prior]
x.external = true
x.hdf5 = "prior-samples.hdf5"
x.dataset = "data/prior-samples"
x.column = "x"

Provide models with lookup tables#

Define partition-specific model parameters#

There are a number of ways to define model parameters that take different values in each partition of the ensemble. Here are a few examples:

  • Define a unique value for each partition in the model table:

    [model]
    values_alpha = [10, 50, 100]
    

    The model can use these values to initialise the appropriate field(s) of the state vectors:

    def init(self, ctx, vec):
        for (ix, partition) in ctx.get_setting(['filter', 'partition'], []):
            part_ixs = partition['slice']
            value = ctx.settings['model']['values_alpha'][ix]
            vec['alpha'][part_ixs] = value
    
  • Define a probability distribution for each partition by defining a separate parameter for each partition in the prior table:

    [prior]
    alpha_1 = { name = "uniform", args.loc = 0, args.scale = 1}
    alpha_2 = { name = "uniform", args.loc = 0, args.scale = 10}
    

    The model can use these values to initialise the appropriate field(s) of the state vectors:

    def init(self, ctx, vec):
        for (ix, partition) in ctx.get_setting(['filter', 'partition'], []):
            part_ixs = partition['slice']
            param_name = 'alpha_{}'.format(ix)
            values = ctx.data['prior'][param_name]
            count = len(vec[param_name][part_ixs])
            vec[param_name][part_ixs] = values[:count]
    
  • Create a separate lookup table for each partition, and associate each table with a partition by adding a setting to the model table:

    [model]
    partition_lookup_tables = ['first_partition', 'second_partition']
    
    [lookup_tables]
    # Define the 'first_partition' and 'second_partition' tables here.
    

    The model can sample values from each lookup table and assign them to the appropriate subsets of the ensemble:

    def init(self, ctx, vec):
        start = ctx.get_setting(['time', 'sim_start'])
        for (ix, partition) in ctx.get_setting(['filter', 'partition'], []):
            part_ixs = partition['slice']
            table_name = ctx.settings['model']['partition_lookup_tables'][ix]
            values = ctx.component['lookup'][table_name].lookup(start)
            count = len(vec['parameter'][part_ixs])
            vec['parameter'][part_ixs] = values[:count]
    
  • Construct appropriate prior samples for these parameters that account for the partition sizes and ordering, and read these samples from external files (as described above);

Provide observation models with lookup tables#

An observation model may allow some of its parameters to be defined in lookup tables, rather than as fixed values. This allows parameters to differ between particles and to vary over time.

Note

The observation model must support using lookup tables for parameter values. The example shown here uses the epifx.obs.PopnCounts observation model, which allows the observation probability to be defined in a lookup table.

To use a lookup table for an observation model parameter, the scenario must:

  • Define the lookup table by giving it a name (in this example, "pr_obs") and identifying the data file to use (in this example, "pr-obs.ssv");

  • Notify pypfilt that each particle should be associated with a column from this lookup table by setting sample_values = true (see the example below); and

  • Notify the observation model that it should use this lookup table for the observation probability, by providing it as an argument to the observation model’s constructor (in this example, the argument name is "pr_obs_lookup").

These steps are illustrated in the following TOML excerpt, which shows only the relevant lines:

[scenario.test.lookup_tables]
pr_obs = { file = "pr-obs.ssv", sample_values = true }

[scenario.test.observations.cases]
model = "epifx.obs.PopnCounts"
pr_obs_lookup = "pr_obs"

You can then retrieve values from the lookup table, and index them to select the appropriate value for each particle:

def get_values(ctx, snapshot, table_name):
    table = ctx.component['lookup'][table_name]
    ixs = snapshot.vec['lookup'][table_name]
    values = table.lookup(snapshot.time)[ixs]

Using lookup tables in your own models#

Time-varying inputs can be provided in a lookup table, and retrieved by the simulation model when updating the particle state vectors at each time-step. The example below shows how to retrieve the current value(s) for a parameter alpha from the lookup table called “alpha_table”.

def update(self, ctx, step_date, dt, is_fs, prev, curr):
    """Perform a single time-step.""""
    # Retrieve the current value(s) for the parameter alpha.
    alpha = ctx.component['lookup']['alpha_table'].lookup(time)
    # NOTE: now update the particle state vectors.
    pass

This table must be defined in the scenario definition:

[lookup_tables]
alpha_table = "alpha-values.ssv"

Note

You may want to avoid hard-coding the table name, and make the table name a model setting.

The scenario definition would need to include this new setting:

[model]
alpha_lookup = "some_table_name"

And the simulation model would use this setting in its update() method:

def update(self, ctx, step_date, dt, is_fs, prev, curr):
    """Perform a single time-step.""""
    # Retrieve the current value(s) for the parameter alpha.
    table_name = ctx.settings['model']['alpha_lookup']
    alpha = ctx.component['lookup'][table_name].lookup(time)
    # NOTE: now update the particle state vectors.
    pass

Using lookup tables in your own observation models#

Defining and using parameters#

Particle filter: resampling#

Particle filter: post-regularisation#

Creating summary tables#

Resampling before calculating summary statistics#

It can be useful to resample the particle ensemble before calculating summary statistics. For example, the SimulatedObs table resamples the particles each time that it generates simulated observations during the estimation pass. In contrast, it only resamples the particles at the beginning of the forecasting pass, to ensure that the simulated observations reflect individual particle trajectories. These resampling steps are performed by using pypfilt.resample.resample_weights() to return an array of resampled particle indices.

Note

This resampling only affects the simulated observations, and does not affect the particle ensemble. As a consequence, it does not perform steps such as post-regularisation.

You can use this same approach in your own summary tables:

  • Create a pseudo-random number generator (PRNG) to use with resample_weights();

  • Ensure the PRNG state is saved to, and restored from, cache files by implementing the save_state() and load_state() methods;

  • Resample the particles at each time during the estimation pass; and

  • Resample the particles at the start of the forecasting pass and retain these indices for the entire forecasting pass (e.g., store them in self.__samples_ixs as shown below).

import numpy as np
from pypfilt.cache import load_rng_states, save_rng_states
from pypfilt.io import time_field
from pypfilt.resample import resample_weights
from pypfilt.summary import Table

class MeanX(Table):
    """Record the mean value of ``x`` at each time unit."""

    def field_types(self, ctx, obs_list, name):
        # Create a PRNG for resampling the particles.
        prng_seed = ctx.settings['filter'].get('prng_seed')
        self.__rnd = np.random.default_rng(prng_seed)
        # Return the field types.
        return [time_field('fs_time'), time_field('time'), ('value', np.float64)]

    def n_rows(self, ctx, start_date, end_date, n_days, forecasting):
        # Reset the sample indices at the start of each simulation.
        self.__sample_ixs = None
        self.__forecasting = forecasting
        # Generate one row at each day.
        return n_days

    def load_state(self, ctx, group):
        """Restore the state of each PRNG from the cache."""
        load_rng_states(group, 'prng_states', {'resample': self.__rnd})

    def save_state(self, ctx, group):
        """Save the current state of each PRNG to the cache."""
        save_rng_states(group, 'prng_states', {'resample': self.__rnd})

    def add_rows(self, ctx, fs_time, window, insert_fn):
        """Record the mean value of ``x`` at each time unit."""
        for snapshot in window:
            if self.__sample_ixs is None:
                (sample_ixs, _weight) = resample_weights(snapshot.weights, self.__rnd)
                if self.__forecasting:
                    # Save these indices for use over the entire forecast.
                    self.__sample_ixs = sample_ixs
            else:
                sample_ixs = self.__sample_ixs

            # Now summarise the snapshot, using `sample_ixs` for indexing.
            x_values = snapshot.state_vec['x'][sample_ixs]
            row = (fs_time, snapshot.time, np.mean(x_values))
            insert_fn(row)

Creating summary monitors#

Load summary tables as Pandas data frames#

You can load summary tables with pypfilt.io.load_dataset(), which returns these tables as NumPy structured arrays. This can easily be extended to return summary tables as Pandas data frames:

import h5py
import pandas as pd
import pypfilt

def load_dataframe(source, dataset, time_scale):
    """
    Load a pypfilt summary table and convert it into a Pandas dataframe.

    :param source: The path to the HDF5 file.
    :param dataset: The HDF5 path to the summary table.
    :param time_scale: The simulation time scale.
    """
    with h5py.File(source, 'r') as f:
        table = pypfilt.io.load_dataset(time_scale, f[dataset])

    return pd.DataFrame(table)

You can then use this function to read summary tables from simulation output files:

output_file = 'output.hdf5'
dataset = '/tables/forecasts'
time_scale = pypfilt.Datetime()
forecasts = load_dataframe(output_file, dataset, time_scale)

Implement a continuous-time Markov chain (CTMC) model#

The pypfilt.examples.sir.SirCtmc model implementation is shown below.

class SirCtmc(Model):
    """
    A continuous-time Markov chain implementation of the SIR model.

    The model settings must include the following keys:

    * ``population_size``: The number of individuals in the population.
    """

    def field_types(self, ctx):
        """
        Define the state vector structure.
        """
        return [
            # Model state variables.
            ('S', np.int64),
            ('I', np.int64),
            ('R', np.int64),
            # Model parameters.
            ('R0', np.float64),
            ('gamma', np.float64),
            # Next event details.
            ('next_event', np.int64),
            ('next_time', np.float64),
        ]

    def can_smooth(self):
        """
        The fields that can be smoothed by the post-regularisation filter.
        """
        # Return the continuous model parameters.
        return {'R0', 'gamma'}

    def init(self, ctx, vec):
        """
        Initialise the state vectors.
        """
        # Initialise the model state variables.
        population = ctx.settings['model']['population_size']
        vec['S'] = population - 1
        vec['I'] = 1
        vec['R'] = 0

        # Initialise the model parameters.
        prior = ctx.data['prior']
        vec['R0'] = prior['R0']
        vec['gamma'] = prior['gamma']

        # Select the first event for event particle.
        vec['next_time'] = 0
        vec['next_event'] = 0
        self.select_next_event(ctx, vec, stop_time=0)

    def update(self, ctx, time_step, is_forecast, prev, curr):
        """
        Update the state vectors to account for all events that occur up to,
        and including, ``time``.
        """
        # Copy the state vectors, and update the current state.
        curr[:] = prev[:]

        # Simulate events for active particles.
        active = self.active_particles(curr, time_step.end)
        while any(active):
            # Simulate infection events.
            infections = np.logical_and(active, curr['next_event'] == 0)
            curr['S'][infections] -= 1
            curr['I'][infections] += 1

            # Simulate recovery events.
            recoveries = np.logical_and(active, curr['next_event'] == 1)
            curr['I'][recoveries] -= 1
            curr['R'][recoveries] += 1

            # Identifying which particles are still active.
            self.select_next_event(ctx, curr, stop_time=time_step.end)
            active = self.active_particles(curr, time_step.end)

    def active_particles(self, vec, stop_time):
        """
        Return a Boolean array that identifies the particles whose most recent
        event occurred no later than ``stop_time``.
        """
        return np.logical_and(
            vec['next_time'] <= stop_time,
            vec['I'] > 0,
        )

    def select_next_event(self, ctx, vec, stop_time):
        """
        Calculate the next event time and event type for each active particle.
        """
        active = self.active_particles(vec, stop_time)
        if not any(active):
            return

        # Extract state variables and parameters for the active particles.
        S = vec['S'][active]
        I = vec['I'][active]
        R0 = vec['R0'][active]
        gamma = vec['gamma'][active]
        N = ctx.settings['model']['population_size']

        # Calculate the mean rate of infection and recovery events.
        s_to_i_rate = R0 * gamma * S * I / (N - 1)
        i_to_r_rate = gamma * I
        rate_sum = s_to_i_rate + i_to_r_rate

        # Select the time of the next event.
        rng = ctx.component['random']['model']
        dt = -np.log(rng.random(S.shape)) / rate_sum
        vec['next_time'][active] += dt

        # Select the event type: False for infection and True for recovery.
        threshold = rng.random(S.shape) * rate_sum
        recovery_event = threshold > s_to_i_rate
        vec['next_event'][active] = recovery_event.astype(np.int64)

Implement a discrete-time Markov chain (CTMC) model#

The pypfilt.examples.sir.SirDtmc model implementation is shown below.

class SirDtmc(Model):
    """
    A discrete-time Markov chain implementation of the SIR model.

    The model settings must include the following keys:

    * ``population_size``: The number of individuals in the population.
    """

    def field_types(self, ctx):
        """
        Define the state vector structure.
        """
        return [
            # Model state variables.
            ('S', np.int64),
            ('I', np.int64),
            ('R', np.int64),
            # Model parameters.
            ('R0', np.float64),
            ('gamma', np.float64),
        ]

    def can_smooth(self):
        """
        The fields that can be smoothed by the post-regularisation filter.
        """
        # Return the continuous model parameters.
        return {'R0', 'gamma'}

    def init(self, ctx, vec):
        """
        Initialise the state vectors.
        """
        # Initialise the model state variables.
        population = ctx.settings['model']['population_size']
        vec['S'] = population - 1
        vec['I'] = 1
        vec['R'] = 0

        # Initialise the model parameters.
        prior = ctx.data['prior']
        vec['R0'] = prior['R0']
        vec['gamma'] = prior['gamma']

    def update(self, ctx, time_step, is_forecast, prev, curr):
        """
        Update the state vectors.
        """
        rng = ctx.component['random']['model']
        beta = prev['R0'] * prev['gamma']
        denom = ctx.settings['model']['population_size'] - 1

        # Calculate the rate at which *an individual* leaves S.
        s_out_rate = time_step.dt * beta * prev['I'] / denom
        # Select the number of infections.
        s_out = rng.binomial(prev['S'], -np.expm1(-s_out_rate))

        # Calculate the rate at which *an individual* leaves I.
        i_out_rate = time_step.dt * prev['gamma']
        # Select the number of recoveries.
        i_out = rng.binomial(prev['I'], -np.expm1(-i_out_rate))

        # Update the state variables.
        curr['S'] = prev['S'] - s_out
        curr['I'] = prev['I'] + s_out - i_out
        curr['R'] = prev['R'] + i_out

        # Copy the model parameters.
        curr['R0'] = prev['R0']
        curr['gamma'] = prev['gamma']

Implement an ordinary differential equation (ODE) model#

The pypfilt.examples.sir.SirOdeEuler model implementation is shown below. For simplicity, it uses the forward Euler method.

class SirOdeEuler(Model):
    """
    An ordinary differential equation implementation of the SIR model, which
    uses the forward Euler method.

    The model settings must include the following keys:

    * ``population_size``: The number of individuals in the population.
    """

    def field_types(self, ctx):
        """
        Define the state vector structure.
        """
        return [
            # Model state variables.
            ('S', np.float64),
            ('I', np.float64),
            ('R', np.float64),
            # Model parameters.
            ('R0', np.float64),
            ('gamma', np.float64),
        ]

    def can_smooth(self):
        """
        The fields that can be smoothed by the post-regularisation filter.
        """
        # Return the continuous model parameters.
        return {'R0', 'gamma'}

    def init(self, ctx, vec):
        """
        Initialise the state vectors.
        """
        # Initialise the model state variables.
        population = ctx.settings['model']['population_size']
        vec['S'] = population - 1
        vec['I'] = 1
        vec['R'] = 0

        # Initialise the model parameters.
        prior = ctx.data['prior']
        vec['R0'] = prior['R0']
        vec['gamma'] = prior['gamma']

    def update(self, ctx, time_step, is_forecast, prev, curr):
        """
        Update the state vectors.
        """
        # Calculate the flow rates out of S and I.
        beta = prev['R0'] * prev['gamma']
        N = ctx.settings['model']['population_size']
        s_out = time_step.dt * beta * prev['I'] * prev['S'] / N
        i_out = time_step.dt * prev['gamma'] * prev['I']

        # Update the state variables.
        curr['S'] = prev['S'] - s_out
        curr['I'] = prev['I'] + s_out - i_out
        curr['R'] = prev['R'] + i_out

        # Copy the model parameters.
        curr['R0'] = prev['R0']
        curr['gamma'] = prev['gamma']

The pypfilt.examples.sir.SirOdeRk model implementation is shown below. It uses a SciPy initial value problem (IVP) solver for ODEs.

Note

The SirOdeRk model derives from the OdeModel class, and requires fewer time-steps (time.steps_per_unit) than the SirOde model, because the ODE solver can divide each time-step into as many smaller steps as required.

class SirOdeRk(OdeModel):
    """
    An ordinary differential equation implementation of the SIR model, which
    uses the explicit Runge-Kutta method of order 5(4).

    The model settings must include the following keys:

    * ``population_size``: The number of individuals in the population.
    """

    def field_types(self, ctx):
        """
        Define the state vector structure.
        """
        return [
            # Model state variables.
            ('S', np.float64),
            ('I', np.float64),
            ('R', np.float64),
            # Model parameters.
            ('R0', np.float64),
            ('gamma', np.float64),
        ]

    def can_smooth(self):
        """
        The fields that can be smoothed by the post-regularisation filter.
        """
        # Return the continuous model parameters.
        return {'R0', 'gamma'}

    def init(self, ctx, vec):
        """
        Initialise the state vectors.
        """
        # Initialise the model state variables.
        self.population = ctx.settings['model']['population_size']
        vec['S'] = self.population - 1
        vec['I'] = 1
        vec['R'] = 0

        # Initialise the model parameters.
        prior = ctx.data['prior']
        vec['R0'] = prior['R0']
        vec['gamma'] = prior['gamma']

        # Define the integration method.
        self.method = 'RK45'

    def d_dt(self, time, xt, ctx, is_forecast):
        """
        The right-hand side of the system.
        """
        s_out = xt['R0'] * xt['gamma'] * xt['I'] * xt['S'] / self.population
        i_out = xt['gamma'] * xt['I']
        d_dt = np.zeros(xt.shape, dtype=xt.dtype)
        d_dt['S'] = -s_out
        d_dt['I'] = s_out - i_out
        d_dt['R'] = i_out
        return d_dt

Implement a stochastic differential equation (SDE) model#

The pypfilt.examples.sir.SirSde model implementation is shown below. It uses the Euler-Maruyama method.

class SirSde(Model):
    """
    A stochastic differential equation implementation of the SIR model.

    The model settings must include the following keys:

    * ``population_size``: The number of individuals in the population.
    """

    def field_types(self, ctx):
        """
        Define the state vector structure.
        """
        return [
            # Model state variables.
            ('S', np.float64),
            ('I', np.float64),
            ('R', np.float64),
            # Model parameters.
            ('R0', np.float64),
            ('gamma', np.float64),
        ]

    def can_smooth(self):
        """
        The fields that can be smoothed by the post-regularisation filter.
        """
        # Return the continuous model parameters.
        return {'R0', 'gamma'}

    def init(self, ctx, vec):
        """
        Initialise the state vectors.
        """
        # Initialise the model state variables.
        population = ctx.settings['model']['population_size']
        vec['S'] = population - 1
        vec['I'] = 1
        vec['R'] = 0

        # Initialise the model parameters.
        prior = ctx.data['prior']
        vec['R0'] = prior['R0']
        vec['gamma'] = prior['gamma']

    def update(self, ctx, time_step, is_forecast, prev, curr):
        """
        Update the state vectors.
        """
        rng = ctx.component['random']['model']
        beta = prev['R0'] * prev['gamma']
        N = ctx.settings['model']['population_size']
        size = prev.shape

        # Calculate the mean flows out of S and I.
        s_mean = time_step.dt * beta * prev['I'] * prev['S'] / N
        i_mean = time_step.dt * prev['gamma'] * prev['I']
        # Sample the stochastic term for each flow.
        s_stoch = np.sqrt(s_mean) * rng.normal(size=size)
        i_stoch = np.sqrt(i_mean) * rng.normal(size=size)
        # Calculate the stochastic flows out of S and I, ensuring that all
        # compartments remain non-negative.
        s_out = np.clip(s_mean + s_stoch, a_min=0, a_max=prev['S'])
        i_out = np.clip(i_mean + i_stoch, a_min=0, a_max=prev['I'])

        # Update the state variables.
        curr['S'] = prev['S'] - s_out
        curr['I'] = prev['I'] + s_out - i_out
        curr['R'] = prev['R'] + i_out

        # Copy the model parameters.
        curr['R0'] = prev['R0']
        curr['gamma'] = prev['gamma']

Implement an observation model that relies on past state#

The pypfilt.examples.sir.SirObs model implementation is shown below. The expected value is derived from the decrease in susceptible individuals over the observation period \(\Delta\).

class SirObs(Univariate):
    r"""
    A binomial observation model for the example SIR models.

    .. math::

       \mathcal{L}(y_t \mid x_t) &\sim B(n, p)

       n &= S(t-\Delta) - S(t)

    :param obs_unit: A descriptive name for the data.
    :param settings: The observation model settings dictionary.

    The settings dictionary should contain the following keys:

    * ``observation_period``: The observation period :math:`\Delta`.

    For example, for daily observations that capture 80% of new infections:

    .. code-block:: toml

       [observations.cases]
       model = "pypfilt.examples.sir.SirObs"
       observation_period = 1
       parameters.p = 0.8
    """

    def new_infections(self, ctx, snapshot):
        r"""
        Return the number of new infections :math:`S(t-\Delta) - S(t)` that
        occurred during the observation period :math:`\Delta` for each
        particle.
        """
        period = self.settings['observation_period']
        prev = snapshot.back_n_units_state_vec(period)
        new_infs = prev['S'] - snapshot.state_vec['S']
        # Round continuous values to the nearest integer.
        if not np.issubdtype(new_infs.dtype, np.int64):
            new_infs = new_infs.round().astype(np.int64)
        if np.any(new_infs < 0):
            raise ValueError('Negative number of new infections')
        return new_infs

    def distribution(self, ctx, snapshot):
        """
        Return the observation distribution for each particle.
        """
        prob = ctx.settings['observations'][self.unit]['parameters']['p']
        infections = self.new_infections(ctx, snapshot)
        return scipy.stats.binom(n=infections, p=prob)

Record only a subset of simulation model time-steps#

For simulation models that require very small time-steps, it may be desirable to avoid recording the particle states at each time-step. This can be achieved by using a large time-step and dividing each time-step into a number of “mini-steps” that will not be recorded, by applying the pypfilt.model.ministeps() decorator to the simulation model’s update() method.

This decorator can be used in two ways.

  1. Providing a default number of mini-steps, which can be overridden by the scenario settings:

    class TestModelPredef(pypfilt.model.Model):
        """
        A simple monotonic deterministic process, where the number of mini-steps
        is predefined by the decorator.
        """
    
        def field_types(self, ctx):
            return [('x', np.float64)]
    
        @pypfilt.model.ministeps(50)
        def update(self, ctx, time_step, is_fs, prev, curr):
            curr['x'] = prev['x'] + time_step.dt
            logger = logging.getLogger(__name__)
            logger.debug('Incrementing x by {}'.format(time_step.dt))
    
  2. Requiring the number of mini-steps to be defined in the simulation settings, by providing no default number of mini-steps:

    class TestModelSetting(pypfilt.model.Model):
        """
        A simple monotonic deterministic process, where the number of mini-steps
        must be defined in the scenario settings.
        """
    
        def field_types(self, ctx):
            return [('x', np.float64)]
    
        @pypfilt.model.ministeps()
        def update(self, ctx, time_step, is_fs, prev, curr):
            curr['x'] = prev['x'] + time_step.dt
            logger = logging.getLogger(__name__)
            logger.debug('Incrementing x by {}'.format(time_step.dt))
    

In both cases, the number of mini-steps can be specified in the scenario definition:

[time]
mini_steps_per_step = 100