The Leapfrog Method
P. B. Visscher developed an interesting variation on the finite difference approach. This staggered time, or leapfrog, approach computes the real and imaginary parts of the wave function at slightly different times. and
This maps naturally into our wave function representation as a Web Shader Language (WSL) vec2f.
The update equations are adapted to work with these staggered times. and
What is really interesting is they way these update equations are applied. Starting with at and at we apply the first update equation to update all of to , then apply the second update equation to update to .
A careful examination shows that these updates can be performed in place without allocating any additional memory. Updating each depends only on the same and elements of . Similarly, updating each depends only on the same and elements of .
We see this method in action for the same case that caused trouble with the original forward difference approach. The staggered time approach appears to give the same results as the central difference version. We compare these a little later.
The Code
Initialization
As always with FDTD codes, we need to set up the initial conditions. Our central difference initial value
set up computed the wave function at two different times, and loaded these values into separate arrays
each holding the wave function at a specific time. For the staggered time approach, we compute the wave
function at two different times,
and
.
Then assign the values for to the
fields in the wave function array, and the
values for to the imaginary components
of the wave function. Specifically, this code is invoked with time=0.
let halfDt = parameters.dt / 2.0;
// Real part of the wave function at t.
timeWaveFunction[index].x = computePsi(index, time).x;
// Imiginary part of the wave function at t+dt/2, hence staggered time
timeWaveFunction[index].y = computePsi(index, time+halfDt).y;
Timesteps
Adapting the timesteps to the staggered time requires a bit more thought. We must update the entirety of and only then go on to update . This means having two entry points in the shader, one to update the real part, and the other to update the imaginary part of the wave function.
/**
* Timestep the real component of the wave function. For consistency and correctness we must
* generate timesteps for the real part, then for the imaginary part of the wave function.
*/
@compute @workgroup_size(${WORKGROUP_SIZE})
fn realPartTimeStep(@builtin(global_invocation_id) global_id : vec3u)
{
let index = global_id.x;
// Skip invocations when work groups exceed the actual problem size
if (index >= parameters.xResolution) {
return;
}
// The potential and the wave function arrays have the same size.
let dx = parameters.length / f32(parameters.xResolution-1);
let dx22 = 2.0*dx*dx;
let V = parameters.potential[index];
let waveFunctionAtX = waveFunction[index];
let waveFunctionAtXPlusDx = waveFunction[min(index+1, parameters.xResolution-1)];
let waveFunctionAtXMinusDx = waveFunction[max(index-1, 0)];
waveFunction[index].x = waveFunctionAtX.x
- ((waveFunctionAtXPlusDx.y - 2.0*waveFunctionAtX.y + waveFunctionAtXMinusDx.y)
/ dx22 - V*waveFunctionAtX.y) * parameters.dt;
}
/**
* Timestep the imaginary component of the wave function. For consistency and correctness we must
* generate timesteps for the real part, then for the imaginary part of the wave function.
*/
@compute @workgroup_size(${WORKGROUP_SIZE})
fn imaginaryPartTimeStep(@builtin(global_invocation_id) global_id : vec3u)
{
let index = global_id.x;
// Skip invocations when work groups exceed the actual problem size
if (index >= parameters.xResolution) {
return;
}
// The potential and the wave function arrays have the same size.
let dx = parameters.length / f32(parameters.xResolution-1);
let dx22 = 2.0*dx*dx;
let V = parameters.potential[index];
let waveFunctionAtX = waveFunction[index];
let waveFunctionAtXPlusDx = waveFunction[min(index+1, parameters.xResolution-1)];
let waveFunctionAtXMinusDx = waveFunction[max(index-1, 0)];
waveFunction[index].y = waveFunctionAtX.y
+ ((waveFunctionAtXPlusDx.x - 2.0*waveFunctionAtX.x + waveFunctionAtXMinusDx.x)
/ dx22 - V*waveFunctionAtX.x) * parameters.dt;
}
In these updates we take the waveFunction array as our input, and write to that same
waveFunction array. This is what we meant when we mentioned earlier that we can make
the updates in place. This time we only need the single waveFunction array.
// Group 1, Current wave function with Ψ_r at t and Ψ_i at t+Δt/2.
@group(1) @binding(0) var<storage, read_write> waveFunction : array<vec2f>;
To control this we need two pipelines, one to run the real part timestep.
realPartTimeStep = device.createComputePipeline({
label: "update real part pipeline",
layout: device.createPipelineLayout({
bindGroupLayouts: [parametersBindGroupLayout, waveFunctionBindGroupLayout]
}),
compute: {
module: timeStepShaderModule,
entryPoint: "realPartTimeStep"
}
});
And the second to run the imaginary part timestep.
imaginaryPartTimeStep = device.createComputePipeline({
label: "update imaginary part pipeline",
layout: device.createPipelineLayout({
bindGroupLayouts: [parametersBindGroupLayout, waveFunctionBindGroupLayout]
}),
compute: {
module: timeStepShaderModule,
entryPoint: "imaginaryPartTimeStep"
}
});
Each timestep of the wave function consists of a timestep of the real part followed by a timestep of the imaginary part. Thus, we move from one state to another state where we have at and at .
/**
* Execute count iterations of the simulation. Each iteration consists of one update to the
* real part of the wave function, and one update to the imaginary part of the wave function.
*
* @param {Integer} count The number of iterations to carry out.
*/
step(count=21) {
running = true;
// Recreate this because it can not be reused after finish is invoked.
const commandEncoder = device.createCommandEncoder();
const workgroupCountX = Math.ceil(xResolution / WORKGROUP_SIZE);
for (let i=0; i<count && running; i++)
{
const realPassEncoder = commandEncoder.beginComputePass();
realPassEncoder.setPipeline(realPartTimeStep);
realPassEncoder.setBindGroup(0, parametersBindGroup);
realPassEncoder.setBindGroup(1, waveFunctionBindGroup);
realPassEncoder.dispatchWorkgroups(workgroupCountX);
realPassEncoder.end();
const imaginaryPassEncoder = commandEncoder.beginComputePass();
imaginaryPassEncoder.setPipeline(imaginaryPartTimeStep);
imaginaryPassEncoder.setBindGroup(0, parametersBindGroup);
imaginaryPassEncoder.setBindGroup(1, waveFunctionBindGroup);
imaginaryPassEncoder.dispatchWorkgroups(workgroupCountX);
imaginaryPassEncoder.end();
}
// Submit GPU commands.
const gpuCommands = commandEncoder.finish();
device.queue.submit([gpuCommands]);
}