Created
September 18, 2025 19:05
-
-
Save schwa/ef6158e8813bc49a14a31ec930b89a0f to your computer and use it in GitHub Desktop.
Metal Visible Functions:
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| // https://developer.apple.com/videos/play/wwdc2020/10013 | |
| // Using Metal Visible Functions: | |
| // 1. Define a function pointer type alias in the kernel (e.g., `using OperationFunction = float(float, constant void *)`) | |
| // 2. Kernel takes a `visible_function_table<FunctionType>` parameter | |
| // 3. Mark functions as `[[visible]]` in Metal shader code | |
| // 4. Use MTLLinkedFunctions to link visible functions when creating the compute pipeline | |
| // 5. Create MTLVisibleFunctionTable and populate with function handles from the pipeline | |
| // 6. Pass the visible function table to compute encoder with setVisibleFunctionTable() | |
| // 7. Access functions in kernel via table index: `operationFunctions[index]` | |
| import Metal | |
| func main() { | |
| let device = MTLCreateSystemDefaultDevice()! | |
| // MARK: Setup input buffer with data. | |
| let count = 10 | |
| let bufferSize = count * MemoryLayout<Float>.size | |
| let buffer = device.makeBuffer(length: bufferSize, options: .storageModeShared)! | |
| let pointer = buffer.contents().bindMemory(to: Float.self, capacity: count) | |
| for i in 0..<count { | |
| pointer[i] = Float(i + 1) | |
| } | |
| print("Before: \((0..<count).map { pointer[$0] })") | |
| // MARK: Compile the compute kernel that uses a visible function table. | |
| let kernelSource = """ | |
| #include <metal_stdlib> | |
| using namespace metal; | |
| using OperationFunction = float(float, constant void *); | |
| kernel void doubleValues( | |
| device float* buffer [[buffer(0)]], | |
| visible_function_table<OperationFunction> operationFunctions [[buffer(1)]], | |
| constant int &functionIndex [[buffer(2)]], | |
| constant void *functionParameter [[buffer(3)]], | |
| uint index [[thread_position_in_grid]] | |
| ) { | |
| auto operationFunction = operationFunctions[functionIndex]; | |
| buffer[index] = operationFunction(buffer[index], functionParameter); | |
| } | |
| """ | |
| let kernelLibrary = try! device.makeLibrary(source: kernelSource, options: nil) | |
| let kernelFunction = kernelLibrary.makeFunction(name: "doubleValues")! | |
| // MARK: Compile the "helper" functions that will be called via the visible function table. | |
| let operationSource = """ | |
| // Both functions need the same signature to match OperationFunction type, but we dont need to use parameter here | |
| [[visible]] float times2(float value, constant void *) { | |
| return value * 2.0; | |
| } | |
| // The kernel thinks it's a `void *` but we think it's a `float &` and through the power of no-type safety we can do what we want! | |
| [[visible]] float timesN(float value, constant float ¶meter) { | |
| return value * parameter; | |
| } | |
| """ | |
| let operationLibrary = try! device.makeLibrary(source: operationSource, options: nil) | |
| let times2Function = operationLibrary.makeFunction(name: "times2")! | |
| let timesNFunction = operationLibrary.makeFunction(name: "timesN")! | |
| let linkedFunctions = MTLLinkedFunctions() | |
| linkedFunctions.functions = [times2Function, timesNFunction] | |
| // MARK: Create the pipeline with the kernel function and the linked functions. | |
| let computePipelineDescriptor = MTLComputePipelineDescriptor() | |
| computePipelineDescriptor.computeFunction = kernelFunction | |
| computePipelineDescriptor.linkedFunctions = linkedFunctions | |
| let computePipeline = try! device.makeComputePipelineState(descriptor: computePipelineDescriptor, options: [], reflection: nil) | |
| // MARK: Create and populate the visible function table. We'll be choosing at runtime which function to call. | |
| let visibleFunctionTableDescriptor = MTLVisibleFunctionTableDescriptor() | |
| visibleFunctionTableDescriptor.functionCount = 2 | |
| let visibleFunctionTable = computePipeline.makeVisibleFunctionTable(descriptor: visibleFunctionTableDescriptor)! | |
| visibleFunctionTable.setFunction(computePipeline.functionHandle(function: times2Function)!, index:0) | |
| visibleFunctionTable.setFunction(computePipeline.functionHandle(function: timesNFunction)!, index:1) | |
| // MARK: Start the compute pass. | |
| for index in 0..<2 { | |
| let commandQueue = device.makeCommandQueue()! | |
| let commandBuffer = commandQueue.makeCommandBuffer()! | |
| // MARK: Set up the compute encoder. Give it our visible function table. | |
| let computeEncoder = commandBuffer.makeComputeCommandEncoder()! | |
| computeEncoder.setComputePipelineState(computePipeline) | |
| computeEncoder.setBuffer(buffer, offset: 0, index: 0) | |
| computeEncoder.setVisibleFunctionTable(visibleFunctionTable, bufferIndex: 1) | |
| withUnsafeBytes(of: UInt32(index)) { bytes in | |
| let bytes = UnsafeRawBufferPointer(bytes) | |
| computeEncoder.setBytes(bytes.baseAddress!, length: bytes.count, index: 2) | |
| } | |
| withUnsafeBytes(of: Float(10.0)) { bytes in | |
| let bytes = UnsafeRawBufferPointer(bytes) | |
| computeEncoder.setBytes(bytes.baseAddress!, length: bytes.count, index: 3) | |
| } | |
| // MARK: Dispatch the compute work. | |
| let threadsPerGrid = MTLSize(width: count, height: 1, depth: 1) | |
| let threadsPerThreadgroup = MTLSize(width: computePipeline.maxTotalThreadsPerThreadgroup, height: 1, depth: 1) | |
| computeEncoder.dispatchThreads(threadsPerGrid, threadsPerThreadgroup: threadsPerThreadgroup) | |
| // MARK: End the compute pass and wait for it to finish. | |
| computeEncoder.endEncoding() | |
| commandBuffer.commit() | |
| commandBuffer.waitUntilCompleted() | |
| // MARK: Print the results. | |
| print("After operation #\(index): \((0..<count).map { pointer[$0] })") | |
| } | |
| } | |
| main() |
Author
schwa
commented
Sep 18, 2025
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment