Created
September 18, 2025 19:05
-
-
Save schwa/ef6158e8813bc49a14a31ec930b89a0f to your computer and use it in GitHub Desktop.
Metal Visible Functions:
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| // https://developer.apple.com/videos/play/wwdc2020/10013 | |
| // Using Metal Visible Functions: | |
| // 1. Define a function pointer type alias in the kernel (e.g., `using OperationFunction = float(float, constant void *)`) | |
| // 2. Kernel takes a `visible_function_table<FunctionType>` parameter | |
| // 3. Mark functions as `[[visible]]` in Metal shader code | |
| // 4. Use MTLLinkedFunctions to link visible functions when creating the compute pipeline | |
| // 5. Create MTLVisibleFunctionTable and populate with function handles from the pipeline | |
| // 6. Pass the visible function table to compute encoder with setVisibleFunctionTable() | |
| // 7. Access functions in kernel via table index: `operationFunctions[index]` | |
| import Metal | |
| func main() { | |
| let device = MTLCreateSystemDefaultDevice()! | |
| // MARK: Setup input buffer with data. | |
| let count = 10 | |
| let bufferSize = count * MemoryLayout<Float>.size | |
| let buffer = device.makeBuffer(length: bufferSize, options: .storageModeShared)! | |
| let pointer = buffer.contents().bindMemory(to: Float.self, capacity: count) | |
| for i in 0..<count { | |
| pointer[i] = Float(i + 1) | |
| } | |
| print("Before: \((0..<count).map { pointer[$0] })") | |
| // MARK: Compile the compute kernel that uses a visible function table. | |
| let kernelSource = """ | |
| #include <metal_stdlib> | |
| using namespace metal; | |
| using OperationFunction = float(float, constant void *); | |
| kernel void doubleValues( | |
| device float* buffer [[buffer(0)]], | |
| visible_function_table<OperationFunction> operationFunctions [[buffer(1)]], | |
| constant int &functionIndex [[buffer(2)]], | |
| constant void *functionParameter [[buffer(3)]], | |
| uint index [[thread_position_in_grid]] | |
| ) { | |
| auto operationFunction = operationFunctions[functionIndex]; | |
| buffer[index] = operationFunction(buffer[index], functionParameter); | |
| } | |
| """ | |
| let kernelLibrary = try! device.makeLibrary(source: kernelSource, options: nil) | |
| let kernelFunction = kernelLibrary.makeFunction(name: "doubleValues")! | |
| // MARK: Compile the "helper" functions that will be called via the visible function table. | |
| let operationSource = """ | |
| // Both functions need the same signature to match OperationFunction type, but we dont need to use parameter here | |
| [[visible]] float times2(float value, constant void *) { | |
| return value * 2.0; | |
| } | |
| // The kernel thinks it's a `void *` but we think it's a `float &` and through the power of no-type safety we can do what we want! | |
| [[visible]] float timesN(float value, constant float ¶meter) { | |
| return value * parameter; | |
| } | |
| """ | |
| let operationLibrary = try! device.makeLibrary(source: operationSource, options: nil) | |
| let times2Function = operationLibrary.makeFunction(name: "times2")! | |
| let timesNFunction = operationLibrary.makeFunction(name: "timesN")! | |
| let linkedFunctions = MTLLinkedFunctions() | |
| linkedFunctions.functions = [times2Function, timesNFunction] | |
| // MARK: Create the pipeline with the kernel function and the linked functions. | |
| let computePipelineDescriptor = MTLComputePipelineDescriptor() | |
| computePipelineDescriptor.computeFunction = kernelFunction | |
| computePipelineDescriptor.linkedFunctions = linkedFunctions | |
| let computePipeline = try! device.makeComputePipelineState(descriptor: computePipelineDescriptor, options: [], reflection: nil) | |
| // MARK: Create and populate the visible function table. We'll be choosing at runtime which function to call. | |
| let visibleFunctionTableDescriptor = MTLVisibleFunctionTableDescriptor() | |
| visibleFunctionTableDescriptor.functionCount = 2 | |
| let visibleFunctionTable = computePipeline.makeVisibleFunctionTable(descriptor: visibleFunctionTableDescriptor)! | |
| visibleFunctionTable.setFunction(computePipeline.functionHandle(function: times2Function)!, index:0) | |
| visibleFunctionTable.setFunction(computePipeline.functionHandle(function: timesNFunction)!, index:1) | |
| // MARK: Start the compute pass. | |
| for index in 0..<2 { | |
| let commandQueue = device.makeCommandQueue()! | |
| let commandBuffer = commandQueue.makeCommandBuffer()! | |
| // MARK: Set up the compute encoder. Give it our visible function table. | |
| let computeEncoder = commandBuffer.makeComputeCommandEncoder()! | |
| computeEncoder.setComputePipelineState(computePipeline) | |
| computeEncoder.setBuffer(buffer, offset: 0, index: 0) | |
| computeEncoder.setVisibleFunctionTable(visibleFunctionTable, bufferIndex: 1) | |
| withUnsafeBytes(of: UInt32(index)) { bytes in | |
| let bytes = UnsafeRawBufferPointer(bytes) | |
| computeEncoder.setBytes(bytes.baseAddress!, length: bytes.count, index: 2) | |
| } | |
| withUnsafeBytes(of: Float(10.0)) { bytes in | |
| let bytes = UnsafeRawBufferPointer(bytes) | |
| computeEncoder.setBytes(bytes.baseAddress!, length: bytes.count, index: 3) | |
| } | |
| // MARK: Dispatch the compute work. | |
| let threadsPerGrid = MTLSize(width: count, height: 1, depth: 1) | |
| let threadsPerThreadgroup = MTLSize(width: computePipeline.maxTotalThreadsPerThreadgroup, height: 1, depth: 1) | |
| computeEncoder.dispatchThreads(threadsPerGrid, threadsPerThreadgroup: threadsPerThreadgroup) | |
| // MARK: End the compute pass and wait for it to finish. | |
| computeEncoder.endEncoding() | |
| commandBuffer.commit() | |
| commandBuffer.waitUntilCompleted() | |
| // MARK: Print the results. | |
| print("After operation #\(index): \((0..<count).map { pointer[$0] })") | |
| } | |
| } | |
| main() |
Author
Author
Alternative we could do something like:
ComputePipeline(…) {
FunctionTable(…) {
ComputeDisptch(…)
.parameter(functionTable…)
}
}
Author
It gets complex. We have to do some work when building the compute pipeline. And then store a value (a "compiled" visible function table).. Then we need to bind it to a buffer in the encoder.
Stitched function graphs may be able to avoid this complexity.
Author
diff --git a/Sources/Ultraviolence/Metal/ComputePass.swift b/Sources/Ultraviolence/Metal/ComputePass.swift
index fa04f686a2..0f5e471293 100644
--- a/Sources/Ultraviolence/Metal/ComputePass.swift
+++ b/Sources/Ultraviolence/Metal/ComputePass.swift
@@ -41,7 +41,7 @@
}
func setupEnter(_ node: Node) throws {
- let device = try node.environmentValues.device.orThrow(.missingEnvironment(\.device))
+ // Populate pipeline descriptor
let descriptor = MTLComputePipelineDescriptor()
if let label {
descriptor.label = label
@@ -50,8 +50,26 @@
if let linkedFunctions = node.environmentValues.linkedFunctions {
descriptor.linkedFunctions = linkedFunctions
}
+
+ let device = try node.environmentValues.device.orThrow(.missingEnvironment(\.device))
let (computePipelineState, reflection) = try device.makeComputePipelineState(descriptor: descriptor, options: .bindingInfo)
node.environmentValues.reflection = Reflection(try reflection.orThrow(.resourceCreationFailure("Failed to create reflection.")))
+
+ if let namedFunctionTable = node.environmentValues.namedFunctionTable {
+ let visibleFunctionTableDescriptor = MTLVisibleFunctionTableDescriptor()
+ visibleFunctionTableDescriptor.functionCount = namedFunctionTable.functions.count
+ let visibleFunctionTable = computePipelineState.makeVisibleFunctionTable(descriptor: visibleFunctionTableDescriptor)!
+ for (index, function) in namedFunctionTable.functions.enumerated() {
+ guard let functionHandle = computePipelineState.functionHandle(function: function) else {
+ fatalError("Failed to get function handle for function \(function.name) in pipeline (is it nore in linkedFunctions?)")
+ }
+ visibleFunctionTable.setFunction(functionHandle, index:0)
+ }
+ }
+ // TODO: Now how do I get this into the encoder????
+
+
+ // Compile descriptor into pipeline state
node.environmentValues.computePipelineState = computePipelineState
}
@@ -62,3 +80,29 @@
false
}
}
+
+
+// TODO: MOVE
+struct NamedFunctionTable {
+ var identifier: String
+ var functions: [MTLFunction]
+
+}
+
+// TODO: This limits us to a single function table per pipeline for now.
+
+extension UVEnvironmentValues {
+ @UVEntry
+ var namedFunctionTable: NamedFunctionTable?
+}
+
+public extension Element {
+ func namedFunctionTable(_ identifier: String, functions: [MTLFunction]) -> some Element {
+
+// transformEnvironment(\.namedFunctionTables) { namedFunctionTables in
+// namedFunctionTables[identifier] = NamedFunctionTable(identifier: identifier, functions: functions)
+// }
+ environment(\.namedFunctionTable, NamedFunctionTable(identifier: identifier, functions: functions))
+ }
+}
+
diff --git a/Sources/Ultraviolence/Metal/ParameterValue.swift b/Sources/Ultraviolence/Metal/ParameterValue.swift
index c0ae45cb27..d64ac64183 100644
--- a/Sources/Ultraviolence/Metal/ParameterValue.swift
+++ b/Sources/Ultraviolence/Metal/ParameterValue.swift
@@ -6,6 +6,7 @@
case buffer(MTLBuffer?, Int)
case array([T])
case value(T)
+ case visibleFunctionTable(MTLVisibleFunctionTable)
}
extension ParameterValue: CustomDebugStringConvertible {
@@ -21,6 +22,8 @@
return "Array"
case .value(let value):
return "\(value)"
+ case .visibleFunctionTable:
+ return "VisibleFunctionTable()"
}
}
}
@@ -37,18 +40,16 @@
switch value {
case .texture(let texture):
setTexture(texture, index: index, functionType: functionType)
-
case .samplerState(let samplerState):
setSamplerState(samplerState, index: index, functionType: functionType)
-
case .buffer(let buffer, let offset):
setBuffer(buffer, offset: offset, index: index, functionType: functionType)
-
case .array(let array):
setUnsafeBytes(of: array, index: index, functionType: functionType)
-
case .value(let value):
setUnsafeBytes(of: value, index: index, functionType: functionType)
+ case .visibleFunctionTable(let table):
+ setVisibleFunctionTable(table, bufferIndex: index, functionType: functionType)
}
}
}
@@ -58,18 +59,16 @@
switch value {
case .texture(let texture):
setTexture(texture, index: index)
-
case .samplerState(let samplerState):
setSamplerState(samplerState, index: index)
-
case .buffer(let buffer, let offset):
setBuffer(buffer, offset: offset, index: index)
-
case .array(let array):
setUnsafeBytes(of: array, index: index)
-
case .value(let value):
setUnsafeBytes(of: value, index: index)
+ case .visibleFunctionTable(let table):
+ setVisibleFunctionTable(table, bufferIndex: index)
}
}
}
diff --git a/Sources/Ultraviolence/Metal/Parameters.swift b/Sources/Ultraviolence/Metal/Parameters.swift
index 0a77ff4698..064914b8ef 100644
--- a/Sources/Ultraviolence/Metal/Parameters.swift
+++ b/Sources/Ultraviolence/Metal/Parameters.swift
@@ -123,6 +123,10 @@
assert(isPOD(value), "Parameter value must be a POD type.")
return ParameterElementModifier(functionType: functionType, name: name, value: .value(value), content: self)
}
+
+ func parameter(_ name: String, functionType: MTLFunctionType? = nil, visibleFunctionTable: MTLVisibleFunctionTable) -> some Element {
+ return ParameterElementModifier(functionType: functionType, name: name, value: ParameterValue<()>.visibleFunctionTable(visibleFunctionTable), content: self)
+ }
}
extension String {
diff --git a/Sources/UltraviolenceExampleShaders/ColorAdjust.metal b/Sources/UltraviolenceExampleShaders/ColorAdjust.metal
index c762efafbb..3e2db801f3 100644
--- a/Sources/UltraviolenceExampleShaders/ColorAdjust.metal
+++ b/Sources/UltraviolenceExampleShaders/ColorAdjust.metal
@@ -3,8 +3,7 @@
using namespace metal;
-[[ visible ]]
-float4 adjustColor(float4 inputColor, constant void *parameters);
+using AdjustColorFunction = float4(float4 inputColor, constant void *parameters);
namespace ColorAdjust {
@@ -48,16 +47,14 @@
kernel void colorAdjust(
constant Texture2DSpecifierArgumentBuffer &inputSpecifier [[buffer(0)]],
constant void *inputParameters [[buffer(1)]],
+ visible_function_table<AdjustColorFunction> adjustColorFunction [[buffer(2)]],
+
texture2d<float, access::read_write> outputTexture [[texture(0)]]
) {
bool discard = false;
const float2 textureCoordinate = textureCoordinateForPixel(inputSpecifier, thread_position_in_grid);
const float4 inputColor = resolveSpecifiedColor(inputSpecifier, textureCoordinate, discard);
- // TODO: Make this a function pointer
-// float4 newColor = pow(inputColor, 50.0);;
-
- float4 newColor = adjustColor(inputColor, inputParameters);
-
+ float4 newColor = adjustColorFunction[0](inputColor, inputParameters);
outputTexture.write(newColor, thread_position_in_grid);
}
}
diff --git a/Sources/UltraviolenceExamples/Demos/ColorAdjustDemo/ColorAdjustComputePipeline.swift b/Sources/UltraviolenceExamples/Demos/ColorAdjustDemo/ColorAdjustComputePipeline.swift
index 525506ae72..a377501417 100644
--- a/Sources/UltraviolenceExamples/Demos/ColorAdjustDemo/ColorAdjustComputePipeline.swift
+++ b/Sources/UltraviolenceExamples/Demos/ColorAdjustDemo/ColorAdjustComputePipeline.swift
@@ -6,30 +6,47 @@
let inputSpecifier: Texture2DSpecifier
let inputParameters: T
let outputTexture: MTLTexture
- var kernel: ComputeKernel
+ let kernel: ComputeKernel
+ let linkedFunctions: MTLLinkedFunctions
- init(inputSpecifier: Texture2DSpecifier, inputParameters: T, outputTexture: MTLTexture) {
+ init(inputSpecifier: Texture2DSpecifier, inputParameters: T, outputTexture: MTLTexture, adjustColorFunction: MTLFunction) {
self.inputSpecifier = inputSpecifier
self.inputParameters = inputParameters
self.outputTexture = outputTexture
let shaderLibrary = try! ShaderLibrary(bundle: .ultraviolenceExampleShaders().orFatalError(), namespace: "ColorAdjust")
self.kernel = try! shaderLibrary.colorAdjust
+ linkedFunctions = MTLLinkedFunctions()
+ linkedFunctions.privateFunctions = [adjustColorFunction]
+
+ let visibleFunctionTableDescriptor = MTLVisibleFunctionTableDescriptor()
+ visibleFunctionTableDescriptor.functionCount = 1
+
+ // Find and set functions by handle
+// let functionHandle = pipeline.functionHandle(function: spot)!
+// lightingFunctionTable.setFunction(functionHandle, index:0)
}
public var body: some Element {
get throws {
- try ComputePipeline(
- computeKernel: kernel
- ) {
+
+ try ComputePipeline(computeKernel: kernel) {
try ComputeDispatch(threadsPerGrid: [outputTexture.width, outputTexture.height, 1], threadsPerThreadgroup: [16, 16, 1])
// TODO: #280 Maybe a .argumentBuffer() is a better solution
.parameter("inputSpecifier", value: inputSpecifier.toTexture2DSpecifierArgmentBuffer())
.useComputeResource(inputSpecifier.texture2D, usage: .read)
.useComputeResource(inputSpecifier.textureCube, usage: .read)
.useComputeResource(inputSpecifier.depth2D, usage: .read)
+// computeCommandEncoder.setVisibleFunctionTable(lightingFunctionTable, bufferIndex:1)
+
+// .parameter("xxx", visibleFunctionTable: visibleFunctionTable)
+
.parameter("inputParameters", value: inputParameters)
.parameter("outputTexture", texture: outputTexture)
+
+
}
+ .environment(\.linkedFunctions, linkedFunctions)
}
}
}
+
diff --git a/Sources/UltraviolenceExamples/Demos/ColorAdjustDemo/ColorAdjustDemoView.swift b/Sources/UltraviolenceExamples/Demos/ColorAdjustDemo/ColorAdjustDemoView.swift
index 07a329bb63..149844bcfe 100644
--- a/Sources/UltraviolenceExamples/Demos/ColorAdjustDemo/ColorAdjustDemoView.swift
+++ b/Sources/UltraviolenceExamples/Demos/ColorAdjustDemo/ColorAdjustDemoView.swift
@@ -19,7 +19,7 @@
}
"""
- let linkedFunctions: MTLLinkedFunctions
+ let adjustColorFunction: MTLFunction
public init() {
let device = _MTLCreateSystemDefaultDevice()
@@ -39,19 +39,14 @@
// TODO: #278 Use Ultraviolence's normal shader loading capabilities
// TODO: #279 Use proper Metal function loading - this one requires all functions to be named the same.
let sourceLibrary = try! device.makeLibrary(source: adjustSource, options: nil)
- let adjustColorFunction = sourceLibrary.makeFunction(name: "adjustColor")!
- let linkedFunctions = MTLLinkedFunctions()
- linkedFunctions.privateFunctions = [adjustColorFunction]
-
- self.linkedFunctions = linkedFunctions
+ adjustColorFunction = sourceLibrary.makeFunction(name: "adjustColor")!
}
public var body: some View {
RenderView { _, _ in
try ComputePass(label: "ColorAdjust") {
- ColorAdjustComputePipeline(inputSpecifier: .texture2D(sourceTexture, nil), inputParameters: Float(0.5), outputTexture: adjustedTexture)
+ ColorAdjustComputePipeline(inputSpecifier: .texture2D(sourceTexture, nil), inputParameters: Float(0.5), outputTexture: adjustedTexture, adjustColorFunction: adjustColorFunction)
}
- .environment(\.linkedFunctions, linkedFunctions)
try RenderPass {
try BillboardRenderPipeline(specifier: .texture2D(adjustedTexture))
diff --git a/Sources/UltraviolenceExamples/Demos/DepthDemo/DepthDemoView.swift b/Sources/UltraviolenceExamples/Demos/DepthDemo/DepthDemoView.swift
index b75ca7e647..cfb5f18fef 100644
--- a/Sources/UltraviolenceExamples/Demos/DepthDemo/DepthDemoView.swift
+++ b/Sources/UltraviolenceExamples/Demos/DepthDemo/DepthDemoView.swift
@@ -33,6 +33,8 @@
let teapot = MTKMesh.teapot()
+ let adjustColorFunction: MTLFunction
+
let adjustSource = """
#include <metal_stdlib>
using namespace metal;
@@ -43,8 +45,6 @@
}
"""
- let linkedFunctions: MTLLinkedFunctions
-
public init() {
let device = _MTLCreateSystemDefaultDevice()
@@ -64,12 +64,8 @@
stitchedLibraryDescriptor.functions = [sourceLibrary.makeFunction(name: "node")!]
stitchedLibraryDescriptor.functionGraphs = [graph]
let stitchedLibrary = try! device.makeLibrary(stitchedDescriptor: stitchedLibraryDescriptor)
- let stitchedFunction = stitchedLibrary.makeFunction(name: "adjustColor")!
-
- let linkedFunctions = MTLLinkedFunctions()
- linkedFunctions.privateFunctions = [stitchedFunction]
-
- self.linkedFunctions = linkedFunctions
+ adjustColorFunction = stitchedLibrary.makeFunction(name: "adjustColor")!
+
}
public var body: some View {
@@ -91,9 +87,8 @@
}
try ComputePass(label: "ColorAdjust") {
- ColorAdjustComputePipeline(inputSpecifier: .depth2D(depthTexture, nil), inputParameters: exponent, outputTexture: adjustedDepthTexture)
+ ColorAdjustComputePipeline(inputSpecifier: .depth2D(depthTexture, nil), inputParameters: exponent, outputTexture: adjustedDepthTexture, adjustColorFunction: adjustColorFunction)
}
- .environment(\.linkedFunctions, linkedFunctions)
try RenderPass(label: "Depth to Screen Pass") {
try BillboardRenderPipeline(specifier: showDepthMap ? .texture2D(adjustedDepthTexture, nil) : .texture2D(colorTexture, nil), flippedY: true)
diff --git a/Sources/UltraviolenceSupport/MetalSupport.swift b/Sources/UltraviolenceSupport/MetalSupport.swift
index c3655441f0..c5c99172d8 100644
--- a/Sources/UltraviolenceSupport/MetalSupport.swift
+++ b/Sources/UltraviolenceSupport/MetalSupport.swift
@@ -607,14 +607,24 @@
switch functionType {
case .vertex:
setVertexSamplerState(sampler, index: index)
-
case .fragment:
setFragmentSamplerState(sampler, index: index)
-
- default:
- fatalError("Unimplemented")
- }
- }
+ default:
+ fatalError("Unimplemented")
+ }
+ }
+
+ func setVisibleFunctionTable(_ functionTable: MTLVisibleFunctionTable?, bufferIndex: Int, functionType: MTLFunctionType) {
+ switch functionType {
+ case .vertex:
+ setVertexVisibleFunctionTable(functionTable, bufferIndex: bufferIndex)
+ case .fragment:
+ setFragmentVisibleFunctionTable(functionTable, bufferIndex: bufferIndex)
+ default:
+ fatalError("Unimplemented")
+ }
+ }
+
}
public extension MTLComputeCommandEncoder {
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This should print:
Before: [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
After operation #0: [2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0, 18.0, 20.0]
After operation #1: [20.0, 40.0, 60.0, 80.0, 100.0, 120.0, 140.0, 160.0, 180.0, 200.0]
Program ended with exit code: 0