Skip to content

Commit ab79785

Browse files
committed
gen - add support for mixed precision CUDA operators
1 parent 65d1306 commit ab79785

33 files changed

+3093
-379
lines changed

backends/cuda-gen/ceed-cuda-gen-operator-build.cpp

Lines changed: 45 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1285,7 +1285,7 @@ extern "C" int CeedOperatorBuildKernel_Cuda_gen(CeedOperator op, bool *is_good_b
12851285
code << tab << "// s_G_[in,out]_i: Gradient matrix, shared memory\n";
12861286
code << tab << "// -----------------------------------------------------------------------------\n";
12871287
code << tab << "extern \"C\" __global__ void " << operator_name
1288-
<< "(CeedInt num_elem, void* ctx, FieldsInt_Cuda indices, Fields_Cuda fields, Fields_Cuda B, Fields_Cuda G, CeedScalar *W, Points_Cuda "
1288+
<< "(CeedInt num_elem, void* ctx, FieldsInt_Cuda indices, Fields_Cuda fields, Fields_Cuda B, Fields_Cuda G, CeedScalarCPU *W, Points_Cuda "
12891289
"points) {\n";
12901290
tab.push();
12911291

@@ -1295,11 +1295,11 @@ extern "C" int CeedOperatorBuildKernel_Cuda_gen(CeedOperator op, bool *is_good_b
12951295

12961296
CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
12971297
if (eval_mode != CEED_EVAL_WEIGHT) { // Skip CEED_EVAL_WEIGHT
1298-
code << tab << "const CeedScalar *__restrict__ d_in_" << i << " = fields.inputs[" << i << "];\n";
1298+
code << tab << "const CeedScalarCPU *__restrict__ d_in_" << i << " = fields.inputs[" << i << "];\n";
12991299
}
13001300
}
13011301
for (CeedInt i = 0; i < num_output_fields; i++) {
1302-
code << tab << "CeedScalar *__restrict__ d_out_" << i << " = fields.outputs[" << i << "];\n";
1302+
code << tab << "CeedScalarCPU *__restrict__ d_out_" << i << " = fields.outputs[" << i << "];\n";
13031303
}
13041304

13051305
code << tab << "const CeedInt max_dim = " << max_dim << ";\n";
@@ -1574,9 +1574,18 @@ extern "C" int CeedOperatorBuildKernel_Cuda_gen(CeedOperator op, bool *is_good_b
15741574
{
15751575
bool is_compile_good = false;
15761576
const CeedInt T_1d = CeedIntMax(is_all_tensor ? Q_1d : Q, data->max_P_1d);
1577+
bool use_mixed_precision;
1578+
1579+
// Check for mixed precision
1580+
CeedCallBackend(CeedOperatorGetMixedPrecision(op, &use_mixed_precision));
15771581

15781582
data->thread_1d = T_1d;
1579-
CeedCallBackend(CeedTryCompile_Cuda(ceed, code.str().c_str(), &is_compile_good, &data->module, 1, "OP_T_1D", T_1d));
1583+
if (use_mixed_precision) {
1584+
CeedCallBackend(
1585+
CeedTryCompile_Cuda(ceed, code.str().c_str(), &is_compile_good, &data->module, 2, "OP_T_1D", T_1d, "CEED_JIT_MIXED_PRECISION", 1));
1586+
} else {
1587+
CeedCallBackend(CeedTryCompile_Cuda(ceed, code.str().c_str(), &is_compile_good, &data->module, 1, "OP_T_1D", T_1d));
1588+
}
15801589
if (is_compile_good) {
15811590
*is_good_build = true;
15821591
CeedCallBackend(CeedGetKernel_Cuda(ceed, data->module, operator_name.c_str(), &data->op));
@@ -1689,8 +1698,8 @@ static int CeedOperatorBuildKernelAssemblyAtPoints_Cuda_gen(CeedOperator op, boo
16891698
code << tab << "// s_G_[in,out]_i: Gradient matrix, shared memory\n";
16901699
code << tab << "// -----------------------------------------------------------------------------\n";
16911700
code << tab << "extern \"C\" __global__ void " << operator_name
1692-
<< "(CeedInt num_elem, void* ctx, FieldsInt_Cuda indices, Fields_Cuda fields, Fields_Cuda B, Fields_Cuda G, CeedScalar *W, Points_Cuda "
1693-
"points, CeedScalar *__restrict__ values_array) {\n";
1701+
<< "(CeedInt num_elem, void* ctx, FieldsInt_Cuda indices, Fields_Cuda fields, Fields_Cuda B, Fields_Cuda G, CeedScalarCPU *W, Points_Cuda "
1702+
"points, CeedScalarCPU *__restrict__ values_array) {\n";
16941703
tab.push();
16951704

16961705
// Scratch buffers
@@ -1699,11 +1708,11 @@ static int CeedOperatorBuildKernelAssemblyAtPoints_Cuda_gen(CeedOperator op, boo
16991708

17001709
CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
17011710
if (eval_mode != CEED_EVAL_WEIGHT) { // Skip CEED_EVAL_WEIGHT
1702-
code << tab << "const CeedScalar *__restrict__ d_in_" << i << " = fields.inputs[" << i << "];\n";
1711+
code << tab << "const CeedScalarCPU *__restrict__ d_in_" << i << " = fields.inputs[" << i << "];\n";
17031712
}
17041713
}
17051714
for (CeedInt i = 0; i < num_output_fields; i++) {
1706-
code << tab << "CeedScalar *__restrict__ d_out_" << i << " = fields.outputs[" << i << "];\n";
1715+
code << tab << "CeedScalarCPU *__restrict__ d_out_" << i << " = fields.outputs[" << i << "];\n";
17071716
}
17081717

17091718
code << tab << "const CeedInt max_dim = " << max_dim << ";\n";
@@ -2045,10 +2054,20 @@ static int CeedOperatorBuildKernelAssemblyAtPoints_Cuda_gen(CeedOperator op, boo
20452054
{
20462055
bool is_compile_good = false;
20472056
const CeedInt T_1d = CeedIntMax(is_all_tensor ? Q_1d : Q, data->max_P_1d);
2057+
bool use_mixed_precision;
2058+
2059+
// Check for mixed precision
2060+
CeedCallBackend(CeedOperatorGetMixedPrecision(op, &use_mixed_precision));
20482061

20492062
data->thread_1d = T_1d;
2050-
CeedCallBackend(CeedTryCompile_Cuda(ceed, code.str().c_str(), &is_compile_good,
2051-
is_full ? &data->module_assemble_full : &data->module_assemble_diagonal, 1, "OP_T_1D", T_1d));
2063+
if (use_mixed_precision) {
2064+
CeedCallBackend(CeedTryCompile_Cuda(ceed, code.str().c_str(), &is_compile_good,
2065+
is_full ? &data->module_assemble_full : &data->module_assemble_diagonal, 2, "OP_T_1D", T_1d,
2066+
"CEED_JIT_MIXED_PRECISION", 1));
2067+
} else {
2068+
CeedCallBackend(CeedTryCompile_Cuda(ceed, code.str().c_str(), &is_compile_good,
2069+
is_full ? &data->module_assemble_full : &data->module_assemble_diagonal, 1, "OP_T_1D", T_1d));
2070+
}
20522071
if (is_compile_good) {
20532072
*is_good_build = true;
20542073
CeedCallBackend(CeedGetKernel_Cuda(ceed, is_full ? data->module_assemble_full : data->module_assemble_diagonal, operator_name.c_str(),
@@ -2221,8 +2240,8 @@ extern "C" int CeedOperatorBuildKernelLinearAssembleQFunction_Cuda_gen(CeedOpera
22212240
code << tab << "// s_G_[in,out]_i: Gradient matrix, shared memory\n";
22222241
code << tab << "// -----------------------------------------------------------------------------\n";
22232242
code << tab << "extern \"C\" __global__ void " << operator_name
2224-
<< "(CeedInt num_elem, void* ctx, FieldsInt_Cuda indices, Fields_Cuda fields, Fields_Cuda B, Fields_Cuda G, CeedScalar *W, Points_Cuda "
2225-
"points, CeedScalar *__restrict__ values_array) {\n";
2243+
<< "(CeedInt num_elem, void* ctx, FieldsInt_Cuda indices, Fields_Cuda fields, Fields_Cuda B, Fields_Cuda G, CeedScalarCPU *W, Points_Cuda "
2244+
"points, CeedScalarCPU *__restrict__ values_array) {\n";
22262245
tab.push();
22272246

22282247
// Scratch buffers
@@ -2231,11 +2250,11 @@ extern "C" int CeedOperatorBuildKernelLinearAssembleQFunction_Cuda_gen(CeedOpera
22312250

22322251
CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
22332252
if (eval_mode != CEED_EVAL_WEIGHT) { // Skip CEED_EVAL_WEIGHT
2234-
code << tab << "const CeedScalar *__restrict__ d_in_" << i << " = fields.inputs[" << i << "];\n";
2253+
code << tab << "const CeedScalarCPU *__restrict__ d_in_" << i << " = fields.inputs[" << i << "];\n";
22352254
}
22362255
}
22372256
for (CeedInt i = 0; i < num_output_fields; i++) {
2238-
code << tab << "CeedScalar *__restrict__ d_out_" << i << " = fields.outputs[" << i << "];\n";
2257+
code << tab << "CeedScalarCPU *__restrict__ d_out_" << i << " = fields.outputs[" << i << "];\n";
22392258
}
22402259

22412260
code << tab << "const CeedInt max_dim = " << max_dim << ";\n";
@@ -2485,8 +2504,8 @@ extern "C" int CeedOperatorBuildKernelLinearAssembleQFunction_Cuda_gen(CeedOpera
24852504
CeedCallBackend(CeedQFunctionFieldGetSize(qf_input_fields[f], &field_size));
24862505
CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[f], &eval_mode));
24872506
if (eval_mode == CEED_EVAL_GRAD) {
2488-
code << tab << "CeedScalar r_q_in_" << f << "[num_comp_in_" << f << "*" << "dim_in_" << f << "*"
2489-
<< (is_all_tensor && (max_dim >= 3) ? "Q_1d" : "1") << "] = {0.};\n";
2507+
code << tab << "CeedScalar r_q_in_" << f << "[num_comp_in_" << f << "*"
2508+
<< "dim_in_" << f << "*" << (is_all_tensor && (max_dim >= 3) ? "Q_1d" : "1") << "] = {0.};\n";
24902509
} else {
24912510
code << tab << "CeedScalar r_q_in_" << f << "[num_comp_in_" << f << "*" << (is_all_tensor && (max_dim >= 3) ? "Q_1d" : "1") << "] = {0.};\n";
24922511
}
@@ -2625,9 +2644,18 @@ extern "C" int CeedOperatorBuildKernelLinearAssembleQFunction_Cuda_gen(CeedOpera
26252644
{
26262645
bool is_compile_good = false;
26272646
const CeedInt T_1d = CeedIntMax(is_all_tensor ? Q_1d : Q, data->max_P_1d);
2647+
bool use_mixed_precision;
2648+
2649+
// Check for mixed precision
2650+
CeedCallBackend(CeedOperatorGetMixedPrecision(op, &use_mixed_precision));
26282651

26292652
data->thread_1d = T_1d;
2630-
CeedCallBackend(CeedTryCompile_Cuda(ceed, code.str().c_str(), &is_compile_good, &data->module_assemble_qfunction, 1, "OP_T_1D", T_1d));
2653+
if (use_mixed_precision) {
2654+
CeedCallBackend(CeedTryCompile_Cuda(ceed, code.str().c_str(), &is_compile_good, &data->module_assemble_qfunction, 2, "OP_T_1D", T_1d,
2655+
"CEED_JIT_MIXED_PRECISION", 1));
2656+
} else {
2657+
CeedCallBackend(CeedTryCompile_Cuda(ceed, code.str().c_str(), &is_compile_good, &data->module_assemble_qfunction, 1, "OP_T_1D", T_1d));
2658+
}
26312659
if (is_compile_good) {
26322660
*is_good_build = true;
26332661
CeedCallBackend(CeedGetKernel_Cuda(ceed, data->module_assemble_qfunction, operator_name.c_str(), &data->assemble_qfunction));

backends/cuda-gen/ceed-cuda-gen.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ static int CeedInit_Cuda_gen(const char *resource, Ceed ceed) {
2929
CeedCallBackend(CeedCalloc(1, &data));
3030
CeedCallBackend(CeedSetData(ceed, data));
3131
CeedCallBackend(CeedInit_Cuda(ceed, resource));
32+
CeedCallBackend(CeedSetSupportsMixedPrecision(ceed, true));
3233

3334
CeedCallBackend(CeedInit("/gpu/cuda/shared", &ceed_shared));
3435
CeedCallBackend(CeedSetDelegate(ceed, ceed_shared));

include/ceed-impl.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ struct Ceed_private {
128128
bool is_debug;
129129
bool has_valid_op_fallback_resource;
130130
bool is_deterministic;
131+
bool supports_mixed_precision;
131132
char err_msg[CEED_MAX_RESOURCE_LEN];
132133
FOffset *f_offsets;
133134
CeedWorkVectors work_vectors;
@@ -380,6 +381,7 @@ struct CeedOperator_private {
380381
bool is_composite;
381382
bool is_at_points;
382383
bool has_restriction;
384+
bool use_mixed_precision;
383385
CeedQFunctionAssemblyData qf_assembled;
384386
CeedOperatorAssemblyData op_assembled;
385387
CeedOperator *sub_operators;

include/ceed/backend.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,7 @@ CEED_EXTERN int CeedGetOperatorFallbackResource(Ceed ceed, const char **resource
250250
CEED_EXTERN int CeedGetOperatorFallbackCeed(Ceed ceed, Ceed *fallback_ceed);
251251
CEED_EXTERN int CeedSetOperatorFallbackResource(Ceed ceed, const char *resource);
252252
CEED_EXTERN int CeedSetDeterministic(Ceed ceed, bool is_deterministic);
253+
CEED_INTERN int CeedSetSupportsMixedPrecision(Ceed ceed, bool supports_mixed_precision);
253254
CEED_EXTERN int CeedSetBackendFunctionImpl(Ceed ceed, const char *type, void *object, const char *func_name, void (*f)(void));
254255
CEED_EXTERN int CeedGetData(Ceed ceed, void *data);
255256
CEED_EXTERN int CeedSetData(Ceed ceed, void *data);

include/ceed/ceed-f32.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414

1515
/// Set base scalar type to FP32. (See CeedScalarType enum in ceed.h for all options.)
1616
#define CEED_SCALAR_TYPE CEED_SCALAR_FP32
17-
typedef float CeedScalar;
17+
typedef float CeedScalar;
18+
typedef CeedScalar CeedScalarCPU;
1819

1920
/// Machine epsilon
20-
#define CEED_EPSILON 6e-08
21+
#define CEED_EPSILON 0x1p-23

include/ceed/ceed-f64.h

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,16 @@
1414

1515
/// Set base scalar type to FP64. (See CeedScalarType enum in ceed.h for all options.)
1616
#define CEED_SCALAR_TYPE CEED_SCALAR_FP64
17-
typedef double CeedScalar;
17+
#if defined(CEED_RUNNING_JIT_PASS) && defined(CEED_JIT_MIXED_PRECISION)
18+
typedef float CeedScalar;
19+
typedef double CeedScalarCPU;
1820

1921
/// Machine epsilon
20-
#define CEED_EPSILON 1e-16
22+
#define CEED_EPSILON 0x1p-23
23+
#else
24+
typedef double CeedScalar;
25+
typedef CeedScalar CeedScalarCPU;
26+
27+
/// Machine epsilon
28+
#define CEED_EPSILON 0x1p-52
29+
#endif // CEED_RUNNING_JIT_PASS && CEED_JIT_MIXED_PRECISION

include/ceed/ceed.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ CEED_EXTERN int CeedSetStream(Ceed ceed, void *handle);
106106
CEED_EXTERN int CeedReferenceCopy(Ceed ceed, Ceed *ceed_copy);
107107
CEED_EXTERN int CeedGetResource(Ceed ceed, const char **resource);
108108
CEED_EXTERN int CeedIsDeterministic(Ceed ceed, bool *is_deterministic);
109+
CEED_EXTERN int CeedGetSupportsMixedPrecision(Ceed ceed, bool *supports_mixed_precision);
109110
CEED_EXTERN int CeedAddJitSourceRoot(Ceed ceed, const char *jit_source_root);
110111
CEED_EXTERN int CeedAddJitDefine(Ceed ceed, const char *jit_define);
111112
CEED_EXTERN int CeedView(Ceed ceed, FILE *stream);
@@ -426,6 +427,8 @@ CEED_EXTERN int CeedOperatorCheckReady(CeedOperator op);
426427
CEED_EXTERN int CeedOperatorGetActiveVectorLengths(CeedOperator op, CeedSize *input_size, CeedSize *output_size);
427428
CEED_EXTERN int CeedOperatorSetQFunctionAssemblyReuse(CeedOperator op, bool reuse_assembly_data);
428429
CEED_EXTERN int CeedOperatorSetQFunctionAssemblyDataUpdateNeeded(CeedOperator op, bool needs_data_update);
430+
CEED_EXTERN int CeedOperatorSetMixedPrecision(CeedOperator op);
431+
CEED_EXTERN int CeedOperatorGetMixedPrecision(CeedOperator op, bool *use_mixed_precision);
429432
CEED_EXTERN int CeedOperatorLinearAssembleQFunction(CeedOperator op, CeedVector *assembled, CeedElemRestriction *rstr, CeedRequest *request);
430433
CEED_EXTERN int CeedOperatorLinearAssembleQFunctionBuildOrUpdate(CeedOperator op, CeedVector *assembled, CeedElemRestriction *rstr,
431434
CeedRequest *request);

0 commit comments

Comments
 (0)