Skip to content

inference

leaguedata.inference

numpyro_model(markov_model, observed_data)

Function that is used as a model in NumPyro to perform inference on the Discrete Markov Chain model.

PARAMETER DESCRIPTION
markov_model

The Discrete Markov Chain model to use.

TYPE: DTMCModel

observed_data

The observed data to use for inference.

TYPE: array

Source code in leaguedata/inference.py
def numpyro_model(markov_model, observed_data):
    """
    Function that is used as a model in NumPyro to perform inference on the Discrete Markov Chain model.

    Parameters:
        markov_model (DTMCModel): The Discrete Markov Chain model to use.
        observed_data (jnp.array): The observed data to use for inference.
    """

    if not markov_model.is_bernoulli:
        proba = numpyro.sample('proba',
                               dist.Uniform(low=jnp.zeros(2 ** markov_model.n), high=jnp.ones(2 ** markov_model.n)))
    else:
        proba = numpyro.sample('proba', dist.Uniform(low=0, high=1)) * jnp.ones(2 ** markov_model.n)

    transition_matrix = markov_model.build_transition_matrix(proba)

    def transition_fn(_, x):
        return tfd.Categorical(probs=transition_matrix[x])

    encoded_history = np.apply_along_axis(markov_model.binary_serie_to_categorical, 1, observed_data)

    likelihood_dist = tfd.MarkovChain(
        initial_state_prior=tfd.Categorical(probs=markov_model.uniform_prior),
        transition_fn=transition_fn,
        num_steps=encoded_history.shape[1]
    )

    numpyro.sample('likelihood', likelihood_dist, obs=encoded_history)