Prior model implementation in BAT.jl

In this notebook, the implementation of the prior model is demonstrated, starting with a simple two-component model and scaling up to the full 9 component model.

using Distributions, StatsBase, LinearAlgebra
using Plots, SpecialFunctions, Printf, Random, ValueShapes
using BAT, DensityInterface
const sf = SpecialFunctions;

gr(fmt=:png);
rng = MersenneTwister(42)
Random.MersenneTwister(42)

Simple two-component model: gluons

First we start with a simpler two-component model, and show all the steps explicity for clarity.

Forward model

The gluon distributions are parametrised by

\[x g(x) = A_{g1} x^{\lambda_g1}(1-x)^{K_g} + A_{g2} x^{\lambda_{g2}}(1-x)^5.\]

We also want to impose

\[\int_0^1 x g(x) dx = A_{g1} B(\lambda_{g1}+1, K_g+1) + A_{g2} B(\lambda_{g2}+1, 1+5) = 1,\]

with B(.,.) the Beta function.

Start by defining some useful functions:

function xg1x(x, λ_g1, K_g, θ_1)
    A_g1 = θ_1 / sf.beta(λ_g1 + 1, K_g + 1)
    return A_g1 * x^λ_g1 * (1 - x)^K_g
end

function xg2x(x, λ_g2, K_q, θ_2)
    A_g2 = θ_2 / sf.beta(λ_g2 + 1, K_q + 1)
    return A_g2 * x^λ_g2 * (1 - x)^K_q
end

function xgx(x, λ_g1, λ_g2, K_g, K_q, θ)
    xg1 = xg1x(x, λ_g1, K_g, θ[1])
    xg2 = xg2x(x, λ_g2, K_q, θ[2])
    return xg1 + xg2
end
xgx (generic function with 1 method)

Choose true values for the high-level parameters and show what the resulting model looks like.

θ = [0.5, 0.5]
λ_g1 = 0.5 # rand(rng, Uniform(0, 1))
λ_g2 = -0.7 # rand(rng, Uniform(-1, 0))
K_g = 3 # rand(rng, Uniform(2, 10))
K_q = 5
truths = (θ=θ, λ_g1=λ_g1, λ_g2=λ_g2, K_g=K_g, K_q=K_q);

A_g1 = θ[1] / sf.beta(λ_g1 + 1, K_g + 1)
A_g2 = θ[2] / sf.beta(λ_g2 + 1, K_q + 1);

Check integral = 1

total = A_g1 * sf.beta(λ_g1 + 1, K_g + 1) + A_g2 * sf.beta(λ_g2 + 1, K_q + 1)
print("Integral = ", total)
Integral = 1.0

Plot true model

x_grid = range(0, stop=1, length=50)

xg1 = A_g1 * x_grid .^ λ_g1 .* (1 .- x_grid) .^ K_g
xg2 = A_g2 * x_grid .^ λ_g2 .* (1 .- x_grid) .^ K_q

plot(x_grid, [xg1x(x, λ_g1, K_g, θ[1]) for x in x_grid],
    alpha=0.7, label="x g1(x)", lw=3, color="green")
plot!(x_grid, [xg2x(x, λ_g2, K_q, θ[2]) for x in x_grid],
    alpha=0.7, label="x g2(x)", lw=3, color="blue")
plot!(x_grid, [xgx(x, λ_g1, λ_g2, K_g, K_q, θ) for x in x_grid],
    alpha=0.7, label="x g1(x) + x g2(x)", lw=3, color="red")
plot!(xlabel="x")

Now, for the purposes of testing the prior implementation, sample some data from this distribution assuming that the data are produced by integrating over the function in different bins, and multiplying by some factor. Then, plot the model and data to compare.

bins = 0.0:0.05:1.0
bin_widths = bins[2:end] - bins[1:end-1]
bin_centers = (bins[1:end-1] + bins[2:end]) / 2

N = 1000
nbins = size(bin_centers)[1]

expected_counts = zeros(nbins)
observed_counts = zeros(Integer, nbins)
for i in 1:nbins
    xg = xgx(bin_centers[i], λ_g1, λ_g2, K_g, K_q, θ) * N
    expected_counts[i] = bin_widths[i] * xg
    observed_counts[i] = rand(rng, Poisson(expected_counts[i]))
