diff --git a/ext/DistributionsChainRulesCoreExt/multivariate/dirichlet.jl b/ext/DistributionsChainRulesCoreExt/multivariate/dirichlet.jl index 5aa0d727d5..90868bf685 100644 --- a/ext/DistributionsChainRulesCoreExt/multivariate/dirichlet.jl +++ b/ext/DistributionsChainRulesCoreExt/multivariate/dirichlet.jl @@ -9,6 +9,9 @@ function ChainRulesCore.frule((_, Δalpha)::Tuple{Any,Any}, ::Type{DT}, alpha::A return d, Δd end +ChainRulesCore.frule(::ChainRulesCore.RuleConfig, Δ, pdf::Type{DT}, alpha::AbstractVector{T}; check_args::Bool = true) where {T <: Real, DT <: Union{Dirichlet{T}, Dirichlet}} = + ChainRulesCore.frule(Δ, pdf, alpha, check_args=check_args) + function ChainRulesCore.rrule(::Type{DT}, alpha::AbstractVector{T}; check_args::Bool = true) where {T <: Real, DT <: Union{Dirichlet{T}, Dirichlet}} d = DT(alpha; check_args=check_args) digamma_alpha0 = SpecialFunctions.digamma(d.alpha0) @@ -33,6 +36,9 @@ function ChainRulesCore.frule((_, Δd, Δx)::Tuple{Any,Any,Any}, ::typeof(Distri return Ω, ΔΩ end +ChainRulesCore.frule(::ChainRulesCore.RuleConfig, Δ, pdf::typeof(Distributions._logpdf), d::Dirichlet, x::AbstractVector{<:Real}) = + ChainRulesCore.frule(Δ, pdf, d, x) + function ChainRulesCore.rrule(::typeof(Distributions._logpdf), d::T, x::AbstractVector{<:Real}) where {T<:Dirichlet} Ω = Distributions._logpdf(d, x) isfinite_Ω = isfinite(Ω) diff --git a/ext/DistributionsChainRulesCoreExt/univariate/continuous/uniform.jl b/ext/DistributionsChainRulesCoreExt/univariate/continuous/uniform.jl index 0461329577..cbcdb32eda 100644 --- a/ext/DistributionsChainRulesCoreExt/univariate/continuous/uniform.jl +++ b/ext/DistributionsChainRulesCoreExt/univariate/continuous/uniform.jl @@ -12,6 +12,9 @@ function ChainRulesCore.frule((_, Δd, _), ::typeof(logpdf), d::Uniform, x::Real return Ω, ΔΩ end +ChainRulesCore.frule(::ChainRulesCore.RuleConfig, Δ, pdf::typeof(logpdf), d::Uniform, x::Real) = + ChainRulesCore.frule(Δ, pdf, d, x) + function ChainRulesCore.rrule(::typeof(logpdf), d::Uniform, x::Real) # Compute log probability a, b = params(d) diff --git a/ext/DistributionsChainRulesCoreExt/univariate/discrete/poissonbinomial.jl b/ext/DistributionsChainRulesCoreExt/univariate/discrete/poissonbinomial.jl index aa27a9fd94..fdc9be3785 100644 --- a/ext/DistributionsChainRulesCoreExt/univariate/discrete/poissonbinomial.jl +++ b/ext/DistributionsChainRulesCoreExt/univariate/discrete/poissonbinomial.jl @@ -8,6 +8,8 @@ for f in (:poissonbinomial_pdf, :poissonbinomial_pdf_fft) A = Distributions.poissonbinomial_pdf_partialderivatives(p) return y, A' * Δp end + ChainRulesCore.frule(::ChainRulesCore.RuleConfig, Δ, pdf::typeof(Distributions.$f), p::AbstractVector{<:Real}) = + ChainRulesCore.frule(Δ, pdf, p) function ChainRulesCore.rrule(::typeof(Distributions.$f), p::AbstractVector{<:Real}) y = Distributions.$f(p) A = Distributions.poissonbinomial_pdf_partialderivatives(p)