Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions irksome/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
)
from .tableaux.pep_explicit_rk import PEPRK
from .ufl.deriv import Dt, expand_time_derivatives, check_irksome_import_order
from .ufl.lag import lag

check_irksome_import_order()

Expand Down Expand Up @@ -56,6 +57,7 @@
"expand_time_derivatives",
"GalerkinCollocationScheme",
"GaussLegendre",
"lag",
"LobattoIIIA",
"LobattoIIIC",
"MeshConstant",
Expand Down
8 changes: 7 additions & 1 deletion irksome/tools.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from .backend import get_backend
import numpy
from ufl.algorithms.analysis import extract_type
from ufl.classes import Variable
from ufl import as_tensor, replace as ufl_replace

import FIAT

from .ufl.deriv import TimeDerivative
from .ufl.lag import lag_label


def dot(A, B):
Expand Down Expand Up @@ -120,8 +122,12 @@ def getNullspace(V, Vbig, num_stages, nullspace):


def replace(e, mapping):
"""A wrapper for ufl.replace that allows numpy arrays."""
"""A wrapper for ufl.replace that allows numpy arrays and skips
substitution into sub-expressions wrapped by :func:`~irksome.lag`."""
cmapping = {k: as_tensor(v) for k, v in mapping.items()}
for var in extract_type(e, Variable):
if var.ufl_operands[1] is lag_label:
cmapping.setdefault(var, var)
return ufl_replace(e, cmapping)


Expand Down
12 changes: 12 additions & 0 deletions irksome/ufl/lag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from ufl.classes import Label, Variable


# A :class:`ufl.Label` to mark nodes that are only evaluated at the start of
# the timestep.
lag_label = Label()


def lag(expr):
"""Mark a sub-expression to be evaluated only at the start of the
timestep during the implicit solve."""
return Variable(expr, lag_label)
56 changes: 56 additions & 0 deletions tests/test_lag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import pytest
import firedrake
from firedrake import Constant, inner, grad, dx, conditional, assemble
import irksome
from irksome import Dt, lag


def test_stefan_implicit():
"""Test lagging the conductivity on the Stefan problem"""
nx = 32
mesh = firedrake.UnitIntervalMesh(nx)
V = firedrake.FunctionSpace(mesh, "CG", 1)
x, = firedrake.SpatialCoordinate(mesh)

u = firedrake.Function(V)
u.interpolate(1 - 2 * x)
u_0 = u.copy(deepcopy=True)

T_1 = Constant(1.0)
T_2 = Constant(-1.0)
bcs = [firedrake.DirichletBC(V, T_1, 1), firedrake.DirichletBC(V, T_2, 2)]

v = firedrake.TestFunction(V)
t = Constant(0.0)
dt = Constant(0.1)
final_time = 10.0
num_steps = int(final_time / float(dt))
solver_params = {"snes_type": "newtonls", "snes_converged_reason": None}
params = {"bcs": bcs, "solver_parameters": solver_params}
method = irksome.BackwardEuler()

k_solid = Constant(2.0)
k_liquid = Constant(1.0)

# Check that fully implicit form fails
k = conditional(u < 0, k_solid, k_liquid)
F = (Dt(u) * v + k * inner(grad(u), grad(v))) * dx
stepper = irksome.TimeStepper(F, method, t, dt, u, **params)
with pytest.raises(firedrake.ConvergenceError):
for step in range(num_steps):
stepper.advance()

# Check that the lagged form works
u.assign(u_0)
F = (Dt(u) * v + lag(k) * inner(grad(u), grad(v))) * dx
stepper = irksome.TimeStepper(F, method, t, dt, u, **params)
energy = 0.5 * u**2 * dx
initial_energy = assemble(energy)

Comment thread
pbrubeck marked this conversation as resolved.
for step in range(num_steps):
stepper.advance()
Comment thread
pbrubeck marked this conversation as resolved.

final_energy = assemble(energy)
print(f"Initial energy: {initial_energy}")
print(f"Final: {final_energy}")
assert final_energy <= initial_energy
Loading