The Schrödinger Equation on a GPU
This section, indeed this entire set of pages, is a work in progress. While conceptually they will be reasonable, there are certainly adjustments that need to be made to improve clarity and improve a few small issues of correctness.
Both quantum mechanics and GPGPU computing are complex and difficult topics. Appropriately, we will use each to shed light on the other. The Schrödinger equation will provide a specific and valuable application of GPGPU computing, and GPGPU computing will in turn provide interactive simulations and visualizations where you can experiment with and explore the sometimes highly counterintuitive world of quantum mechanics.
Luckily, one of the most popular and straightforward numerical methods, the finite difference method, is particularly well suited to implementations on the GPU.
The FDTD method is a wonderfully simple method that can be taught at the undergraduate or early graduate level. Yet,it is capable of solving extremely sophisticated engineering problems.
We will work through the process of developing a compute shader based FDTD treatment with enough detail to provide a good introduction, and hopefully enough insight to apply these methods to other topics such as thermodynamics, electrodynamics, and fluid dynamics.
The Representation
Let's start with the one dimensional time dependent Schrödinger equation.
The Schrödinger equation follows a pattern that we see over and over in physics and engineering. On one side we have a time derivative, and on the other we have spatial derivatives. This means we write our evolution in time as a function of differences over space. We have two arrays, with each array containing the values of our function over space at a specific time. Then each computation step from one array to the other is a step forward in time.
What exactly do we mean when "we take each array as containing the values of our function
over space at a specific time"?
We are looking at a 1D Schrödinger equation, so try to model it with discrete values on a
1D grid, an array. In terms of variables in our code, the grid has xResolution
points, with
length
in increments of
δ = length
/xResolution
between array elements.
We have to be careful putting the Schrödinger equation on the grid. The
in the Schrödinger equation indicates that
is a complex function, that it has both real and imaginary parts. Rather than make each array
an array of floats, we make each array an array of vec2<f32>
or vec2f
.
We will put the real part of the wave function in the 0 component of the vec2f
, and
the imaginary part in the 1 component.
The Approximation
We approximate the continuous function with values at discrete points at a fixed time on our grid. For the derivatives we use differences between points on the grid, hence the name of the technique we are using, finite differences.
Now, we can start toward a finite difference approximation to the Schrödinger equation. Start by writing out an approximation to the time derivative.
This is almost the familiar definition of the first derivative, but without the expected limit. Without the limit it is clearly an approximation. The exact error can be found from the Taylor series expansion for , and is readily shown to be . As we expect, we can improve the quality of the approximation by choosing a smaller , making successive arrays closer together in time.
Putting this into the Schrödinger equation and isolating the term on the left, while setting and to for simplicity, we get
Which shows explicitly that the new value for an array element is the current value for the array element, modified according to the values of the surrounding array elements.
Now that we have a grid with explicit real and imaginary parts, we break up the wave function and then the wave equation into real and imaginary parts.
We put this into the Schrödinger equation, once again setting and to , then gather the real and imaginary parts into their own equations.
To understand how to manifest these equations on a grid, focus on the operator. As it turns out, it's actually pretty straight forward.
Now we have approximations of both the time and space derivatives in terms of values on our grid.
The Code
Now we start writing some code. We need two arrays representing values for our wave function. One array for the values at the current time , and another array for the values at . The compute shader uses the values at to compute the wave function at .
//group 1, changes on each iteration
// Initial wave function at t.
@group(1) @binding(0) var<storage, read_write> waveFunction : array<vec2f>;
// The updated wave function at t+Δt.
@group(1) @binding(1) var<storage, read_write> updatedWaveFunction : array<vec2f>;
Now, when our mathematics makes an assignment to
the code will make an assignment to updatedWaveFunction[i].x
, and assignments to
map to updatedWaveFunction[i].y
.
Every reference to
is an access to waveFunction[i].x
, and every reference to
is an access to waveFunction[i].y
.
The terms are also easily related to the wave function arrays. is the spacing between elements, so is an offset of the array index by .
We also need some physical parameters for the simulation. Specifically, we need to specify
the Δt for the step size between computed wave functions, and Δx for the step size between
values of the wave function at a given time. We compute Δx as length / xResolution
,
and also use xResolution
for the size of our arrays.
struct Parameters {
dt: f32, // The time step, Δt.
xResolution: u32, // The number of points along the x-axis, the number of elements in the array.
length: f32 // The physical length for our simulation.
}
// group 0, things that never change within a simulation.
// The parameters for the simulation
@group(0) @binding(0) var<storage, read> parameters: Parameters;
You may notice that the wave function arrays and the simulation parameters are in different groups. This is a standard best practice to group inputs according to the frequency of their updates.
Entry Point
The @compute
annotation marks the entry point for the compute shader. We also
provide a @workgroup_size
of 64. Remember when we had a 2D matrix, we used a
workgroup size of 8x8. Both of these workgroups have 64 elements, which you might remember
is the size we chose to make efficient use
of the GPU.
@compute @workgroup_size(64)
fn timeStep(@builtin(global_invocation_id) global_id : vec3u)
The Code
The body of the timeStep
method is a straightforward implementation of the
Schrödinger equation. We do precompute some common expressions, and it is important to
remember that we invoke this shader once for each element of the updatedWaveFunction
.
let index = global_id.x;
// waveFunction, and updatedWaveFunction have the same size.
let dx = parameters.length / f32(parameters.xResolution);
let dx22 = dx*dx*2.0;
let waveFunctionAtX = waveFunction[index];
let waveFunctionAtXPlusDx = waveFunction[min(index+1, parameters.xResolution-1)];
let waveFunctionAtXMinusDx = waveFunction[max(index-1, 0)];
updatedWaveFunction[index].x = waveFunctionAtX.x
- ((waveFunctionAtXPlusDx.y - 2.0*waveFunctionAtX.y + waveFunctionAtXMinusDx.y)
/ dx22) * parameters.dt;
updatedWaveFunction[index].y = waveFunctionAtX.y
+ ((waveFunctionAtXPlusDx.x - 2.0*waveFunctionAtX.x + waveFunctionAtXMinusDx.x)
/ dx22) * parameters.dt;
Our entire compute shader is then
struct Parameters {
dt: f32, // The time step, Δt.
xResolution: u32, // The number of points along the x-axis, the number of elements in the array.
length: f32 // The physical length for our simulation.
}
// group 0, things that never change within a simulation.
// The parameters for the simulation
@group(0) @binding(0) var<storage, read> parameters: Parameters;
//group 1, changes on each iteration
// Initial wave function at t.
@group(1) @binding(0) var<storage, read_write> waveFunction : array<vec2f>;
// The updated wave function at t+Δt.
@group(1) @binding(1) var<storage, read_write> updatedWaveFunction : array<vec2f>;
@compute @workgroup_size(64)
fn timeStep(@builtin(global_invocation_id) global_id : vec3u)
{
let index = global_id.x;
// Skip invocations when work groups exceed the actual problem size
if (index >= parameters.xResolution) {
return;
}
// waveFunction, and updatedWaveFunction have the same size.
let dx = parameters.length / f32(parameters.xResolution);
let dx22 = dx*dx*2.0;
let waveFunctionAtX = waveFunction[index];
let waveFunctionAtXPlusDx = waveFunction[min(index+1, parameters.xResolution-1)];
let waveFunctionAtXMinusDx = waveFunction[max(index-1, 0)];
updatedWaveFunction[index].x = waveFunctionAtX.x
- ((waveFunctionAtXPlusDx.y - 2.0*waveFunctionAtX.y + waveFunctionAtXMinusDx.y)
/ dx22) * parameters.dt;
updatedWaveFunction[index].y = waveFunctionAtX.y
+ ((waveFunctionAtXPlusDx.x - 2.0*waveFunctionAtX.x + waveFunctionAtXMinusDx.x)
/ dx22) * parameters.dt;
}
At the start of each step, the input array contains values for
. The shaders use this as input to calculate
so that at the end of a step the updatedWaveFunction
array is populated with
.
At this point we no longer need the original array. We swap arrays and use the array as input, and write the into the array that previously held the values.
After every step, we recycle the array holding the previous timestep as the target for the next timestep. This technique of switching arrays back and forth as source and target for our calculations is well known as ping-pong buffering.
As always with WebGPU we load the shader into a GPUShaderModule.
timeStepShaderModule = device.createShaderModule({
label: 'Schrodinger time step shader',
code: timeStepShader
});
The Buffers
We need to feed data to this simulation through WebGPU buffers, which of course reflect the data structures within the shader.
struct Parameters {
dt: f32, // The time step, Δt.
xResolution: u32, // The number of points along the x-axis, the number of elements in the array.
length: f32 // The physical length for our simulation.
}
// group 0, things that never change within a simulation.
// The parameters for the simulation
@group(0) @binding(0) var<storage, read> parameters: Parameters;
The bind group layout identifies the usage of a group of resources. Here, binding 0 is a read only storage object available to the fragment and compute shaders. This is forward looking in that we make the data available to the fragment shader, which we will use to render the wave function, and we use storage rather than a uniform because we will later include a potential in the parameters.
parametersBindGroupLayout = device.createBindGroupLayout({
label: "Simulation parameters",
entries: [
{
binding: 0,
visibility: GPUShaderStage.COMPUTE | GPUShaderStage.FRAGMENT,
buffer: {
type: "read-only-storage"
}
}
]
});
The next thing is to create the buffer and load it with data. For clarity and correctness, we explicitly account for the size of each field. As before, the buffer is mapped at creating allowing us to load it with data. If the debug option is set, we also provide for the buffer to be copied so that we can later examine its contents.
parametersBuffer = device.createBuffer({
label: "Parameters buffer",
mappedAtCreation: true,
size: Float32Array.BYTES_PER_ELEMENT // dt
+ Uint32Array.BYTES_PER_ELEMENT // xResolution
+ Float32Array.BYTES_PER_ELEMENT, // length
// How we use this buffer, in the debug case we copy it to another buffer for reading
usage: debug ? GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC : GPUBufferUsage.STORAGE
});
The buffer is mapped at creation, so we can immediately get the CPU side memory and populate it with our simulation parameters.
// Get the raw array buffer for the mapped GPU buffer
const parametersArrayBuffer = parametersBuffer.getMappedRange();
Remember that we create the
Float32Array
and
Uint32Array
s
backed by the GPU buffer's array buffer. This typed array constructor
takes the underlying byte array, the byte offset where the data begins, and the array length. Once
we have a typed array associated with a section of our array buffer, we load data into it. We must
use typed arrays here, even though we only have single elements in each array. This gives the native
32-bit types rather than the JavaScript 64-bit data.
var bytesSoFar = 0;
new Float32Array(parametersArrayBuffer, bytesSoFar, 1).set([dt]);
bytesSoFar += Float32Array.BYTES_PER_ELEMENT;
new Uint32Array(parametersArrayBuffer, bytesSoFar, 1).set([xResolution]);
bytesSoFar += Uint32Array.BYTES_PER_ELEMENT;
new Float32Array(parametersArrayBuffer, bytesSoFar, 1).set([length]);
dt
, xResolution
and length
.
Now that the parametersBuffer
is populated with our data, return it to
the GPU.
// Unmap the buffer returning ownership to the GPU.
parametersBuffer.unmap();
The last step in setting up the parameters buffer is the bind group. The bind group names the specific resource and ties it to the bind group layout, which identifies how it will be used.
parametersBindGroup = device.createBindGroup({
layout: parametersBindGroupLayout,
entries: [
{
binding: 0,
resource: {
buffer: parametersBuffer
}
}
]
});
Now we move on to the wave function buffers. Remember that we have one for the wave function at and the other for the wave function at .
// Wave function representations
waveFunctionBuffer0 = device.createBuffer({
size: 2*xResolution*Float32Array.BYTES_PER_ELEMENT,
usage: debug ? GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC : GPUBufferUsage.STORAGE
});
waveFunctionBuffer1 = device.createBuffer({
size: 2*xResolution*Float32Array.BYTES_PER_ELEMENT,
usage: debug ? GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC : GPUBufferUsage.STORAGE
});
We can write a quick description of how we will use these buffers. The first wave function buffer is the current wave function, which we only read. The second wave function buffer is the update wave function, which we compute and write in the compute shader.
waveFunctionBindGroupLayout = device.createBindGroupLayout({
label: "Wave function data.",
entries: [{
binding: 0,
visibility: GPUShaderStage.COMPUTE,
buffer: {
type: "read-only-storage"
}
},
{
binding: 1,
visibility: GPUShaderStage.COMPUTE,
buffer: {
type: "storage"
}
}]
});
We swap back and forth which buffer we use for the current wave function and which is used for the updated wave function, so we create two bind groups, and package them in an array.
waveFunctionBindGroup = new Array(2);
The first bind group uses waveFunctionBuffer0
for the current wave function.
and waveFunctionBuffer1
for the updated wave function.
waveFunctionBindGroup[0] = device.createBindGroup({
layout: waveFunctionBindGroupLayout,
entries: [{
binding: 0,
resource: {
buffer: waveFunctionBuffer0
}
},
{
binding: 1,
resource: {
buffer: waveFunctionBuffer1
}
}
]});
The second bind group swaps which buffer is used for the current wave function, and which
is used for the updated wave function. Now waveFunctionBuffer1
is bound
to 0, and waveFunctionBuffer0
is bound to 1. Successive calls to the shader
will swap back and forth between waveFunctionBindGroup[0]
and
waveFunctionBindGroup[1]
.
waveFunctionBindGroup[1] = device.createBindGroup({
layout: waveFunctionBindGroupLayout,
entries: [{
binding: 0,
resource: {
buffer: waveFunctionBuffer1
}
},
{
binding: 1,
resource: {
buffer: waveFunctionBuffer0
}
}
]});
The last resource we need is the compute pipeline. Using a compute pipeline specifies that we are doing computations rather than rendering. The pipeline specifies the shader and the types of resources we will employ. We take advantage of only identify the type of resources to swap around the actual resources without reconstructing the pipeline.
computePipeline = device.createComputePipeline({
layout: device.createPipelineLayout({
bindGroupLayouts: [parametersBindGroupLayout, waveFunctionBindGroupLayout]
}),
compute: {
module: timeStepShaderModule,
entryPoint: "timeStep"
}
});
We have defined these resources on the GPU through the device
. We now issue
commands to the GPU executing the shader using our buffers.
Start formulating the commands by computing a quick constant, the number of work groupss we dispatch. Since the workgroup size is 64, we choose a multiple of this, 512, for the size of the wave function buffers. We need one thread for each wave function buffer element, and the workgroup size is 64, so we need 8 work groups.
const workgroupCountX = Math.ceil(xResolution / 64);
Wrap the compute shader in a loop because each invocation is a very small step. To produce a noticeable progression, we execute several steps.
for (let i=0; i<count && running; i++)
The command encoder captures commands to be sent to the GPU. A new one is created within the loop because it can not be reused. I would love to recycle the command encoder, but the spec disallows it.
// Created in the loop because it can not be reused after finish is invoked.
const commandEncoder = device.createCommandEncoder();
We are only doing computations, so create a compute pass built on our compute pipeline.
const passEncoder = commandEncoder.beginComputePass();
passEncoder.setPipeline(computePipeline);
Remember that the parameters contain general parameters for the simulation, and do not change over the course of the simulation.
passEncoder.setBindGroup(0, parametersBindGroup);
This is the ping pong we mentioned earlier. On every iteration through the loop we swap the bind group in use, and hence swap the current and updated wave function buffers.
passEncoder.setBindGroup(1, waveFunctionBindGroup[i%2]);
The workgroup count was computed just above the loop so that we have enough workgroups to cover the wave function buffers..
passEncoder.dispatchWorkgroups(workgroupCountX);
Now we finish the compute pass, and submit the generated commands to the GPU. Like the command encoder, I would reuse the GPUCommandBuffer, however, this is forbidden by the spec.
passEncoder.end();
// Submit GPU commands.
device.queue.submit([commandEncoder.finish()]);
