Skip to content

Instantly share code, notes, and snippets.

@schwa
Created September 18, 2025 19:05
Show Gist options
  • Select an option

  • Save schwa/ef6158e8813bc49a14a31ec930b89a0f to your computer and use it in GitHub Desktop.

Select an option

Save schwa/ef6158e8813bc49a14a31ec930b89a0f to your computer and use it in GitHub Desktop.
Metal Visible Functions:
// https://developer.apple.com/videos/play/wwdc2020/10013
// Using Metal Visible Functions:
// 1. Define a function pointer type alias in the kernel (e.g., `using OperationFunction = float(float, constant void *)`)
// 2. Kernel takes a `visible_function_table<FunctionType>` parameter
// 3. Mark functions as `[[visible]]` in Metal shader code
// 4. Use MTLLinkedFunctions to link visible functions when creating the compute pipeline
// 5. Create MTLVisibleFunctionTable and populate with function handles from the pipeline
// 6. Pass the visible function table to compute encoder with setVisibleFunctionTable()
// 7. Access functions in kernel via table index: `operationFunctions[index]`
import Metal
func main() {
let device = MTLCreateSystemDefaultDevice()!
// MARK: Setup input buffer with data.
let count = 10
let bufferSize = count * MemoryLayout<Float>.size
let buffer = device.makeBuffer(length: bufferSize, options: .storageModeShared)!
let pointer = buffer.contents().bindMemory(to: Float.self, capacity: count)
for i in 0..<count {
pointer[i] = Float(i + 1)
}
print("Before: \((0..<count).map { pointer[$0] })")
// MARK: Compile the compute kernel that uses a visible function table.
let kernelSource = """
#include <metal_stdlib>
using namespace metal;
using OperationFunction = float(float, constant void *);
kernel void doubleValues(
device float* buffer [[buffer(0)]],
visible_function_table<OperationFunction> operationFunctions [[buffer(1)]],
constant int &functionIndex [[buffer(2)]],
constant void *functionParameter [[buffer(3)]],
uint index [[thread_position_in_grid]]
) {
auto operationFunction = operationFunctions[functionIndex];
buffer[index] = operationFunction(buffer[index], functionParameter);
}
"""
let kernelLibrary = try! device.makeLibrary(source: kernelSource, options: nil)
let kernelFunction = kernelLibrary.makeFunction(name: "doubleValues")!
// MARK: Compile the "helper" functions that will be called via the visible function table.
let operationSource = """
// Both functions need the same signature to match OperationFunction type, but we dont need to use parameter here
[[visible]] float times2(float value, constant void *) {
return value * 2.0;
}
// The kernel thinks it's a `void *` but we think it's a `float &` and through the power of no-type safety we can do what we want!
[[visible]] float timesN(float value, constant float &parameter) {
return value * parameter;
}
"""
let operationLibrary = try! device.makeLibrary(source: operationSource, options: nil)
let times2Function = operationLibrary.makeFunction(name: "times2")!
let timesNFunction = operationLibrary.makeFunction(name: "timesN")!
let linkedFunctions = MTLLinkedFunctions()
linkedFunctions.functions = [times2Function, timesNFunction]
// MARK: Create the pipeline with the kernel function and the linked functions.
let computePipelineDescriptor = MTLComputePipelineDescriptor()
computePipelineDescriptor.computeFunction = kernelFunction
computePipelineDescriptor.linkedFunctions = linkedFunctions
let computePipeline = try! device.makeComputePipelineState(descriptor: computePipelineDescriptor, options: [], reflection: nil)
// MARK: Create and populate the visible function table. We'll be choosing at runtime which function to call.
let visibleFunctionTableDescriptor = MTLVisibleFunctionTableDescriptor()
visibleFunctionTableDescriptor.functionCount = 2
let visibleFunctionTable = computePipeline.makeVisibleFunctionTable(descriptor: visibleFunctionTableDescriptor)!
visibleFunctionTable.setFunction(computePipeline.functionHandle(function: times2Function)!, index:0)
visibleFunctionTable.setFunction(computePipeline.functionHandle(function: timesNFunction)!, index:1)
// MARK: Start the compute pass.
for index in 0..<2 {
let commandQueue = device.makeCommandQueue()!
let commandBuffer = commandQueue.makeCommandBuffer()!
// MARK: Set up the compute encoder. Give it our visible function table.
let computeEncoder = commandBuffer.makeComputeCommandEncoder()!
computeEncoder.setComputePipelineState(computePipeline)
computeEncoder.setBuffer(buffer, offset: 0, index: 0)
computeEncoder.setVisibleFunctionTable(visibleFunctionTable, bufferIndex: 1)
withUnsafeBytes(of: UInt32(index)) { bytes in
let bytes = UnsafeRawBufferPointer(bytes)
computeEncoder.setBytes(bytes.baseAddress!, length: bytes.count, index: 2)
}
withUnsafeBytes(of: Float(10.0)) { bytes in
let bytes = UnsafeRawBufferPointer(bytes)
computeEncoder.setBytes(bytes.baseAddress!, length: bytes.count, index: 3)
}
// MARK: Dispatch the compute work.
let threadsPerGrid = MTLSize(width: count, height: 1, depth: 1)
let threadsPerThreadgroup = MTLSize(width: computePipeline.maxTotalThreadsPerThreadgroup, height: 1, depth: 1)
computeEncoder.dispatchThreads(threadsPerGrid, threadsPerThreadgroup: threadsPerThreadgroup)
// MARK: End the compute pass and wait for it to finish.
computeEncoder.endEncoding()
commandBuffer.commit()
commandBuffer.waitUntilCompleted()
// MARK: Print the results.
print("After operation #\(index): \((0..<count).map { pointer[$0] })")
}
}
main()
@schwa
Copy link
Author

