From 4ba2714c95d19cdb772d0753e2e33b96c2ccefcb Mon Sep 17 00:00:00 2001 From: Kenta Sato Date: Mon, 16 Oct 2023 17:02:39 +0200 Subject: [PATCH 1/3] k-dimensional OU process --- src/continuous.jl | 27 ++++++++++++++++++--------- src/randomvariable.jl | 7 ++++--- test/runtests.jl | 19 +++++++++++++++++++ 3 files changed, 41 insertions(+), 12 deletions(-) diff --git a/src/continuous.jl b/src/continuous.jl index a5d04ad..a5c7c60 100644 --- a/src/continuous.jl +++ b/src/continuous.jl @@ -1,11 +1,14 @@ -#OU process -struct OrnsteinUhlenbeckDiffusion{T <: Real} <: GaussianStateProcess +# OU process +struct OrnsteinUhlenbeckDiffusion{T} <: GaussianStateProcess mean::T volatility::T reversion::T end -OrnsteinUhlenbeckDiffusion(mean::Real, volatility::Real, reversion::Real) = OrnsteinUhlenbeckDiffusion(float.(promote(mean, volatility, reversion))...) +function OrnsteinUhlenbeckDiffusion(mean::Real, volatility::Real, reversion::Real) + μ, σ, θ = float.(promote(mean, volatility, reversion)) + return OrnsteinUhlenbeckDiffusion{typeof(μ)}(μ, σ, θ) +end OrnsteinUhlenbeckDiffusion(mean::T) where T <: Real = OrnsteinUhlenbeckDiffusion(mean,T(1.0),T(0.5)) @@ -13,19 +16,25 @@ var(model::OrnsteinUhlenbeckDiffusion) = (model.volatility^2) / (2 * model.rever eq_dist(model::OrnsteinUhlenbeckDiffusion) = Normal(model.mean,sqrt(var(model))) -function forward(process::OrnsteinUhlenbeckDiffusion, x_s::AbstractArray, s::Real, t::Real) +# These are for nested broadcasting +elmwisesqrt(x) = sqrt.(x) +elmwiseinv(x) = inv.(x) +elmwisemul(x, y) = x .* y +elmwisediv(x, y) = x ./ y + +function forward(process::OrnsteinUhlenbeckDiffusion{T}, x_s::AbstractArray{T}, s::Real, t::Real) where T μ, σ, θ = process.mean, process.volatility, process.reversion - mean = @. exp(-(t - s) * θ) * (x_s - μ) + μ + mean = elmwisemul.((exp.(-(t - s) * θ),), x_s .- (μ,)) .+ (μ,) var = similar(mean) - var .= ((1 - exp(-2(t - s) * θ)) * σ^2) / 2θ + fill!(var, @. ((1 - exp(-2(t - s) * θ)) * σ^2) / 2θ) return GaussianVariables(mean, var) end -function backward(process::OrnsteinUhlenbeckDiffusion, x_t::AbstractArray, s::Real, t::Real) +function backward(process::OrnsteinUhlenbeckDiffusion{T}, x_t::AbstractArray{T}, s::Real, t::Real) where T μ, σ, θ = process.mean, process.volatility, process.reversion - mean = @. exp((t - s) * θ) * (x_t - μ) + μ + mean = elmwisemul.((exp.((t - s) * θ),), x_t .- (μ,)) .+ (μ,) var = similar(mean) - var .= -(σ^2 / 2θ) + (σ^2 * exp(2(t - s) * θ)) / 2θ + fill!(var, @. -(σ^2 / 2θ) + (σ^2 * exp(2(t - s) * θ)) / 2θ) return (μ = mean, σ² = var) end diff --git a/src/randomvariable.jl b/src/randomvariable.jl index 89892a0..92689ad 100644 --- a/src/randomvariable.jl +++ b/src/randomvariable.jl @@ -9,11 +9,12 @@ end Base.size(X::GaussianVariables) = size(X.μ) -sample(rng::AbstractRNG, X::GaussianVariables{T}) where T = randn(rng, T, size(X)) .* .√X.σ² .+ X.μ +sample(rng::AbstractRNG, X::GaussianVariables{T}) where T = + elmwisemul.(randn(rng, T, size(X)), elmwisesqrt.(X.σ²)) .+ X.μ function combine(X::GaussianVariables, lik) - σ² = @. inv(inv(X.σ²) + inv(lik.σ²)) - μ = @. σ² * (X.μ / X.σ² + lik.μ / lik.σ²) + σ² = elmwiseinv.(elmwiseinv.(X.σ²) .+ elmwiseinv.(lik.σ²)) + μ = elmwisemul.(σ², elmwisediv.(X.μ, X.σ²) .+ elmwisediv.(lik.μ, lik.σ²)) return GaussianVariables(μ, σ²) end diff --git a/test/runtests.jl b/test/runtests.jl index c773432..8034628 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -74,6 +74,16 @@ end x = one(QuatRotation{Float32}) t = 0.29999998f0 @test sampleforward(diffusion, t, [x]) isa Vector + + # three-dimensional diffusion + μ = @SVector [0.0, 0.0, 0.0] + θ = @SVector [1.0, 1.0, 1.0] + σ = @SVector [0.5, 0.5, 0.5] + x_0 = [zero(μ), zero(μ)] + diffusion = OrnsteinUhlenbeckDiffusion(μ, σ, θ) + x_t = sampleforward(diffusion, 1.0, x_0) + @test x_t isa typeof(x_0) + @test size(x_t) == size(x_0) end @testset "Discrete Diffusions" begin @@ -175,6 +185,15 @@ end x = samplebackward((x, t) -> x + randn(size(x)), process, [1/8, 1/4, 1/2, 1/1], x_t) @test size(x) == size(x_t) @test x isa Matrix + + μ = @SVector [0.0, 0.0, 0.0] + θ = @SVector [1.0, 1.0, 1.0] + σ = @SVector [0.5, 0.5, 0.5] + x_t = randn(typeof(μ), 4, 10) + process = OrnsteinUhlenbeckDiffusion(μ, σ, θ) + x = samplebackward((x, t) -> x + randn(eltype(x), size(x)), process, [1/8, 1/4, 1/2, 1/1], x_t) + @test size(x) == size(x_t) + @test x isa Matrix end @testset "Masked Diffusion" begin From 608e6de9076dbbb2439b4b554fbc9c71d01f8bbb Mon Sep 17 00:00:00 2001 From: Kenta Sato Date: Fri, 20 Oct 2023 19:35:27 +0000 Subject: [PATCH 2/3] fix --- src/continuous.jl | 22 +++++++++++----------- src/loss.jl | 10 ++++++++++ src/randomvariable.jl | 12 ++++++------ test/runtests.jl | 37 +++++++++++++++++++++---------------- 4 files changed, 48 insertions(+), 33 deletions(-) diff --git a/src/continuous.jl b/src/continuous.jl index a5c7c60..13c8339 100644 --- a/src/continuous.jl +++ b/src/continuous.jl @@ -1,5 +1,5 @@ # OU process -struct OrnsteinUhlenbeckDiffusion{T} <: GaussianStateProcess +struct OrnsteinUhlenbeckDiffusion{T <: Real} <: GaussianStateProcess mean::T volatility::T reversion::T @@ -17,24 +17,24 @@ var(model::OrnsteinUhlenbeckDiffusion) = (model.volatility^2) / (2 * model.rever eq_dist(model::OrnsteinUhlenbeckDiffusion) = Normal(model.mean,sqrt(var(model))) # These are for nested broadcasting -elmwisesqrt(x) = sqrt.(x) -elmwiseinv(x) = inv.(x) +elmwiseadd(x, y) = x .+ y +elmwisesub(x, y) = x .- y elmwisemul(x, y) = x .* y elmwisediv(x, y) = x ./ y -function forward(process::OrnsteinUhlenbeckDiffusion{T}, x_s::AbstractArray{T}, s::Real, t::Real) where T +function forward(process::OrnsteinUhlenbeckDiffusion, x_s::AbstractArray, s::Real, t::Real) μ, σ, θ = process.mean, process.volatility, process.reversion - mean = elmwisemul.((exp.(-(t - s) * θ),), x_s .- (μ,)) .+ (μ,) - var = similar(mean) - fill!(var, @. ((1 - exp(-2(t - s) * θ)) * σ^2) / 2θ) + # exp(-(t - s) * θ) * (x_s .- μ) .+ μ + mean = elmwiseadd.(elmwisemul.(exp(-(t - s) * θ), elmwisesub.(x_s, μ)), μ) + var = ((1 - exp(-2(t - s) * θ)) * σ^2) / 2θ return GaussianVariables(mean, var) end -function backward(process::OrnsteinUhlenbeckDiffusion{T}, x_t::AbstractArray{T}, s::Real, t::Real) where T +function backward(process::OrnsteinUhlenbeckDiffusion, x_t::AbstractArray, s::Real, t::Real) μ, σ, θ = process.mean, process.volatility, process.reversion - mean = elmwisemul.((exp.((t - s) * θ),), x_t .- (μ,)) .+ (μ,) - var = similar(mean) - fill!(var, @. -(σ^2 / 2θ) + (σ^2 * exp(2(t - s) * θ)) / 2θ) + # @. exp((t - s) * θ) * (x_t - μ) + μ + mean = elmwiseadd.(elmwisemul.(exp((t - s) * θ), elmwisesub.(x_t, μ)), μ) + var = -(σ^2 / 2θ) + (σ^2 * exp(2(t - s) * θ)) / 2θ return (μ = mean, σ² = var) end diff --git a/src/loss.jl b/src/loss.jl index 6b38f91..8c371e4 100644 --- a/src/loss.jl +++ b/src/loss.jl @@ -37,6 +37,16 @@ function standardloss( return scaledloss(loss(x̂, parent(x)), maskedindices(x), (t -> scaler(p, t)).(t)) end +function standardloss( + p::OrnsteinUhlenbeckDiffusion, + t::Union{Real,AbstractVector{<:Real}}, + x̂::AbstractArray{<: SVector}, x::AbstractArray{<: SVector}; + scaler=defaultscaler) + loss(x̂, x) = norm.(x̂ .- x).^2 + # ugly syntax but scaler.(p, t) is not differentiable with Zygote.jl for some reason + return scaledloss(loss(x̂, parent(x)), maskedindices(x), (t -> scaler(p, t)).(t)) +end + defaultscaler(p::RotationDiffusion, t::Real) = sqrt(1 - exp(-t * p.rate * 5)) function standardloss( diff --git a/src/randomvariable.jl b/src/randomvariable.jl index 92689ad..4b277b3 100644 --- a/src/randomvariable.jl +++ b/src/randomvariable.jl @@ -1,19 +1,19 @@ # Random Variables # ---------------- -struct GaussianVariables{T, A <: AbstractArray{T}} +struct GaussianVariables{A, B} # μ and σ² must have the same size - μ::A # mean - σ²::A # variance + μ::A # mean (array) + σ²::B # variance (scalar) end Base.size(X::GaussianVariables) = size(X.μ) -sample(rng::AbstractRNG, X::GaussianVariables{T}) where T = - elmwisemul.(randn(rng, T, size(X)), elmwisesqrt.(X.σ²)) .+ X.μ +sample(rng::AbstractRNG, X::GaussianVariables) = + elmwisemul.(randn(rng, eltype(X.μ), size(X)), √X.σ²) .+ X.μ function combine(X::GaussianVariables, lik) - σ² = elmwiseinv.(elmwiseinv.(X.σ²) .+ elmwiseinv.(lik.σ²)) + σ² = inv(inv(X.σ²) + inv(lik.σ²)) μ = elmwisemul.(σ², elmwisediv.(X.μ, X.σ²) .+ elmwisediv.(lik.μ, lik.σ²)) return GaussianVariables(μ, σ²) end diff --git a/test/runtests.jl b/test/runtests.jl index 8034628..3890d27 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -76,11 +76,8 @@ end @test sampleforward(diffusion, t, [x]) isa Vector # three-dimensional diffusion - μ = @SVector [0.0, 0.0, 0.0] - θ = @SVector [1.0, 1.0, 1.0] - σ = @SVector [0.5, 0.5, 0.5] - x_0 = [zero(μ), zero(μ)] - diffusion = OrnsteinUhlenbeckDiffusion(μ, σ, θ) + diffusion = OrnsteinUhlenbeckDiffusion(0.0) + x_0 = fill(zero(SVector{3, Float64}), 2) x_t = sampleforward(diffusion, 1.0, x_0) @test x_t isa typeof(x_0) @test size(x_t) == size(x_0) @@ -186,11 +183,8 @@ end @test size(x) == size(x_t) @test x isa Matrix - μ = @SVector [0.0, 0.0, 0.0] - θ = @SVector [1.0, 1.0, 1.0] - σ = @SVector [0.5, 0.5, 0.5] - x_t = randn(typeof(μ), 4, 10) - process = OrnsteinUhlenbeckDiffusion(μ, σ, θ) + process = OrnsteinUhlenbeckDiffusion(0.0) + x_t = randn(SVector{3, Float64}, 4, 10) x = samplebackward((x, t) -> x + randn(eltype(x), size(x)), process, [1/8, 1/4, 1/2, 1/1], x_t) @test size(x) == size(x_t) @test x isa Matrix @@ -263,8 +257,8 @@ end end @testset "Loss" begin - p = OrnsteinUhlenbeckDiffusion(0.0, 1.0, 0.5) - x_0 = randn(5, 10) + p = OrnsteinUhlenbeckDiffusion(0.0) + x_0 = zeros(5, 10) t = rand(10) @test standardloss(p, t, x_0, x_0) == 0 x = rand(5, 10) @@ -272,14 +266,25 @@ end # unmasked elements don't contribute to the loss x = copy(x_0) - m = x_0 .< 0 + m = rand(size(x)...) .< 0.5 + x[.!m] .= 1 x_0 = mask(x_0, m) - x[.!m] .= 0 @test standardloss(p, t, x, x_0) == 0 + @test standardloss(p, t, x, parent(x_0)) > 0 - # but masked elements do - x[m] .= 0 + p = OrnsteinUhlenbeckDiffusion(0.0) + x_0 = fill(zero(SVector{3, Float64}), 10) + t = rand(10) + @test standardloss(p, t, x_0, x_0) == 0 + x = [rand(SVector{3, Float64}) for _ in eachindex(x_0)] @test standardloss(p, t, x, x_0) > 0 + + x = copy(x_0) + m = rand(size(x)...) .< 0.5 + x[.!m] .= (ones(SVector{3, Float64}),) + x_0 = mask(x_0, m) + @test standardloss(p, t, x, x_0) == 0 + @test standardloss(p, t, x, parent(x_0)) > 0 end @testset "Autodiff" begin From 855a847ac9b808bd54427f5cca7da7e78ec5b4ec Mon Sep 17 00:00:00 2001 From: Kenta Sato Date: Fri, 20 Oct 2023 19:37:59 +0000 Subject: [PATCH 3/3] revert minor changes --- src/continuous.jl | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/continuous.jl b/src/continuous.jl index 13c8339..8907694 100644 --- a/src/continuous.jl +++ b/src/continuous.jl @@ -5,10 +5,7 @@ struct OrnsteinUhlenbeckDiffusion{T <: Real} <: GaussianStateProcess reversion::T end -function OrnsteinUhlenbeckDiffusion(mean::Real, volatility::Real, reversion::Real) - μ, σ, θ = float.(promote(mean, volatility, reversion)) - return OrnsteinUhlenbeckDiffusion{typeof(μ)}(μ, σ, θ) -end +OrnsteinUhlenbeckDiffusion(mean::Real, volatility::Real, reversion::Real) = OrnsteinUhlenbeckDiffusion(float.(promote(mean, volatility, reversion))...) OrnsteinUhlenbeckDiffusion(mean::T) where T <: Real = OrnsteinUhlenbeckDiffusion(mean,T(1.0),T(0.5)) @@ -24,7 +21,7 @@ elmwisediv(x, y) = x ./ y function forward(process::OrnsteinUhlenbeckDiffusion, x_s::AbstractArray, s::Real, t::Real) μ, σ, θ = process.mean, process.volatility, process.reversion - # exp(-(t - s) * θ) * (x_s .- μ) .+ μ + # exp(-(t - s) * θ) * (x_s - μ) + μ mean = elmwiseadd.(elmwisemul.(exp(-(t - s) * θ), elmwisesub.(x_s, μ)), μ) var = ((1 - exp(-2(t - s) * θ)) * σ^2) / 2θ return GaussianVariables(mean, var)