end

plot(bin_centers, [xgx(x, λ_g1, λ_g2, K_g, K_q, θ) for x in bin_centers] .* bin_widths * N,
    alpha=0.7, label="Expected", lw=3, color="red")
scatter!(bin_centers, observed_counts, lw=3, label="Observed", color="black")

Store the data in a simple dict to pass to the likelihood later.

data = Dict()
data["N"] = N
data["bin_centers"] = bin_centers;
data["observed_counts"] = observed_counts;
data["bin_widths"] = bin_widths;

Fit

To fit this example data, we choose a prior over our hyperparameters θ, λ_g1, λ_g2, K_g and K_q.

We decide to choose a sensible Dirichlet prior, and have a look at some samples to help understand what this means.

dirichlet = Dirichlet([1, 1])
test = rand(rng, dirichlet, 1000)
plot(append!(Histogram(0:0.1:1), test[1, :]))
plot!(append!(Histogram(0:0.1:1), test[2, :]))

So, we see that this means all weights are equally likely for both components.

Now we can define the prior:

\[\mathrm{Prior} - P(\theta) P(\lambda_{g1}) P(\lambda_{g2}) P(K_g).\]

prior = NamedTupleDist(
    θ=Dirichlet([1, 1]),
    λ_g1=Uniform(0, 1), #Truncated(Normal(0.7, 0.1), 0, 1)
    λ_g2=Uniform(-1, 0), #Truncated(Normal(-0.7, 0.1), -1, 0),
    K_g=Uniform(2, 10), # Truncated(Normal(1, 2), 2, 10),
    K_q=Uniform(4, 6),
);

Define a simple poisson likelihood that describes the test data that we generated above.

likelihood = let d = data, f = xgx

    observed_counts = d["observed_counts"]
    bin_centers = d["bin_centers"]
    bin_widths = d["bin_widths"]
    N = data["N"]

    logfuncdensity(function (params)
        function bin_log_likelihood(i)
            xg = f(
                bin_centers[i], params.λ_g1, params.λ_g2, params.K_g, params.K_q, params.θ
            )
            expected_counts = bin_widths[i] * xg * N
            logpdf(Poisson(expected_counts), observed_counts[i])
        end

        idxs = eachindex(observed_counts)
        ll_value = bin_log_likelihood(idxs[1])
        for i in idxs[2:end]
            ll_value += bin_log_likelihood(i)
        end

        return ll_value
    end)
end
LogFuncDensity(Main.var"#9#10"{Int64, StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}, StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}, Vector{Integer}, typeof(Main.xgx)}(1000, StepRangeLen(0.05, 0.0, 20), 0.025:0.05:0.975, Integer[191, 111, 86, 66, 66, 36, 54, 40, 35, 33, 21, 11, 5, 4, 3, 3, 2, 1, 0, 0], Main.xgx))

The prior and likelihood can be passed to BAT.jl via the PosteriorDensity. We can then sample this density using bat_sample().

