Prior model implementation in BAT.jl
In this notebook, the implementation of the prior model is demonstrated, starting with a simple two-component model and scaling up to the full 9 component model.
using Distributions, StatsBase, LinearAlgebra
using Plots, SpecialFunctions, Printf, Random, ValueShapes
using BAT, DensityInterface
const sf = SpecialFunctions;
gr(fmt=:png);
rng = MersenneTwister(42)
Random.MersenneTwister(42)
Simple two-component model: gluons
First we start with a simpler two-component model, and show all the steps explicity for clarity.
Forward model
The gluon distributions are parametrised by
\[x g(x) = A_{g1} x^{\lambda_g1}(1-x)^{K_g} + A_{g2} x^{\lambda_{g2}}(1-x)^5.\]
We also want to impose
\[\int_0^1 x g(x) dx = A_{g1} B(\lambda_{g1}+1, K_g+1) + A_{g2} B(\lambda_{g2}+1, 1+5) = 1,\]
with B(.,.) the Beta function.
Start by defining some useful functions:
function xg1x(x, λ_g1, K_g, θ_1)
A_g1 = θ_1 / sf.beta(λ_g1 + 1, K_g + 1)
return A_g1 * x^λ_g1 * (1 - x)^K_g
end
function xg2x(x, λ_g2, K_q, θ_2)
A_g2 = θ_2 / sf.beta(λ_g2 + 1, K_q + 1)
return A_g2 * x^λ_g2 * (1 - x)^K_q
end
function xgx(x, λ_g1, λ_g2, K_g, K_q, θ)
xg1 = xg1x(x, λ_g1, K_g, θ[1])
xg2 = xg2x(x, λ_g2, K_q, θ[2])
return xg1 + xg2
end
xgx (generic function with 1 method)
Choose true values for the high-level parameters and show what the resulting model looks like.
θ = [0.5, 0.5]
λ_g1 = 0.5 # rand(rng, Uniform(0, 1))
λ_g2 = -0.7 # rand(rng, Uniform(-1, 0))
K_g = 3 # rand(rng, Uniform(2, 10))
K_q = 5
truths = (θ=θ, λ_g1=λ_g1, λ_g2=λ_g2, K_g=K_g, K_q=K_q);
A_g1 = θ[1] / sf.beta(λ_g1 + 1, K_g + 1)
A_g2 = θ[2] / sf.beta(λ_g2 + 1, K_q + 1);
Check integral = 1
total = A_g1 * sf.beta(λ_g1 + 1, K_g + 1) + A_g2 * sf.beta(λ_g2 + 1, K_q + 1)
print("Integral = ", total)
Integral = 1.0
Plot true model
x_grid = range(0, stop=1, length=50)
xg1 = A_g1 * x_grid .^ λ_g1 .* (1 .- x_grid) .^ K_g
xg2 = A_g2 * x_grid .^ λ_g2 .* (1 .- x_grid) .^ K_q
plot(x_grid, [xg1x(x, λ_g1, K_g, θ[1]) for x in x_grid],
alpha=0.7, label="x g1(x)", lw=3, color="green")
plot!(x_grid, [xg2x(x, λ_g2, K_q, θ[2]) for x in x_grid],
alpha=0.7, label="x g2(x)", lw=3, color="blue")
plot!(x_grid, [xgx(x, λ_g1, λ_g2, K_g, K_q, θ) for x in x_grid],
alpha=0.7, label="x g1(x) + x g2(x)", lw=3, color="red")
plot!(xlabel="x")
Now, for the purposes of testing the prior implementation, sample some data from this distribution assuming that the data are produced by integrating over the function in different bins, and multiplying by some factor. Then, plot the model and data to compare.
bins = 0.0:0.05:1.0
bin_widths = bins[2:end] - bins[1:end-1]
bin_centers = (bins[1:end-1] + bins[2:end]) / 2
N = 1000
nbins = size(bin_centers)[1]
expected_counts = zeros(nbins)
observed_counts = zeros(Integer, nbins)
for i in 1:nbins
xg = xgx(bin_centers[i], λ_g1, λ_g2, K_g, K_q, θ) * N
expected_counts[i] = bin_widths[i] * xg
observed_counts[i] = rand(rng, Poisson(expected_counts[i]))
end
plot(bin_centers, [xgx(x, λ_g1, λ_g2, K_g, K_q, θ) for x in bin_centers] .* bin_widths * N,
alpha=0.7, label="Expected", lw=3, color="red")
scatter!(bin_centers, observed_counts, lw=3, label="Observed", color="black")
Store the data in a simple dict to pass to the likelihood later.
data = Dict()
data["N"] = N
data["bin_centers"] = bin_centers;
data["observed_counts"] = observed_counts;
data["bin_widths"] = bin_widths;
Fit
To fit this example data, we choose a prior over our hyperparameters θ
, λ_g1
, λ_g2
, K_g
and K_q
.
We decide to choose a sensible Dirichlet prior, and have a look at some samples to help understand what this means.
dirichlet = Dirichlet([1, 1])
test = rand(rng, dirichlet, 1000)
plot(append!(Histogram(0:0.1:1), test[1, :]))
plot!(append!(Histogram(0:0.1:1), test[2, :]))
So, we see that this means all weights are equally likely for both components.
Now we can define the prior:
\[\mathrm{Prior} - P(\theta) P(\lambda_{g1}) P(\lambda_{g2}) P(K_g).\]
prior = NamedTupleDist(
θ=Dirichlet([1, 1]),
λ_g1=Uniform(0, 1), #Truncated(Normal(0.7, 0.1), 0, 1)
λ_g2=Uniform(-1, 0), #Truncated(Normal(-0.7, 0.1), -1, 0),
K_g=Uniform(2, 10), # Truncated(Normal(1, 2), 2, 10),
K_q=Uniform(4, 6),
);
Define a simple poisson likelihood that describes the test data that we generated above.
likelihood = let d = data, f = xgx
observed_counts = d["observed_counts"]
bin_centers = d["bin_centers"]
bin_widths = d["bin_widths"]
N = data["N"]
logfuncdensity(function (params)
function bin_log_likelihood(i)
xg = f(
bin_centers[i], params.λ_g1, params.λ_g2, params.K_g, params.K_q, params.θ
)
expected_counts = bin_widths[i] * xg * N
logpdf(Poisson(expected_counts), observed_counts[i])
end
idxs = eachindex(observed_counts)
ll_value = bin_log_likelihood(idxs[1])
for i in idxs[2:end]
ll_value += bin_log_likelihood(i)
end
return ll_value
end)
end
LogFuncDensity(Main.var"#9#10"{Int64, StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}, StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}, Vector{Integer}, typeof(Main.xgx)}(1000, StepRangeLen(0.05, 0.0, 20), 0.025:0.05:0.975, Integer[191, 111, 86, 66, 66, 36, 54, 40, 35, 33, 21, 11, 5, 4, 3, 3, 2, 1, 0, 0], Main.xgx))
The prior and likelihood can be passed to BAT.jl
via the PosteriorDensity
. We can then sample this density using bat_sample()
.
posterior = PosteriorDensity(likelihood, prior);
samples = bat_sample(
posterior,
MCMCSampling(mcalg=MetropolisHastings(), nsteps=10^4, nchains=4)
).result;
[ Info: Initializing new RNG of type Random123.Philox4x{UInt64, 10}
[ Info: MCMCChainPoolInit: trying to generate 4 viable MCMC chain(s).
[ Info: Selected 4 MCMC chain(s).
[ Info: Begin tuning of 4 MCMC chain(s).
[ Info: MCMC Tuning cycle 1 finished, 4 chains, 0 tuned, 0 converged.
[ Info: MCMC Tuning cycle 2 finished, 4 chains, 1 tuned, 0 converged.
[ Info: MCMC Tuning cycle 3 finished, 4 chains, 4 tuned, 4 converged.
[ Info: MCMC tuning of 4 chains successful after 3 cycle(s).
[ Info: Running post-tuning stabilization steps for 4 MCMC chain(s).
The SampledDensity
gives a quick overview of the results.
SampledDensity(posterior, samples)
SampledMeasure(objectid = 0x5b94602a29c73dd6, varshape = NamedTupleShape((:θ, :λ_g1, :λ_g2, :K_g, :K_q)}(…))
Visualise results
We can (roughly) check how well the fit reconstructs the truth with a simple comparison.
x_grid = range(0, stop=1, length=50)
sub_samples = bat_sample(samples, OrderedResampling(nsamples=200)).result
plot()
for i in eachindex(sub_samples)
s = sub_samples[i].v
xg = [xgx(x, s.λ_g1, s.λ_g2, s.K_g, s.K_q, s.θ) for x in bin_centers]
plot!(bin_centers, xg .* bin_widths * N, alpha=0.1, lw=3,
color="darkorange", label="",)
end
xg = [xgx(x, λ_g1, λ_g2, K_g, K_q, θ) for x in bin_centers]
plot!(bin_centers, xg .* bin_widths * N, alpha=0.7,
label="Expected", lw=3, color="red")
scatter!(bin_centers, observed_counts, lw=3, label="Observed", color="black")
plot!(xlabel="x")
In the above plot, the red line represents the truth, and the set of fainter lines represent samples from the posterior.
We can also look at marginal distributions for different parameters...
plot(
samples, :(λ_g1),
mean=true, std=true,
nbins=50,
)
Full model with all components
We can now extend this approach to the full 9 components, there is not so much documentation here, as it follows the above case.
Forward model
Here the parametrisation of each component in decreasing order of importance. We also use the helper functions provided by PartonDensity
to keep things tidy.
using PartonDensity
weights = [1, 0.5, 0.3, 0.2, 0.1, 0.1, 0.1]
λ_u = 0.7;
K_u = 4.0;
λ_d = 0.5;
K_d = 6.0;
θ = get_θ_val(rng, λ_u, K_u, λ_d, K_d, weights)
pdf_params = ValencePDFParams(λ_u=λ_u, K_u=K_u, λ_d=λ_d, K_d=K_d,
λ_g1=0.7, λ_g2=-0.4, K_g=6.0, λ_q=-0.5, K_q=5.0,
θ=θ)
ValencePDFParams{Float64, Vector{Float64}}
param_type: Int64 1
λ_u: Float64 0.7
K_u: Float64 4.0
λ_d: Float64 0.5
K_d: Float64 6.0
λ_g1: Float64 0.7
λ_g2: Float64 -0.4
K_g: Float64 6.0
λ_q: Float64 -0.5
K_q: Float64 5.0
θ: Array{Float64}((7,)) [0.08994884796048978, 0.57869090326387, 0.007268984779449628, 0.005171754145412688, 3.2808103389797114e-9, 0.006638116704494623, 6.881110871676874e-7]
Sanity check
int_xtotx(pdf_params)
0.9999999999999999
Plot true model
plot_input_pdfs(pdf_params)
Generate example data
bins = 0.0:0.05:1.0
bin_widths = bins[2:end] - bins[1:end-1]
bin_centers = (bins[1:end-1] + bins[2:end]) / 2
N = 1000
nbins = size(bin_centers)[1]
expected_counts = zeros(nbins)
observed_counts = zeros(Integer, nbins)
for i in 1:nbins
xt = xtotx(bin_centers[i], pdf_params) * N
expected_counts[i] = bin_widths[i] * xt
observed_counts[i] = rand(rng, Poisson(expected_counts[i]))
end
Plot data and expectation
plot(bin_centers, [xtotx(x, pdf_params) for x in bin_centers] .* bin_widths * N,
alpha=0.7, label="Expected", lw=3, color="red")
scatter!(bin_centers, observed_counts, lw=3, label="Observed", color="black")
Store the data
data = Dict()
data["N"] = N
data["bin_centers"] = bin_centers;
data["observed_counts"] = observed_counts;
data["bin_widths"] = bin_widths;
Fit
Prior
prior = NamedTupleDist(
θ=Dirichlet(weights),
λ_u=Truncated(Normal(pdf_params.λ_u, 0.5), 0, 1), # Uniform(0, 1),
K_u=Truncated(Normal(pdf_params.K_u, 1), 2, 10),
λ_d=Truncated(Normal(pdf_params.λ_d, 0.5), 0, 1), # Uniform(0, 1),
K_d=Truncated(Normal(pdf_params.K_d, 1), 2, 10),
λ_g1=Truncated(Normal(pdf_params.λ_g1, 1), 0, 1),
λ_g2=Truncated(Normal(pdf_params.λ_g2, 1), -1, 0),
K_g=Truncated(Normal(pdf_params.K_g, 1), 2, 10),
λ_q=Truncated(Normal(pdf_params.λ_q, 0.1), -1, 0),
K_q=Truncated(Normal(pdf_params.K_q, 0.5), 3, 7)
);
Likelihood
likelihood = let d = data, f = PartonDensity.xtotx_valence
observed_counts = d["observed_counts"]
bin_centers = d["bin_centers"]
bin_widths = d["bin_widths"]
N = data["N"]
logfuncdensity(function (params)
function bin_log_likelihood(i)
xt = f(bin_centers[i], params.λ_u, params.K_u, params.λ_d, params.K_d,
params.λ_g1, params.λ_g2, params.K_g, params.λ_q, params.K_q, Vector(params.θ))
expected_counts = bin_widths[i] * xt * N
if expected_counts < 0
expected_counts = 1e-3
end
logpdf(Poisson(expected_counts), observed_counts[i])
end
idxs = eachindex(observed_counts)
ll_value = bin_log_likelihood(idxs[1])
for i in idxs[2:end]
ll_value += bin_log_likelihood(i)
end
return ll_value
end)
end
LogFuncDensity(Main.var"#18#19"{Int64, StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}, StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}, Vector{Integer}, typeof(PartonDensity.xtotx_valence)}(1000, StepRangeLen(0.05, 0.0, 20), 0.025:0.05:0.975, Integer[277, 186, 140, 89, 90, 66, 34, 38, 33, 15, 12, 4, 1, 3, 1, 0, 0, 0, 0, 0], PartonDensity.xtotx_valence))
Run fit
posterior = PosteriorDensity(likelihood, prior);
samples = bat_sample(posterior, MCMCSampling(mcalg=MetropolisHastings(), nsteps=10^4, nchains=2)).result;
[ Info: Initializing new RNG of type Random123.Philox4x{UInt64, 10}
[ Info: MCMCChainPoolInit: trying to generate 2 viable MCMC chain(s).
[ Info: Selected 2 MCMC chain(s).
[ Info: Begin tuning of 2 MCMC chain(s).
[ Info: MCMC Tuning cycle 1 finished, 2 chains, 0 tuned, 0 converged.
[ Info: MCMC Tuning cycle 2 finished, 2 chains, 1 tuned, 0 converged.
[ Info: MCMC Tuning cycle 3 finished, 2 chains, 2 tuned, 0 converged.
[ Info: MCMC Tuning cycle 4 finished, 2 chains, 2 tuned, 0 converged.
[ Info: MCMC Tuning cycle 5 finished, 2 chains, 1 tuned, 0 converged.
[ Info: MCMC Tuning cycle 6 finished, 2 chains, 2 tuned, 0 converged.
[ Info: MCMC Tuning cycle 7 finished, 2 chains, 2 tuned, 0 converged.
[ Info: MCMC Tuning cycle 8 finished, 2 chains, 2 tuned, 2 converged.
[ Info: MCMC tuning of 2 chains successful after 8 cycle(s).
[ Info: Running post-tuning stabilization steps for 2 MCMC chain(s).
Visualise results
x_grid = range(0, stop=1, length=50)
sub_samples = bat_sample(samples, OrderedResampling(nsamples=200)).result
plot()
for i in eachindex(sub_samples)
s = sub_samples[i].v
xt = [PartonDensity.xtotx_valence(x, s.λ_u, s.K_u, s.λ_d, s.K_d,
s.λ_g1, s.λ_g2, s.K_g, s.λ_q, s.K_q, Vector(s.θ)) for x in bin_centers]
plot!(bin_centers, xt .* bin_widths * N, alpha=0.1, lw=3,
color="darkorange", label="")
end
xt = [xtotx(x, pdf_params) for x in bin_centers]
plot!(bin_centers, xt .* bin_widths * N, alpha=0.7, label="Expected", lw=3, color="red")
scatter!(bin_centers, observed_counts, lw=3, label="Observed", color="black")
plot!(xlabel="x")
These first results are promising. We can also try changing the input parameters and priors to explore the performance of the fit.
This page was generated using Literate.jl.