-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtexture_inference_common.hlsl
More file actions
88 lines (73 loc) · 3 KB
/
Copy pathtexture_inference_common.hlsl
File metadata and controls
88 lines (73 loc) · 3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
/*!
\file texture_inference_common.hlsl
\author Sho Ikeda
\brief Shared inference step for texture MLP (GPU + C++ fallback)
\copyright Copyright (c) 2026 Advanced Micro Devices, Inc. All Rights Reserved.
SPDX-License-Identifier: MIT
Provides the common inference step (forward pass) used by both
01_texture_inference.comp and the C++ fallback inference kernel.
Requires mlp.hlsl to be included before this file.
*/
#ifndef MINIDXNN_TEXTURE_INFERENCE_COMMON_HLSL
#define MINIDXNN_TEXTURE_INFERENCE_COMMON_HLSL 1
#if defined(MINIDXNN_CPP_HLSL_COMPAT_HPP) && (defined(__GNUC__) || defined(__clang__))
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wold-style-cast"
#endif
namespace texkernel {
template <typename Type, uint NUM_LAYERS, int HIDDEN_DIM,
dx::linalg::ComponentEnum ELEM_TYPE,
dx::linalg::MatrixLayoutEnum LAYOUT,
typename ActivationHiddenT, typename ActivationLastT,
uint W_ALIGN, uint VS_ALIGN, uint B_ALIGN,
bool HAS_BIAS>
void inferenceStep(
const uint threadId,
ByteAddressBuffer uvBuffer,
RWByteAddressBuffer outputBuffer,
ByteAddressBuffer weightBuffer,
ByteAddressBuffer biasBuffer,
const uint2 weightMatrixSize,
const uint numTasks)
{
const int inputDim = 2;
const int outputDim = 2;
const uint inputVecStride = inputDim * (uint)(sizeof(Type));
const uint outputVecStride = outputDim * (uint)(sizeof(Type));
using InputVecT = vector<Type, inputDim>;
using OutputVecT = vector<Type, outputDim>;
using VectorBufferAccessorT = mininn::impl::VectorBufferAccessor<Type>;
// Select LayerDataRef type based on HAS_BIAS
// Note: Both with-bias and no-bias variants are always available,
// but the HAS_BIAS template parameter selects which to use.
using LayerDataRefWithBias = mininn::InferenceLayerDataRef<
NUM_LAYERS, HIDDEN_DIM,
ELEM_TYPE, LAYOUT, ELEM_TYPE, ELEM_TYPE,
ActivationHiddenT, ActivationLastT, ELEM_TYPE,
W_ALIGN, VS_ALIGN, B_ALIGN>;
using LayerDataRefNoBias = mininn::InferenceLayerDataRefNoBias<
NUM_LAYERS, HIDDEN_DIM,
ELEM_TYPE, LAYOUT, ELEM_TYPE,
ActivationHiddenT, ActivationLastT, ELEM_TYPE,
W_ALIGN, VS_ALIGN, B_ALIGN>;
if (numTasks <= threadId)
return;
const InputVecT uv = VectorBufferAccessorT::template load<inputDim>(uvBuffer, threadId * inputVecStride);
OutputVecT output = (OutputVecT)0;
if (HAS_BIAS) {
LayerDataRefWithBias layerData;
layerData.setWeightData(weightBuffer, weightMatrixSize);
layerData.setBiasData(biasBuffer);
mininn::forward(output, uv, layerData);
} else {
LayerDataRefNoBias layerData;
layerData.setWeightData(weightBuffer, weightMatrixSize);
mininn::forward(output, uv, layerData);
}
VectorBufferAccessorT::template store<outputDim>(outputBuffer, threadId * outputVecStride, output);
}
} // namespace texkernel
#if defined(MINIDXNN_CPP_HLSL_COMPAT_HPP) && (defined(__GNUC__) || defined(__clang__))
#pragma GCC diagnostic pop
#endif
#endif // MINIDXNN_TEXTURE_INFERENCE_COMMON_HLSL