posterior = PosteriorDensity(likelihood, prior);
samples = bat_sample(
    posterior,
    MCMCSampling(mcalg=MetropolisHastings(), nsteps=10^4, nchains=4)
).result;
[ Info: Initializing new RNG of type Random123.Philox4x{UInt64, 10}
[ Info: MCMCChainPoolInit: trying to generate 4 viable MCMC chain(s).
[ Info: Selected 4 MCMC chain(s).
[ Info: Begin tuning of 4 MCMC chain(s).
[ Info: MCMC Tuning cycle 1 finished, 4 chains, 0 tuned, 0 converged.
[ Info: MCMC Tuning cycle 2 finished, 4 chains, 1 tuned, 0 converged.
[ Info: MCMC Tuning cycle 3 finished, 4 chains, 4 tuned, 4 converged.
[ Info: MCMC tuning of 4 chains successful after 3 cycle(s).
[ Info: Running post-tuning stabilization steps for 4 MCMC chain(s).

The SampledDensity gives a quick overview of the results.

SampledDensity(posterior, samples)
SampledMeasure(objectid = 0x5b94602a29c73dd6, varshape = NamedTupleShape((:θ, :λ_g1, :λ_g2, :K_g, :K_q)}(…))

Visualise results

We can (roughly) check how well the fit reconstructs the truth with a simple comparison.

x_grid = range(0, stop=1, length=50)
sub_samples = bat_sample(samples, OrderedResampling(nsamples=200)).result

plot()
for i in eachindex(sub_samples)
    s = sub_samples[i].v
    xg = [xgx(x, s.λ_g1, s.λ_g2, s.K_g, s.K_q, s.θ) for x in bin_centers]
    plot!(bin_centers, xg .* bin_widths * N, alpha=0.1, lw=3,
        color="darkorange", label="",)
end

xg = [xgx(x, λ_g1, λ_g2, K_g, K_q, θ) for x in bin_centers]
plot!(bin_centers, xg .* bin_widths * N, alpha=0.7,
    label="Expected", lw=3, color="red")

scatter!(bin_centers, observed_counts, lw=3, label="Observed", color="black")

plot!(xlabel="x")

In the above plot, the red line represents the truth, and the set of fainter lines represent samples from the posterior.

We can also look at marginal distributions for different parameters...

plot(
    samples, :(λ_g1),
    mean=true, std=true,
    nbins=50,
)

Full model with all components

We can now extend this approach to the full 9 components, there is not so much documentation here, as it follows the above case.

Forward model

Here the parametrisation of each component in decreasing order of importance. We also use the helper functions provided by PartonDensity to keep things tidy.

using PartonDensity
weights = [1, 0.5, 0.3, 0.2, 0.1, 0.1, 0.1]
λ_u = 0.7;
K_u = 4.0;
λ_d = 0.5;
K_d = 6.0;
θ = get_θ_val(rng, λ_u, K_u, λ_d, K_d, weights)
pdf_params = ValencePDFParams(λ_u=λ_u, K_u=K_u, λ_d=λ_d, K_d=K_d,
    λ_g1=0.7, λ_g2=-0.4, K_g=6.0, λ_q=-0.5, K_q=5.0,
    θ=θ)
ValencePDFParams{Float64, Vector{Float64}}
  param_type: Int64 1
  λ_u: Float64 0.7
  K_u: Float64 4.0
  λ_d: Float64 0.5
  K_d: Float64 6.0
  λ_g1: Float64 0.7
  λ_g2: Float64 -0.4
  K_g: Float64 6.0
  λ_q: Float64 -0.5
  K_q: Float64 5.0
  θ: Array{Float64}((7,)) [0.08994884796048978, 0.57869090326387, 0.007268984779449628, 0.005171754145412688, 3.2808103389797114e-9, 0.006638116704494623, 6.881110871676874e-7]

Sanity check

int_xtotx(pdf_params)
0.9999999999999999

Plot true model

plot_input_pdfs(pdf_params)

Generate example data

bins = 0.0:0.05:1.0
bin_widths = bins[2:end] - bins[1:end-1]
bin_centers = (bins[1:end-1] + bins[2:end]) / 2

N = 1000
nbins = size(bin_centers)[1]

expected_counts = zeros(nbins)
observed_counts = zeros(Integer, nbins)
for i in 1:nbins
    xt = xtotx(bin_centers[i], pdf_params) * N
    expected_counts[i] = bin_widths[i] * xt
    observed_counts[i] = rand(rng, Poisson(expected_counts[i]))
end

Plot data and expectation

plot(bin_centers, [xtotx(x, pdf_params) for x in bin_centers] .* bin_widths * N,
    alpha=0.7, label="Expected", lw=3, color="red")
scatter!(bin_centers, observed_counts, lw=3, label="Observed", color="black")

Store the data

data = Dict()
data["N"] = N
data["bin_centers"] = bin_centers;
data["observed_counts"] = observed_counts;
data["bin_widths"] = bin_widths;

Fit

Prior

prior = NamedTupleDist(
    θ=Dirichlet(weights),
    λ_u=Truncated(Normal(pdf_params.λ_u, 0.5), 0, 1), #  Uniform(0, 1),
    K_u=Truncated(Normal(pdf_params.K_u, 1), 2, 10),
    λ_d=Truncated(Normal(pdf_params.λ_d, 0.5), 0, 1), # Uniform(0, 1),
    K_d=Truncated(Normal(pdf_params.K_d, 1), 2, 10),
    λ_g1=Truncated(Normal(pdf_params.λ_g1, 1), 0, 1),
    λ_g2=Truncated(Normal(pdf_params.λ_g2, 1), -1, 0),
    K_g=Truncated(Normal(pdf_params.K_g, 1), 2, 10),
    λ_q=Truncated(Normal(pdf_params.λ_q, 0.1), -1, 0),
    K_q=Truncated(Normal(pdf_params.K_q, 0.5), 3, 7)
);

Likelihood

likelihood = let d = data, f = PartonDensity.xtotx_valence

    observed_counts = d["observed_counts"]
    bin_centers = d["bin_centers"]
    bin_widths = d["bin_widths"]
    N = data["N"]

    logfuncdensity(function (params)
        function bin_log_likelihood(i)
            xt = f(bin_centers[i], params.λ_u, params.K_u, params.λ_d, params.K_d,
                params.λ_g1, params.λ_g2, params.K_g, params.λ_q, params.K_q, Vector(params.θ))
            expected_counts = bin_widths[i] * xt * N
            if expected_counts < 0
                expected_counts = 1e-3
            end
            logpdf(Poisson(expected_counts), observed_counts[i])
        end

        idxs = eachindex(observed_counts)
        ll_value = bin_log_likelihood(idxs[1])
        for i in idxs[2:end]
            ll_value += bin_log_likelihood(i)
        end

        return ll_value
    end)
end
LogFuncDensity(Main.var"#18#19"{Int64, StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}, StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}, Vector{Integer}, typeof(PartonDensity.xtotx_valence)}(1000, StepRangeLen(0.05, 0.0, 20), 0.025:0.05:0.975, Integer[277, 186, 140, 89, 90, 66, 34, 38, 33, 15, 12, 4, 1, 3, 1, 0, 0, 0, 0, 0], PartonDensity.xtotx_valence))

Run fit

posterior = PosteriorDensity(likelihood, prior);
samples = bat_sample(posterior, MCMCSampling(mcalg=MetropolisHastings(), nsteps=10^4, nchains=2)).result;
[ Info: Initializing new RNG of type Random123.Philox4x{UInt64, 10}
[ Info: MCMCChainPoolInit: trying to generate 2 viable MCMC chain(s).
[ Info: Selected 2 MCMC chain(s).
[ Info: Begin tuning of 2 MCMC chain(s).
[ Info: MCMC Tuning cycle 1 finished, 2 chains, 0 tuned, 0 converged.
[ Info: MCMC Tuning cycle 2 finished, 2 chains, 1 tuned, 0 converged.
[ Info: MCMC Tuning cycle 3 finished, 2 chains, 2 tuned, 0 converged.
[ Info: MCMC Tuning cycle 4 finished, 2 chains, 2 tuned, 0 converged.
[ Info: MCMC Tuning cycle 5 finished, 2 chains, 1 tuned, 0 converged.
[ Info: MCMC Tuning cycle 6 finished, 2 chains, 2 tuned, 0 converged.
[ Info: MCMC Tuning cycle 7 finished, 2 chains, 2 tuned, 0 converged.
[ Info: MCMC Tuning cycle 8 finished, 2 chains, 2 tuned, 2 converged.
[ Info: MCMC tuning of 2 chains successful after 8 cycle(s).
[ Info: Running post-tuning stabilization steps for 2 MCMC chain(s).

Visualise results

x_grid = range(0, stop=1, length=50)
sub_samples = bat_sample(samples, OrderedResampling(nsamples=200)).result

plot()
for i in eachindex(sub_samples)
    s = sub_samples[i].v
    xt = [PartonDensity.xtotx_valence(x, s.λ_u, s.K_u, s.λ_d, s.K_d,
        s.λ_g1, s.λ_g2, s.K_g, s.λ_q, s.K_q, Vector(s.θ)) for x in bin_centers]
    plot!(bin_centers, xt .* bin_widths * N, alpha=0.1, lw=3,
        color="darkorange", label="")
end
xt = [xtotx(x, pdf_params) for x in bin_centers]
plot!(bin_centers, xt .* bin_widths * N, alpha=0.7, label="Expected", lw=3, color="red")

scatter!(bin_centers, observed_counts, lw=3, label="Observed", color="black")
plot!(xlabel="x")

These first results are promising. We can also try changing the input parameters and priors to explore the performance of the fit.


This page was generated using Literate.jl.