Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions examples/src/examples/compute/edge-detect.compute-shader.wgsl
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
// Include half-precision type aliases (resolves to f16 when supported, f32 otherwise)
#include "halfTypesCS"

@group(0) @binding(0) var inputTexture: texture_2d<f32>;
@group(0) @binding(1) var inputTexture_sampler: sampler;
@group(0) @binding(2) var outputTexture: texture_storage_2d<rgba8unorm, write>;
// Simplified-syntax declarations (no @group/@binding) - the engine reflects these into a bind
// group automatically, so the example does not provide a computeBindGroupFormat.
var inputTexture: texture_2d<f32>;
var inputTexture_sampler: sampler;
var outputTexture: texture_storage_2d<rgba8unorm, write>;

@compute @workgroup_size(8, 8, 1)
fn main(@builtin(global_invocation_id) global_id : vec3u) {
Expand Down
13 changes: 4 additions & 9 deletions examples/src/examples/compute/edge-detect.example.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -173,18 +173,13 @@ assetListLoader.load(() => {
const createComputeShader = () => {
if (!device.supportsCompute) return null;

// No computeBindGroupFormat is provided - the input texture (+ sampler) and the output
// storage texture use the simplified WGSL syntax and are reflected automatically by the
// engine from the shader source.
return new pc.Shader(device, {
name: 'EdgeDetect-Shader',
shaderLanguage: pc.SHADERLANGUAGE_WGSL,
cshader: computeShaderWgsl,

// Format of a bind group for the compute shader
computeBindGroupFormat: new pc.BindGroupFormat(device, [
// Input texture with sampler (sampler takes binding slot+1 automatically)
new pc.BindTextureFormat('inputTexture', pc.SHADERSTAGE_COMPUTE, undefined, undefined, true),
// Output storage texture
new pc.BindStorageTextureFormat('outputTexture', pc.PIXELFORMAT_RGBA8, pc.TEXTUREDIMENSION_2D)
])
cshader: computeShaderWgsl
});
};

Expand Down
26 changes: 5 additions & 21 deletions examples/src/examples/compute/particles.example.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -84,31 +84,15 @@ assetListLoader.load(() => {

const numParticles = 1024 * 1024;

// a compute shader that will simulate the particles stored in a storage buffer
// a compute shader that will simulate the particles stored in a storage buffer. No bind group
// or uniform buffer formats are provided - the loose uniforms (count, dt, sphereCount) and the
// storage buffers (particles, spheres) use the simplified WGSL syntax and are reflected
// automatically by the engine from the shader source.
const shader = device.supportsCompute ?
new pc.Shader(device, {
name: 'SimulationShader',
shaderLanguage: pc.SHADERLANGUAGE_WGSL,
cshader: shaderSharedWgsl + shaderSimulationWgsl,

// format of a uniform buffer used by the compute shader
computeUniformBufferFormats: {
ub: new pc.UniformBufferFormat(device, [
new pc.UniformFormat('count', pc.UNIFORMTYPE_UINT),
new pc.UniformFormat('dt', pc.UNIFORMTYPE_FLOAT),
new pc.UniformFormat('sphereCount', pc.UNIFORMTYPE_UINT)
])
},

// format of a bind group, providing resources for the compute shader
computeBindGroupFormat: new pc.BindGroupFormat(device, [
// a uniform buffer we provided the format for
new pc.BindUniformBufferFormat('ub', pc.SHADERSTAGE_COMPUTE),
// particle storage buffer
new pc.BindStorageBufferFormat('particles', pc.SHADERSTAGE_COMPUTE),
// rad only collision spheres
new pc.BindStorageBufferFormat('spheres', pc.SHADERSTAGE_COMPUTE, true)
])
cshader: shaderSharedWgsl + shaderSimulationWgsl
}) :
null;

Expand Down
25 changes: 12 additions & 13 deletions examples/src/examples/compute/particles.shader-simulation.wgsl
Original file line number Diff line number Diff line change
@@ -1,30 +1,29 @@
// uniform buffer for the compute shader
struct ub_compute {
count: u32, // number of particles
dt: f32, // delta time
sphereCount: u32 // number of spheres
}

// sphere struct used for the colliders
struct Sphere {
center: vec3<f32>,
radius: f32
}

@group(0) @binding(0) var<uniform> ubCompute : ub_compute;
@group(0) @binding(1) var<storage, read_write> particles: array<Particle>;
@group(0) @binding(2) var<storage, read> spheres: array<Sphere>;
// Simplified-syntax resources (no @group/@binding) - the engine reflects these into a bind group
// automatically, so the example does not provide computeBindGroupFormat / computeUniformBufferFormats.
// The loose uniforms are collapsed into a single generated uniform buffer.
uniform count: u32; // number of particles
uniform dt: f32; // delta time
uniform sphereCount: u32; // number of spheres

var<storage, read_write> particles: array<Particle>;
var<storage, read> spheres: array<Sphere>;

@compute @workgroup_size(64)
fn main(@builtin(global_invocation_id) global_invocation_id: vec3u) {

// particle index - ignore if out of bounds (as they get batched into groups of 64)
let index = global_invocation_id.x * 1024 + global_invocation_id.y;
if (index >= ubCompute.count) { return; }
if (index >= uniform.count) { return; }

// update times
var particle = particles[index];
particle.collisionTime += ubCompute.dt;
particle.collisionTime += uniform.dt;

// if particle gets too far, reset it to its original position / velocity
var distance = length(particle.position);
Expand All @@ -41,7 +40,7 @@ fn main(@builtin(global_invocation_id) global_invocation_id: vec3u) {
var next = particle.position + delta;

// handle collisions with spheres
for (var i = 0u; i < ubCompute.sphereCount; i++) {
for (var i = 0u; i < uniform.sphereCount; i++) {
var center = spheres[i].center;
var radius = spheres[i].radius;

Expand Down
26 changes: 20 additions & 6 deletions src/platform/graphics/webgpu/webgpu-compute-pipeline.js
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ class CacheEntry {
}

class WebgpuComputePipeline extends WebgpuPipeline {
lookupHashes = new Uint32Array(2);
// shader compute key + up to 2 bind group format keys (caller group 0 + reflected group)
lookupHashes = new Uint32Array(3);

/**
* The cache of compute pipelines
Expand All @@ -38,12 +39,22 @@ class WebgpuComputePipeline extends WebgpuPipeline {
*/
cache = new Map();

get(shader, bindGroupFormat) {
/**
* @param {import('../shader.js').Shader} shader - The compute shader.
* @param {import('../bind-group-format.js').BindGroupFormat[]} bindGroupFormats - The bind group
* formats, in bind group index order (dense, no gaps).
* @returns {object} - The compute pipeline (GPUComputePipeline).
*/
get(shader, bindGroupFormats) {

Debug.assert(bindGroupFormats.length <= 2);

// unique hash for the pipeline
// unique hash for the pipeline - shader key followed by each bind group format key (0 for
// an absent group). All slots are written, so no need to clear stale values from reuse.
const lookupHashes = this.lookupHashes;
lookupHashes[0] = shader.impl.computeKey;
lookupHashes[1] = bindGroupFormat.impl.key;
lookupHashes[1] = bindGroupFormats[0] ? bindGroupFormats[0].impl.key : 0;
lookupHashes[2] = bindGroupFormats[1] ? bindGroupFormats[1].impl.key : 0;
const hash = hash32Fnv1a(lookupHashes);

// Check cache
Expand All @@ -58,8 +69,11 @@ class WebgpuComputePipeline extends WebgpuPipeline {
}
}

// Cache miss - create new pipeline
const pipelineLayout = this.getPipelineLayout([bindGroupFormat.impl]);
// Cache miss - create new pipeline. Build the impl array explicitly (at most 2 groups).
const impls = [];
if (bindGroupFormats[0]) impls.push(bindGroupFormats[0].impl);
if (bindGroupFormats[1]) impls.push(bindGroupFormats[1].impl);
const pipelineLayout = this.getPipelineLayout(impls);
const cacheEntry = new CacheEntry();
cacheEntry.hashes = new Uint32Array(lookupHashes);
cacheEntry.pipeline = this.create(shader, pipelineLayout);
Expand Down
90 changes: 65 additions & 25 deletions src/platform/graphics/webgpu/webgpu-compute.js
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,14 @@ class WebgpuCompute {
/** @type {UniformBuffer[]} */
uniformBuffers = [];

/** @type {BindGroup} */
bindGroup = null;
/**
* Bind groups, indexed by bind group index. A caller-provided format occupies group 0;
* auto-reflected resources occupy their own group (0 when no caller format, otherwise 1).
* The array is dense (no gaps), as required by WebGPU pipeline layouts.
*
* @type {BindGroup[]}
*/
bindGroups = [];

constructor(compute) {
this.compute = compute;
Expand All @@ -25,27 +31,57 @@ class WebgpuCompute {

DebugGraphics.pushGpuMarker(device, `Compute:${compute.name}`);

// create bind group
const { computeBindGroupFormat, computeUniformBufferFormats } = shader.impl;
Debug.assert(computeBindGroupFormat, 'Compute shader does not have computeBindGroupFormat specified', shader);

// this.bindGroup = new BindGroup(device, computeBindGroupFormat, this.uniformBuffer);
this.bindGroup = new BindGroup(device, computeBindGroupFormat);
DebugHelper.setName(this.bindGroup, `Compute-BindGroup_${this.bindGroup.id}`);

if (computeUniformBufferFormats) {
for (const name in computeUniformBufferFormats) {
if (computeUniformBufferFormats.hasOwnProperty(name)) {
// TODO: investigate implications of using a non-persistent uniform buffer
const ub = new UniformBuffer(device, computeUniformBufferFormats[name], true);
this.uniformBuffers.push(ub);
this.bindGroup.setUniformBuffer(name, ub);
const {
computeBindGroupFormat, computeUniformBufferFormats,
computeReflectedBindGroupFormat, computeReflectedUniformBufferFormat,
computeReflectedGroupIndex
} = shader.impl;

Comment on lines +34 to +39
// caller uniform buffers are bound into the caller bind group, so the format is required
Debug.assert(!computeUniformBufferFormats || computeBindGroupFormat,
'Compute shader specifies computeUniformBufferFormats but no computeBindGroupFormat to bind them into', shader);

// ordered, gapless array of bind group formats (array index === bind group index)
const formats = [];

// group 0: caller-provided resources (if any)
if (computeBindGroupFormat) {
const bindGroup = new BindGroup(device, computeBindGroupFormat);
DebugHelper.setName(bindGroup, `Compute-BindGroup_${bindGroup.id}`);

if (computeUniformBufferFormats) {
for (const name in computeUniformBufferFormats) {
if (computeUniformBufferFormats.hasOwnProperty(name)) {
// TODO: investigate implications of using a non-persistent uniform buffer
const ub = new UniformBuffer(device, computeUniformBufferFormats[name], true);
this.uniformBuffers.push(ub);
bindGroup.setUniformBuffer(name, ub);
}
}
}

formats[0] = computeBindGroupFormat;
this.bindGroups[0] = bindGroup;
}

// auto-reflected resources, at their own bind group (0 when no caller format, otherwise 1)
if (computeReflectedBindGroupFormat) {
const reflectedBindGroup = new BindGroup(device, computeReflectedBindGroupFormat);
DebugHelper.setName(reflectedBindGroup, `Compute-ReflectedBindGroup_${reflectedBindGroup.id}`);

if (computeReflectedUniformBufferFormat) {
// matches the generated 'ub_compute' uniform buffer (see WebgpuShaderProcessorWGSL.runCompute)
const ub = new UniformBuffer(device, computeReflectedUniformBufferFormat, true);
this.uniformBuffers.push(ub);
reflectedBindGroup.setUniformBuffer('ub_compute', ub);
}

formats[computeReflectedGroupIndex] = computeReflectedBindGroupFormat;
this.bindGroups[computeReflectedGroupIndex] = reflectedBindGroup;
}

// pipeline
this.pipeline = device.computePipeline.get(shader, computeBindGroupFormat);
this.pipeline = device.computePipeline.get(shader, formats);

DebugGraphics.popGpuMarker(device);
}
Expand All @@ -55,23 +91,27 @@ class WebgpuCompute {
this.uniformBuffers.forEach(ub => ub.destroy());
this.uniformBuffers.length = 0;

this.bindGroup.destroy();
this.bindGroup = null;
this.bindGroups.forEach(bindGroup => bindGroup.destroy());
this.bindGroups.length = 0;
}

updateBindGroup() {

// bind group data
const { bindGroup } = this;
bindGroup.updateUniformBuffers();
bindGroup.update();
for (let i = 0; i < this.bindGroups.length; i++) {
const bindGroup = this.bindGroups[i];
bindGroup.updateUniformBuffers();
bindGroup.update();
}
}

dispatch(x, y, z) {

// bind group
// bind groups
const device = this.compute.device;
device.setBindGroup(0, this.bindGroup);
for (let i = 0; i < this.bindGroups.length; i++) {
device.setBindGroup(i, this.bindGroups[i]);
}

// compute pipeline
const passEncoder = device.passEncoder;
Expand Down
Loading