-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtexture_training_common.hlsl
More file actions
97 lines (81 loc) · 3.42 KB
/
Copy pathtexture_training_common.hlsl
File metadata and controls
97 lines (81 loc) · 3.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
/*!
\file texture_training_common.hlsl
\author Sho Ikeda
\brief Shared training step for texture MLP (GPU + C++ fallback)
\copyright Copyright (c) 2026 Advanced Micro Devices, Inc. All Rights Reserved.
SPDX-License-Identifier: MIT
Provides the common training step (forward + MSE loss + backward) used by both
02_texture_training.comp and the C++ fallback training kernel.
Requires mlp.hlsl to be included before this file.
*/
#ifndef MINIDXNN_TEXTURE_TRAINING_COMMON_HLSL
#define MINIDXNN_TEXTURE_TRAINING_COMMON_HLSL 1
#if defined(MINIDXNN_CPP_HLSL_COMPAT_HPP) && (defined(__GNUC__) || defined(__clang__))
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wold-style-cast"
#endif
namespace texkernel {
template <typename Type, uint NUM_LAYERS, int HIDDEN_DIM,
dx::linalg::ComponentEnum ELEM_TYPE,
dx::linalg::MatrixLayoutEnum LAYOUT,
typename ActivationHiddenT, typename ActivationLastT,
uint W_ALIGN, uint VS_ALIGN, uint B_ALIGN>
void trainingStep(
const uint threadId,
ByteAddressBuffer uvBuffer,
ByteAddressBuffer targetBuffer,
ByteAddressBuffer weightBuffer,
ByteAddressBuffer biasBuffer,
RWByteAddressBuffer weightGradBuffer,
RWByteAddressBuffer biasGradBuffer,
RWByteAddressBuffer logitsCacheBuffer,
RWByteAddressBuffer lossBuffer,
const uint2 weightMatrixSize,
const uint batchSize,
const uint batchIndex,
const uint currentBatchSize,
const uint logitsStride)
{
const int inputDim = 2;
const int outputDim = 2;
const uint inputVecStride = inputDim * (uint)(sizeof(Type));
const uint outputVecStride = outputDim * (uint)(sizeof(Type));
using InputVecT = vector<Type, inputDim>;
using OutputVecT = vector<Type, outputDim>;
using VectorBufferAccessorT = mininn::impl::VectorBufferAccessor<Type>;
using LayerDataRef = mininn::TrainingLayerDataRef<
NUM_LAYERS, HIDDEN_DIM,
ELEM_TYPE, LAYOUT, ELEM_TYPE,
ELEM_TYPE, ELEM_TYPE, ELEM_TYPE, ELEM_TYPE,
ActivationHiddenT, ActivationLastT, ELEM_TYPE,
W_ALIGN, VS_ALIGN, B_ALIGN>;
if (currentBatchSize <= threadId)
return;
const uint sampleIndex = batchIndex * batchSize + threadId;
const InputVecT uv = VectorBufferAccessorT::template load<inputDim>(uvBuffer, sampleIndex * inputVecStride);
LayerDataRef layerData;
layerData.setWeightData(weightBuffer, weightMatrixSize);
layerData.setWeightGradientCache(weightGradBuffer, weightMatrixSize);
if (LayerDataRef::HAS_BIAS) {
layerData.setBiasData(biasBuffer);
layerData.setBiasGradientCache(biasGradBuffer);
}
layerData.setLogitsCache(logitsCacheBuffer, threadId * logitsStride);
OutputVecT output = (OutputVecT)0;
mininn::forward(output, uv, layerData);
const OutputVecT target = VectorBufferAccessorT::template load<outputDim>(targetBuffer, sampleIndex * outputVecStride);
// MSE loss
const OutputVecT diff = output - target;
const float loss = (float)(dot(diff, diff)) / (float)(outputDim);
mininn::impl::atomicFetchAdd(lossBuffer, 0, loss);
// Loss gradient
const Type scale = (Type)2 / (Type)((float)(outputDim));
const Type batchScale = (Type)1 / (Type)((float)(currentBatchSize));
const OutputVecT lossGrad = (scale * batchScale) * diff;
mininn::backward(lossGrad, uv, layerData);
}
} // namespace texkernel
#if defined(MINIDXNN_CPP_HLSL_COMPAT_HPP) && (defined(__GNUC__) || defined(__clang__))
#pragma GCC diagnostic pop
#endif
#endif // MINIDXNN_TEXTURE_TRAINING_COMMON_HLSL