Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

axes(::ReshapedDistribution) should throw, added missing ndims #1892

Closed
wants to merge 12 commits into from
1 change: 1 addition & 0 deletions src/reshaped.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ end

Base.size(d::ReshapedDistribution) = d.dims
Base.eltype(::Type{ReshapedDistribution{<:Any,<:ValueSupport,D}}) where {D} = eltype(D)
Base.ndims(d::ReshapedDistribution{N}) where {N} = N

partype(d::ReshapedDistribution) = partype(d.dist)
params(d::ReshapedDistribution) = (d.dist, d.dims)
Expand Down
6 changes: 6 additions & 0 deletions test/matrixreshaped.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,18 @@ function test_matrixreshaped(rng, d1, sizes)
for (d, s) in zip(d1s[1:end-1], sizes[1:end-1])
@test size(d) == s
end
@test size(d1s[end]) == (sizes[end][1], sizes[end][1])
end
@testset "MatrixReshaped length" begin
for d in d1s
@test length(d) == length(d1)
end
end
@testset "MatrixReshaped ndims" begin
for d in d1s
@test ndims(d) == 2
end
end
@testset "MatrixReshaped rank" begin
for (d, s) in zip(d1s, sizes)
@test rank(d) == minimum(s)
Expand Down
Loading