ChainRulesTestUtils

Travis Code Style: Blue

ChainRulesTestUtils.jl helps you test ChainRulesCore.frule and ChainRulesCore.rrule methods, when adding rules for your functions in your own packages. For information about ChainRules, including how to write rules, refer to the general ChainRules Documentation:

Canonical example

Let's suppose a custom transformation has been defined

function two2three(x1::Float64, x2::Float64)
    return 1.0, 2.0*x1, 3.0*x2
end

along with the frule

using ChainRulesCore

function ChainRulesCore.frule((Δf, Δx1, Δx2), ::typeof(two2three), x1, x2)
    y = two2three(x1, x2)
    ∂y = Composite{Tuple{Float64, Float64, Float64}}(Zero(), 2.0*Δx1, 3.0*Δx2)
    return y, ∂y
end

and rrule

function ChainRulesCore.rrule(::typeof(two2three), x1, x2)
    y = two2three(x1, x2)
    function two2three_pullback(Ȳ)
        return (NO_FIELDS, 2.0*Ȳ[2], 3.0*Ȳ[3])
    end
    return y, two2three_pullback
end

The frule_test/rrule_test helper function compares the frule/rrule outputs to the gradients obtained by finite differencing. They can be used for any type and number of inputs and outputs.

Testing the frule

frule_test takes in the function f and tuples (x, ẋ) for each function argument x. The call will test the frule for function f at the point x in the domain. Keep this in mind when testing discontinuous rules for functions like ReLU, which should ideally be tested at both x being above and below zero. Additionally, choosing in an unfortunate way (e.g. as zeros) could hide underlying problems with the defined frule.

using ChainRulesTestUtils

x1, x2 = (3.33, -7.77)
ẋ1, ẋ2 = (rand(), rand())

frule_test(two2three, (x1, ẋ1), (x2, ẋ2))

Testing the rrule

rrule_test takes in the function f, sensitivities of the function outputs , and tuples (x, x̄) for each function argument x. is the accumulated adjoint which can be set arbitrarily. The call will test the rrule for function f at the point x, and similarly to frule some rules should be tested at multiple points in the domain. Choosing in an unfortunate way (e.g. as zeros) could hide underlying problems with the rrule.

x1, x2 = (3.33, -7.77)
x̄1, x̄2 = (rand(), rand())
ȳs = (rand(), rand(), rand())

rrule_test(two2three, ȳs, (x1, x̄1), (x2, x̄2))

Scalar example

For functions with a single argument and a single output, such as e.g. ReLU,

function relu(x::Real)
    return max(0, x)
end

with the frule and rrule defined with the help of @scalar_rule macro

@scalar_rule relu(x::Real) x <= 0 ? zero(x) : one(x)

test_scalar function is provided to test both the frule and the rrule with a single call.

test_scalar(relu, 0.5)
test_scalar(relu, -0.5)

API Documentation

ChainRulesTestUtils.TestIteratorType
TestIterator{T,IS<:Base.IteratorSize,IE<:Base.IteratorEltype}

A configurable iterator for testing purposes.

TestIterator(data, itersize, itereltype)
TestIterator(data)

The iterator wraps another iterator data, such as an array, that must have at least as many features implemented as the test iterator and have a FiniteDifferences.to_vec overload. By default, the iterator it has the same features as data.

The optional methods eltype, length, and size are automatically defined and forwarded to data if the type arguments indicate that they should be defined.

source
ChainRulesTestUtils.check_equalMethod
check_equal(actual, expected; kwargs...)

@test's that actual ≈ expected, but breaks up data such that human readable results are shown on failures. Understands things like unthunking ChainRuleCore.Thunks, etc. All keyword arguments are passed to isapprox.

source
ChainRulesTestUtils.frule_testMethod
frule_test(f, (x, ẋ)...; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), fkwargs=NamedTuple(), check_inferred=true, kwargs...)

Arguments

  • f: Function for which the frule should be tested.
  • x: input at which to evaluate f (should generally be set to an arbitary point in the domain).
  • : differential w.r.t. x (should generally be set randomly).

fkwargs are passed to f as keyword arguments. If check_inferred=true, then the inferrability of the frule is checked, as long as f is itself inferrable. All remaining keyword arguments are passed to isapprox.

source
ChainRulesTestUtils.rrule_testMethod
rrule_test(f, ȳ, (x, x̄)...; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), fkwargs=NamedTuple(), check_inferred=true, kwargs...)

Arguments

  • f: Function to which rule should be applied.
  • : adjoint w.r.t. output of f (should generally be set randomly). Should be same structure as f(x) (so if multiple returns should be a tuple)
  • x: input at which to evaluate f (should generally be set to an arbitary point in the domain).
  • : currently accumulated adjoint (should generally be set randomly).

fkwargs are passed to f as keyword arguments. If check_inferred=true, then the inferrability of the rrule is checked — if f is itself inferrable — along with the inferrability of the pullback it returns. All remaining keyword arguments are passed to isapprox.

source
ChainRulesTestUtils.test_scalarMethod
test_scalar(f, z; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), fkwargs=NamedTuple(), check_inferred=true, kwargs...)

Given a function f with scalar input and scalar output, perform finite differencing checks, at input point z to confirm that there are correct frule and rrules provided.

Arguments

  • f: Function for which the frule and rrule should be tested.
  • z: input at which to evaluate f (should generally be set to an arbitary point in the domain).

fkwargs are passed to f as keyword arguments. If check_inferred=true, then the type-stability of the frule and rrule are checked. All remaining keyword arguments are passed to isapprox.

source