schwa commented Sep 18, 2025

This should print:

Before: [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
After operation #0: [2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0, 18.0, 20.0]
After operation #1: [20.0, 40.0, 60.0, 80.0, 100.0, 120.0, 140.0, 160.0, 180.0, 200.0]
Program ended with exit code: 0

@schwa
Copy link
Author

schwa commented Sep 18, 2025

Alternative we could do something like:

ComputePipeline(…) {
     FunctionTable(…) { 
         ComputeDisptch(…)
         .parameter(functionTable…)
    }
}

@schwa
Copy link
Author

schwa commented Sep 18, 2025

It gets complex. We have to do some work when building the compute pipeline. And then store a value (a "compiled" visible function table).. Then we need to bind it to a buffer in the encoder.

Stitched function graphs may be able to avoid this complexity.

@schwa
Copy link
Author

schwa commented Sep 18, 2025

diff --git a/Sources/Ultraviolence/Metal/ComputePass.swift b/Sources/Ultraviolence/Metal/ComputePass.swift
index fa04f686a2..0f5e471293 100644
--- a/Sources/Ultraviolence/Metal/ComputePass.swift
+++ b/Sources/Ultraviolence/Metal/ComputePass.swift
@@ -41,7 +41,7 @@
     }
 
     func setupEnter(_ node: Node) throws {
-        let device = try node.environmentValues.device.orThrow(.missingEnvironment(\.device))
+        // Populate pipeline descriptor
         let descriptor = MTLComputePipelineDescriptor()
         if let label {
             descriptor.label = label
@@ -50,8 +50,26 @@
         if let linkedFunctions = node.environmentValues.linkedFunctions {
             descriptor.linkedFunctions = linkedFunctions
         }
+
+        let device = try node.environmentValues.device.orThrow(.missingEnvironment(\.device))
         let (computePipelineState, reflection) = try device.makeComputePipelineState(descriptor: descriptor, options: .bindingInfo)
         node.environmentValues.reflection = Reflection(try reflection.orThrow(.resourceCreationFailure("Failed to create reflection.")))
+
+        if let namedFunctionTable = node.environmentValues.namedFunctionTable {
+            let visibleFunctionTableDescriptor = MTLVisibleFunctionTableDescriptor()
+            visibleFunctionTableDescriptor.functionCount = namedFunctionTable.functions.count
+            let visibleFunctionTable = computePipelineState.makeVisibleFunctionTable(descriptor: visibleFunctionTableDescriptor)!
+            for (index, function) in namedFunctionTable.functions.enumerated() {
+                guard let functionHandle = computePipelineState.functionHandle(function: function) else {
+                    fatalError("Failed to get function handle for function \(function.name) in pipeline (is it nore in linkedFunctions?)")
+                }
+                visibleFunctionTable.setFunction(functionHandle, index:0)
+            }
+        }
+        // TODO: Now how do I get this into the encoder????
+
+
+        // Compile descriptor into pipeline state
         node.environmentValues.computePipelineState = computePipelineState
     }
 
@@ -62,3 +80,29 @@
         false
     }
 }
