# ChainRulesTestUtils

ChainRulesTestUtils.jl helps you test `ChainRulesCore.frule`

and `ChainRulesCore.rrule`

methods, when adding rules for your functions in your own packages. For information about ChainRules, including how to write rules, refer to the general ChainRules Documentation:

## Canonical example

Let's suppose a custom transformation has been defined

```
function two2three(x1::Float64, x2::Float64)
return 1.0, 2.0*x1, 3.0*x2
end
```

along with the `frule`

```
using ChainRulesCore
function ChainRulesCore.frule((Δf, Δx1, Δx2), ::typeof(two2three), x1, x2)
y = two2three(x1, x2)
∂y = Composite{Tuple{Float64, Float64, Float64}}(Zero(), 2.0*Δx1, 3.0*Δx2)
return y, ∂y
end
```

and `rrule`

```
function ChainRulesCore.rrule(::typeof(two2three), x1, x2)
y = two2three(x1, x2)
function two2three_pullback(Ȳ)
return (NO_FIELDS, 2.0*Ȳ[2], 3.0*Ȳ[3])
end
return y, two2three_pullback
end
```

The `frule_test`

/`rrule_test`

helper function compares the `frule`

/`rrule`

outputs to the gradients obtained by finite differencing. They can be used for any type and number of inputs and outputs.

### Testing the `frule`

`frule_test`

takes in the function `f`

and tuples `(x, ẋ)`

for each function argument `x`

. The call will test the `frule`

for function `f`

at the point `x`

in the domain. Keep this in mind when testing discontinuous rules for functions like ReLU, which should ideally be tested at both `x`

being above and below zero. Additionally, choosing `ẋ`

in an unfortunate way (e.g. as zeros) could hide underlying problems with the defined `frule`

.

```
using ChainRulesTestUtils
x1, x2 = (3.33, -7.77)
ẋ1, ẋ2 = (rand(), rand())
frule_test(two2three, (x1, ẋ1), (x2, ẋ2))
```

### Testing the `rrule`

`rrule_test`

takes in the function `f`

, sensitivities of the function outputs `ȳ`

, and tuples `(x, x̄)`

for each function argument `x`

. `x̄`

is the accumulated adjoint which can be set arbitrarily. The call will test the `rrule`

for function `f`

at the point `x`

, and similarly to `frule`

some rules should be tested at multiple points in the domain. Choosing `ȳ`

in an unfortunate way (e.g. as zeros) could hide underlying problems with the `rrule`

.

```
x1, x2 = (3.33, -7.77)
x̄1, x̄2 = (rand(), rand())
ȳs = (rand(), rand(), rand())
rrule_test(two2three, ȳs, (x1, x̄1), (x2, x̄2))
```

## Scalar example

For functions with a single argument and a single output, such as e.g. ReLU,

```
function relu(x::Real)
return max(0, x)
end
```

with the `frule`

and `rrule`

defined with the help of `@scalar_rule`

macro

`@scalar_rule relu(x::Real) x <= 0 ? zero(x) : one(x)`

`test_scalar`

function is provided to test both the `frule`

and the `rrule`

with a single call.

```
test_scalar(relu, 0.5)
test_scalar(relu, -0.5)
```

# API Documentation

`ChainRulesTestUtils.TestIterator`

— Type`TestIterator{T,IS<:Base.IteratorSize,IE<:Base.IteratorEltype}`

A configurable iterator for testing purposes.

```
TestIterator(data, itersize, itereltype)
TestIterator(data)
```

The iterator wraps another iterator `data`

, such as an array, that must have at least as many features implemented as the test iterator and have a `FiniteDifferences.to_vec`

overload. By default, the iterator it has the same features as `data`

.

The optional methods `eltype`

, `length`

, and `size`

are automatically defined and forwarded to `data`

if the type arguments indicate that they should be defined.

`ChainRulesTestUtils.check_equal`

— Method`check_equal(actual, expected; kwargs...)`

`@test`

's that `actual ≈ expected`

, but breaks up data such that human readable results are shown on failures. Understands things like `unthunk`

ing `ChainRuleCore.Thunk`

s, etc. All keyword arguments are passed to `isapprox`

.

`ChainRulesTestUtils.frule_test`

— Method`frule_test(f, (x, ẋ)...; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), fkwargs=NamedTuple(), check_inferred=true, kwargs...)`

**Arguments**

`f`

: Function for which the`frule`

should be tested.`x`

: input at which to evaluate`f`

(should generally be set to an arbitary point in the domain).`ẋ`

: differential w.r.t.`x`

(should generally be set randomly).

`fkwargs`

are passed to `f`

as keyword arguments. If `check_inferred=true`

, then the inferrability of the `frule`

is checked, as long as `f`

is itself inferrable. All remaining keyword arguments are passed to `isapprox`

.

`ChainRulesTestUtils.rrule_test`

— Method`rrule_test(f, ȳ, (x, x̄)...; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), fkwargs=NamedTuple(), check_inferred=true, kwargs...)`

**Arguments**

`f`

: Function to which rule should be applied.`ȳ`

: adjoint w.r.t. output of`f`

(should generally be set randomly). Should be same structure as`f(x)`

(so if multiple returns should be a tuple)`x`

: input at which to evaluate`f`

(should generally be set to an arbitary point in the domain).`x̄`

: currently accumulated adjoint (should generally be set randomly).

`fkwargs`

are passed to `f`

as keyword arguments. If `check_inferred=true`

, then the inferrability of the `rrule`

is checked — if `f`

is itself inferrable — along with the inferrability of the pullback it returns. All remaining keyword arguments are passed to `isapprox`

.

`ChainRulesTestUtils.test_scalar`

— Method`test_scalar(f, z; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), fkwargs=NamedTuple(), check_inferred=true, kwargs...)`

Given a function `f`

with scalar input and scalar output, perform finite differencing checks, at input point `z`

to confirm that there are correct `frule`

and `rrule`

s provided.

**Arguments**

`f`

: Function for which the`frule`

and`rrule`

should be tested.`z`

: input at which to evaluate`f`

(should generally be set to an arbitary point in the domain).

`fkwargs`

are passed to `f`

as keyword arguments. If `check_inferred=true`

, then the type-stability of the `frule`

and `rrule`

are checked. All remaining keyword arguments are passed to `isapprox`

.