Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 56 additions & 12 deletions firedrake/preconditioners/bddc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from firedrake.parloops import par_loop, INC, READ
from firedrake.bcs import DirichletBC
from firedrake.mesh import Submesh
from ufl import Form, H1, H2, JacobianDeterminant, dx, inner, replace
from ufl import Form, H1, H2, JacobianDeterminant, div, dx, inner, replace
from finat.ufl import BrokenElement
from pyop2.mpi import COMM_SELF
from pyop2.utils import as_tuple
Expand Down Expand Up @@ -118,6 +118,10 @@ def initialize(self, pc):
bddcpc.setBDDCNeumannBoundaries(neu_bndr)

appctx = self.get_appctx(pc)
entity_dofs = V.finat_element.entity_dofs()
vdofs = entity_dofs[min(entity_dofs)]
if any(len(vdofs[v]) > 0 for v in vdofs):
bddcpc.setCoordinates(get_entity_coordinates(V))

# Set coordinates only if corner selection is requested
# There's no API to query from PC
Expand Down Expand Up @@ -189,7 +193,7 @@ def nodes(self):
return numpy.flatnonzero(u.dat.data)


def create_matis(Amat, local_mat_type, cellwise=False):
def create_matis(a, local_mat_type, cellwise=False, bcs=()):
from firedrake.assemble import get_assembler

def local_mesh(mesh):
Expand Down Expand Up @@ -244,10 +248,18 @@ def local_to_global_map(V, cellwise):
indices = usub.dat.data_ro.astype(PETSc.IntType)
return PETSc.LGMap().create(indices, comm=V.comm)

assert Amat.type == "python"
ctx = Amat.getPythonContext()
form = ctx.a
bcs = ctx.bcs
if isinstance(a, Form):
form = a
args = a.arguments()
comm = args[0].function_space().comm
sizes = tuple(arg.function_space().dof_dset.layout_vec.getSizes() for arg in args)
elif isinstance(a, PETSc.Mat):
assert a.type == "python"
ctx = a.getPythonContext()
form = ctx.a
bcs = ctx.bcs
comm = a.comm
sizes = a.getSizes()

local_form = replace(form, {arg: local_argument(arg, cellwise) for arg in form.arguments()})
local_form = Form(list(map(local_integral, local_form.integrals())))
Expand All @@ -259,7 +271,7 @@ def local_to_global_map(V, cellwise):
rmap = local_to_global_map(form.arguments()[0].function_space(), cellwise)
cmap = local_to_global_map(form.arguments()[1].function_space(), cellwise)

Amatis = PETSc.Mat().createIS(Amat.getSizes(), comm=Amat.getComm())
Amatis = PETSc.Mat().createIS(sizes, comm=comm)
Amatis.setISAllowRepeated(cellwise)
Amatis.setLGMap(rmap, cmap)
Amatis.setISLocalMat(tensor.petscmat)
Expand All @@ -283,12 +295,17 @@ def get_divergence_mat(V, mat_type="is", allow_repeated=False):
from firedrake import assemble
degree = max(as_tuple(V.ufl_element().degree()))
Q = TensorFunctionSpace(V.mesh(), "DG", 0, variant=f"integral({degree-1})", shape=V.value_shape[:-1])
B = tabulate_exterior_derivative(V, Q, mat_type=mat_type, allow_repeated=allow_repeated)

Jdet = JacobianDeterminant(V.mesh())
s = assemble(inner(TrialFunction(Q)*(1/Jdet), TestFunction(Q))*dx(degree=0), diagonal=True)
with s.dat.vec as svec:
B.diagonalScale(svec, None)
if mat_type == "is" and (V.ufl_element().sobolev_space == H1 or V.finat_element.complex.is_macrocell()):
form = inner(div(TrialFunction(V)), TestFunction(Q)) * dx
B, _ = create_matis(form, "aij", allow_repeated)
else:
B = tabulate_exterior_derivative(V, Q, mat_type=mat_type, allow_repeated=allow_repeated)
Jdet = JacobianDeterminant(V.mesh())
s = assemble(inner(TrialFunction(Q)*(1/Jdet), TestFunction(Q))*dx(degree=0), diagonal=True)
with s.dat.vec as svec:
B.diagonalScale(svec, None)

return (B,), {}


Expand Down Expand Up @@ -338,3 +355,30 @@ def get_primal_indices(V, primal_markers):
else:
primal_indices = numpy.asarray(primal_markers, dtype=PETSc.IntType)
return primal_indices


def get_entity_coordinates(V):
mesh = V.mesh()
plex = mesh.topology_dm
num_dofs = V.dof_dset.layout_vec.getSizes()[0]
section = V.dm.getLocalSection()
vstart, vend = plex.getDepthStratum(0)
coords = numpy.empty((num_dofs, mesh.geometric_dimension))

if mesh.extruded:
P1 = VectorFunctionSpace(mesh, "Lagrange", 1)
plex_coords = Function(P1).interpolate(mesh.coordinates).dat.data_ro_with_halos
else:
plex_coords = plex.getCoordinatesLocal().getArray().reshape(-1, mesh.geometric_dimension)
for p in range(*plex.getChart()):
dof = section.getDof(p)
if dof <= 0:
continue
off = section.getOffset(p)
V_slice = slice(off, off + dof)

closure, _ = plex.getTransitiveClosure(p, useCone=True)
pverts = [q - vstart for q in closure if vstart <= q < vend]
coords[V_slice] = numpy.average(plex_coords[pverts], axis=0)

return coords
Loading