+
+
+// TODO: MOVE
+struct NamedFunctionTable {
+    var identifier: String
+    var functions: [MTLFunction]
+
+}
+
+// TODO: This limits us to a single function table per pipeline for now.
+
+extension UVEnvironmentValues {
+    @UVEntry
+    var namedFunctionTable: NamedFunctionTable?
+}
+
+public extension Element {
+    func namedFunctionTable(_ identifier: String, functions: [MTLFunction]) -> some Element {
+
+//        transformEnvironment(\.namedFunctionTables) { namedFunctionTables in
+//            namedFunctionTables[identifier] = NamedFunctionTable(identifier: identifier, functions: functions)
+//        }
+        environment(\.namedFunctionTable, NamedFunctionTable(identifier: identifier, functions: functions))
+    }
+}
+
diff --git a/Sources/Ultraviolence/Metal/ParameterValue.swift b/Sources/Ultraviolence/Metal/ParameterValue.swift
index c0ae45cb27..d64ac64183 100644
--- a/Sources/Ultraviolence/Metal/ParameterValue.swift
+++ b/Sources/Ultraviolence/Metal/ParameterValue.swift
@@ -6,6 +6,7 @@
     case buffer(MTLBuffer?, Int)
     case array([T])
     case value(T)
+    case visibleFunctionTable(MTLVisibleFunctionTable)
 }
 
 extension ParameterValue: CustomDebugStringConvertible {
@@ -21,6 +22,8 @@
             return "Array"
         case .value(let value):
             return "\(value)"
+        case .visibleFunctionTable:
+            return "VisibleFunctionTable()"
         }
     }
 }
@@ -37,18 +40,16 @@
         switch value {
         case .texture(let texture):
             setTexture(texture, index: index, functionType: functionType)
-
         case .samplerState(let samplerState):
             setSamplerState(samplerState, index: index, functionType: functionType)
-
         case .buffer(let buffer, let offset):
             setBuffer(buffer, offset: offset, index: index, functionType: functionType)
-
         case .array(let array):
             setUnsafeBytes(of: array, index: index, functionType: functionType)
-
         case .value(let value):
             setUnsafeBytes(of: value, index: index, functionType: functionType)
+        case .visibleFunctionTable(let table):
+            setVisibleFunctionTable(table, bufferIndex: index, functionType: functionType)
         }
     }
 }
@@ -58,18 +59,16 @@
         switch value {
         case .texture(let texture):
             setTexture(texture, index: index)
-
         case .samplerState(let samplerState):
             setSamplerState(samplerState, index: index)
-
         case .buffer(let buffer, let offset):
             setBuffer(buffer, offset: offset, index: index)
-
         case .array(let array):
             setUnsafeBytes(of: array, index: index)
-
         case .value(let value):
             setUnsafeBytes(of: value, index: index)
+        case .visibleFunctionTable(let table):
+            setVisibleFunctionTable(table, bufferIndex: index)
         }
     }
 }
diff --git a/Sources/Ultraviolence/Metal/Parameters.swift b/Sources/Ultraviolence/Metal/Parameters.swift
index 0a77ff4698..064914b8ef 100644
--- a/Sources/Ultraviolence/Metal/Parameters.swift
+++ b/Sources/Ultraviolence/Metal/Parameters.swift
@@ -123,6 +123,10 @@
         assert(isPOD(value), "Parameter value must be a POD type.")
         return ParameterElementModifier(functionType: functionType, name: name, value: .value(value), content: self)
     }
+
+    func parameter(_ name: String, functionType: MTLFunctionType? = nil, visibleFunctionTable: MTLVisibleFunctionTable) -> some Element {
+        return ParameterElementModifier(functionType: functionType, name: name, value: ParameterValue<()>.visibleFunctionTable(visibleFunctionTable), content: self)
+    }
 }
 
 extension String {
diff --git a/Sources/UltraviolenceExampleShaders/ColorAdjust.metal b/Sources/UltraviolenceExampleShaders/ColorAdjust.metal
index c762efafbb..3e2db801f3 100644
--- a/Sources/UltraviolenceExampleShaders/ColorAdjust.metal
+++ b/Sources/UltraviolenceExampleShaders/ColorAdjust.metal
@@ -3,8 +3,7 @@
 
 using namespace metal;
 
-[[ visible ]]
-float4 adjustColor(float4 inputColor, constant void *parameters);
+using AdjustColorFunction = float4(float4 inputColor, constant void *parameters);
 
 namespace ColorAdjust {
 
@@ -48,16 +47,14 @@
     kernel void colorAdjust(
         constant Texture2DSpecifierArgumentBuffer &inputSpecifier [[buffer(0)]],
         constant void *inputParameters [[buffer(1)]],
+        visible_function_table<AdjustColorFunction> adjustColorFunction [[buffer(2)]],
+
         texture2d<float, access::read_write> outputTexture [[texture(0)]]
     ) {
         bool discard = false;
         const float2 textureCoordinate = textureCoordinateForPixel(inputSpecifier, thread_position_in_grid);
         const float4 inputColor = resolveSpecifiedColor(inputSpecifier, textureCoordinate, discard);
-        // TODO: Make this a function pointer
-//        float4 newColor = pow(inputColor, 50.0);;
-
-        float4 newColor = adjustColor(inputColor, inputParameters);
-
+        float4 newColor = adjustColorFunction[0](inputColor, inputParameters);
         outputTexture.write(newColor, thread_position_in_grid);
     }
 }
diff --git a/Sources/UltraviolenceExamples/Demos/ColorAdjustDemo/ColorAdjustComputePipeline.swift b/Sources/UltraviolenceExamples/Demos/ColorAdjustDemo/ColorAdjustComputePipeline.swift
index 525506ae72..a377501417 100644
--- a/Sources/UltraviolenceExamples/Demos/ColorAdjustDemo/ColorAdjustComputePipeline.swift
+++ b/Sources/UltraviolenceExamples/Demos/ColorAdjustDemo/ColorAdjustComputePipeline.swift
@@ -6,30 +6,47 @@
     let inputSpecifier: Texture2DSpecifier
     let inputParameters: T
     let outputTexture: MTLTexture
-    var kernel: ComputeKernel
+    let kernel: ComputeKernel
+    let linkedFunctions: MTLLinkedFunctions
 
-    init(inputSpecifier: Texture2DSpecifier, inputParameters: T, outputTexture: MTLTexture) {
+    init(inputSpecifier: Texture2DSpecifier, inputParameters: T, outputTexture: MTLTexture, adjustColorFunction: MTLFunction) {
         self.inputSpecifier = inputSpecifier
         self.inputParameters = inputParameters
         self.outputTexture = outputTexture
         let shaderLibrary = try! ShaderLibrary(bundle: .ultraviolenceExampleShaders().orFatalError(), namespace: "ColorAdjust")
         self.kernel = try! shaderLibrary.colorAdjust
+        linkedFunctions = MTLLinkedFunctions()
+        linkedFunctions.privateFunctions = [adjustColorFunction]
+
+        let visibleFunctionTableDescriptor = MTLVisibleFunctionTableDescriptor()
+        visibleFunctionTableDescriptor.functionCount = 1
+
+        // Find and set functions by handle
+//        let functionHandle = pipeline.functionHandle(function: spot)!
+//        lightingFunctionTable.setFunction(functionHandle, index:0)
     }
 
     public var body: some Element {
         get throws {
-            try ComputePipeline(
-                computeKernel: kernel
-            ) {
+
+            try ComputePipeline(computeKernel: kernel) {
                 try ComputeDispatch(threadsPerGrid: [outputTexture.width, outputTexture.height, 1], threadsPerThreadgroup: [16, 16, 1])
                     // TODO: #280 Maybe a .argumentBuffer() is a better solution
                     .parameter("inputSpecifier", value: inputSpecifier.toTexture2DSpecifierArgmentBuffer())
                     .useComputeResource(inputSpecifier.texture2D, usage: .read)
                     .useComputeResource(inputSpecifier.textureCube, usage: .read)
                     .useComputeResource(inputSpecifier.depth2D, usage: .read)
+//                computeCommandEncoder.setVisibleFunctionTable(lightingFunctionTable, bufferIndex:1)
+
+//                    .parameter("xxx", visibleFunctionTable: visibleFunctionTable)
+
                     .parameter("inputParameters", value: inputParameters)
                     .parameter("outputTexture", texture: outputTexture)
+
+
             }
+            .environment(\.linkedFunctions, linkedFunctions)
         }
     }
 }
+
diff --git a/Sources/UltraviolenceExamples/Demos/ColorAdjustDemo/ColorAdjustDemoView.swift b/Sources/UltraviolenceExamples/Demos/ColorAdjustDemo/ColorAdjustDemoView.swift
index 07a329bb63..149844bcfe 100644
--- a/Sources/UltraviolenceExamples/Demos/ColorAdjustDemo/ColorAdjustDemoView.swift
+++ b/Sources/UltraviolenceExamples/Demos/ColorAdjustDemo/ColorAdjustDemoView.swift
@@ -19,7 +19,7 @@
     }
     """
 
-    let linkedFunctions: MTLLinkedFunctions
+    let adjustColorFunction: MTLFunction
 
     public init() {
         let device = _MTLCreateSystemDefaultDevice()
@@ -39,19 +39,14 @@
         // TODO: #278 Use Ultraviolence's normal shader loading capabilities
         // TODO: #279 Use proper Metal function loading - this one requires all functions to be named the same.
         let sourceLibrary = try! device.makeLibrary(source: adjustSource, options: nil)
-        let adjustColorFunction = sourceLibrary.makeFunction(name: "adjustColor")!
-        let linkedFunctions = MTLLinkedFunctions()
-        linkedFunctions.privateFunctions = [adjustColorFunction]
-
-        self.linkedFunctions = linkedFunctions
+        adjustColorFunction = sourceLibrary.makeFunction(name: "adjustColor")!
     }
 
     public var body: some View {
         RenderView { _, _ in
             try ComputePass(label: "ColorAdjust") {
-                ColorAdjustComputePipeline(inputSpecifier: .texture2D(sourceTexture, nil), inputParameters: Float(0.5), outputTexture: adjustedTexture)
+                ColorAdjustComputePipeline(inputSpecifier: .texture2D(sourceTexture, nil), inputParameters: Float(0.5), outputTexture: adjustedTexture, adjustColorFunction: adjustColorFunction)
             }
-            .environment(\.linkedFunctions, linkedFunctions)
 
             try RenderPass {
                 try BillboardRenderPipeline(specifier: .texture2D(adjustedTexture))
diff --git a/Sources/UltraviolenceExamples/Demos/DepthDemo/DepthDemoView.swift b/Sources/UltraviolenceExamples/Demos/DepthDemo/DepthDemoView.swift
index b75ca7e647..cfb5f18fef 100644
--- a/Sources/UltraviolenceExamples/Demos/DepthDemo/DepthDemoView.swift
+++ b/Sources/UltraviolenceExamples/Demos/DepthDemo/DepthDemoView.swift
@@ -33,6 +33,8 @@
 
     let teapot = MTKMesh.teapot()
 
+    let adjustColorFunction: MTLFunction
+
     let adjustSource = """
     #include <metal_stdlib>
     using namespace metal;
@@ -43,8 +45,6 @@
     }
     """
 
-    let linkedFunctions: MTLLinkedFunctions
-
     public init() {
         let device = _MTLCreateSystemDefaultDevice()
 
@@ -64,12 +64,8 @@
         stitchedLibraryDescriptor.functions = [sourceLibrary.makeFunction(name: "node")!]
         stitchedLibraryDescriptor.functionGraphs = [graph]
         let stitchedLibrary = try! device.makeLibrary(stitchedDescriptor: stitchedLibraryDescriptor)
-        let stitchedFunction = stitchedLibrary.makeFunction(name: "adjustColor")!
-
-        let linkedFunctions = MTLLinkedFunctions()
-        linkedFunctions.privateFunctions = [stitchedFunction]
-
-        self.linkedFunctions = linkedFunctions
+        adjustColorFunction = stitchedLibrary.makeFunction(name: "adjustColor")!
+
     }
 
     public var body: some View {
@@ -91,9 +87,8 @@
                     }
 
                     try ComputePass(label: "ColorAdjust") {
-                        ColorAdjustComputePipeline(inputSpecifier: .depth2D(depthTexture, nil), inputParameters: exponent, outputTexture: adjustedDepthTexture)
+                        ColorAdjustComputePipeline(inputSpecifier: .depth2D(depthTexture, nil), inputParameters: exponent, outputTexture: adjustedDepthTexture, adjustColorFunction: adjustColorFunction)
                     }
-                    .environment(\.linkedFunctions, linkedFunctions)
 
                     try RenderPass(label: "Depth to Screen Pass") {
                         try BillboardRenderPipeline(specifier: showDepthMap ? .texture2D(adjustedDepthTexture, nil) : .texture2D(colorTexture, nil), flippedY: true)
diff --git a/Sources/UltraviolenceSupport/MetalSupport.swift b/Sources/UltraviolenceSupport/MetalSupport.swift
index c3655441f0..c5c99172d8 100644
--- a/Sources/UltraviolenceSupport/MetalSupport.swift
+++ b/Sources/UltraviolenceSupport/MetalSupport.swift
@@ -607,14 +607,24 @@
         switch functionType {
         case .vertex:
             setVertexSamplerState(sampler, index: index)
-
         case .fragment:
             setFragmentSamplerState(sampler, index: index)
-
-        default:
-            fatalError("Unimplemented")
-        }
-    }
+        default:
+            fatalError("Unimplemented")
+        }
+    }
+
+    func setVisibleFunctionTable(_ functionTable: MTLVisibleFunctionTable?, bufferIndex: Int, functionType: MTLFunctionType) {
+        switch functionType {
+        case .vertex:
+            setVertexVisibleFunctionTable(functionTable, bufferIndex: bufferIndex)
+        case .fragment:
+            setFragmentVisibleFunctionTable(functionTable, bufferIndex: bufferIndex)
+        default:
+            fatalError("Unimplemented")
+        }
+    }
+
 }
 
 public extension MTLComputeCommandEncoder {

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment