From 00fc8daf56d0a070dc75036cb13ffbfc7a6567c6 Mon Sep 17 00:00:00 2001 From: ReinUsesLisp Date: Sun, 28 Oct 2018 13:44:12 -0300 Subject: [PATCH] Use variant instead of creating an object for literals --- include/sirit/sirit.h | 22 +++--- src/CMakeLists.txt | 1 - src/insts/annotation.cpp | 13 ++-- src/insts/constant.cpp | 15 ++-- src/insts/debug.cpp | 2 +- src/insts/flow.cpp | 23 +++---- src/insts/function.cpp | 9 ++- src/insts/type.cpp | 144 +++++++++++++++++++-------------------- src/literal.cpp | 26 ------- src/op.cpp | 28 ++++++++ src/op.h | 4 ++ src/sirit.cpp | 3 +- 12 files changed, 146 insertions(+), 144 deletions(-) delete mode 100644 src/literal.cpp diff --git a/include/sirit/sirit.h b/include/sirit/sirit.h index 61c21a1..8ddf9b5 100644 --- a/include/sirit/sirit.h +++ b/include/sirit/sirit.h @@ -11,6 +11,7 @@ #include #include #include +#include #include namespace Sirit { @@ -20,7 +21,9 @@ constexpr std::uint32_t GENERATOR_MAGIC_NUMBER = 0; class Op; class Operand; -typedef const Op* Ref; +using Literal = std::variant; +using Ref = const Op*; class Module { public: @@ -135,7 +138,7 @@ class Module { Ref ConstantFalse(Ref result_type); /// Returns a numeric scalar constant. - Ref Constant(Ref result_type, Operand* literal); + Ref Constant(Ref result_type, const Literal& literal); /// Returns a numeric scalar constant. Ref ConstantComposite(Ref result_type, @@ -201,18 +204,11 @@ class Module { /// Add a decoration to target. Ref Decorate(Ref target, spv::Decoration decoration, - const std::vector& literals = {}); + const std::vector& literals = {}); - Ref MemberDecorate(Ref structure_type, Operand* member, spv::Decoration decoration, - const std::vector& literals = {}); - - // Literals - static Operand* Literal(std::uint32_t value); - static Operand* Literal(std::uint64_t value); - static Operand* Literal(std::int32_t value); - static Operand* Literal(std::int64_t value); - static Operand* Literal(float value); - static Operand* Literal(double value); + Ref MemberDecorate(Ref structure_type, Literal member, + spv::Decoration decoration, + const std::vector& literals = {}); private: Ref AddCode(Op* op); diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 8602c9f..db94e83 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -7,7 +7,6 @@ add_library(sirit stream.h operand.cpp operand.h - literal.cpp literal-number.cpp literal-number.h literal-string.cpp diff --git a/src/insts/annotation.cpp b/src/insts/annotation.cpp index c54adfc..485ad6b 100644 --- a/src/insts/annotation.cpp +++ b/src/insts/annotation.cpp @@ -10,21 +10,22 @@ namespace Sirit { Ref Module::Decorate(Ref target, spv::Decoration decoration, - const std::vector& literals) { + const std::vector& literals) { auto op{new Op(spv::Op::OpDecorate)}; op->Add(target); AddEnum(op, decoration); - op->Sink(literals); + op->Add(literals); return AddAnnotation(op); } -Ref Module::MemberDecorate(Ref structure_type, Operand* member, spv::Decoration decoration, - const std::vector& literals) { +Ref Module::MemberDecorate(Ref structure_type, Literal member, + spv::Decoration decoration, + const std::vector& literals) { auto op{new Op(spv::Op::OpMemberDecorate)}; op->Add(structure_type); - op->Sink(member); + op->Add(member); AddEnum(op, decoration); - op->Sink(literals); + op->Add(literals); return AddAnnotation(op); } diff --git a/src/insts/constant.cpp b/src/insts/constant.cpp index 3360603..d5e8802 100644 --- a/src/insts/constant.cpp +++ b/src/insts/constant.cpp @@ -4,9 +4,9 @@ * Lesser General Public License version 2.1 or any later version. */ -#include -#include "sirit/sirit.h" #include "insts.h" +#include "sirit/sirit.h" +#include namespace Sirit { @@ -18,20 +18,23 @@ Ref Module::ConstantFalse(Ref result_type) { return AddDeclaration(new Op(spv::Op::OpConstantFalse, bound, result_type)); } -Ref Module::Constant(Ref result_type, Operand* literal) { +Ref Module::Constant(Ref result_type, const Literal& literal) { auto op{new Op(spv::Op::OpConstant, bound, result_type)}; op->Add(literal); return AddDeclaration(op); } -Ref Module::ConstantComposite(Ref result_type, const std::vector& constituents) { +Ref Module::ConstantComposite(Ref result_type, + const std::vector& constituents) { auto op{new Op(spv::Op::OpConstantComposite, bound, result_type)}; op->Add(constituents); return AddDeclaration(op); } -Ref Module::ConstantSampler(Ref result_type, spv::SamplerAddressingMode addressing_mode, - bool normalized, spv::SamplerFilterMode filter_mode) { +Ref Module::ConstantSampler(Ref result_type, + spv::SamplerAddressingMode addressing_mode, + bool normalized, + spv::SamplerFilterMode filter_mode) { AddCapability(spv::Capability::LiteralSampler); AddCapability(spv::Capability::Kernel); auto op{new Op(spv::Op::OpConstantSampler, bound, result_type)}; diff --git a/src/insts/debug.cpp b/src/insts/debug.cpp index c557511..3822dcc 100644 --- a/src/insts/debug.cpp +++ b/src/insts/debug.cpp @@ -4,8 +4,8 @@ * Lesser General Public License version 2.1 or any later version. */ -#include "sirit/sirit.h" #include "insts.h" +#include "sirit/sirit.h" namespace Sirit { diff --git a/src/insts/flow.cpp b/src/insts/flow.cpp index 056e82e..c88df40 100644 --- a/src/insts/flow.cpp +++ b/src/insts/flow.cpp @@ -4,12 +4,13 @@ * Lesser General Public License version 2.1 or any later version. */ -#include "sirit/sirit.h" #include "insts.h" +#include "sirit/sirit.h" namespace Sirit { -Ref Module::LoopMerge(Ref merge_block, Ref continue_target, spv::LoopControlMask loop_control, +Ref Module::LoopMerge(Ref merge_block, Ref continue_target, + spv::LoopControlMask loop_control, const std::vector& literals) { auto op{new Op(spv::Op::OpLoopMerge)}; op->Add(merge_block); @@ -19,16 +20,15 @@ Ref Module::LoopMerge(Ref merge_block, Ref continue_target, spv::LoopControlMask return AddCode(op); } -Ref Module::SelectionMerge(Ref merge_block, spv::SelectionControlMask selection_control) { +Ref Module::SelectionMerge(Ref merge_block, + spv::SelectionControlMask selection_control) { auto op{new Op(spv::Op::OpSelectionMerge)}; op->Add(merge_block); AddEnum(op, selection_control); return AddCode(op); } -Ref Module::Label() { - return AddCode(spv::Op::OpLabel, bound++); -} +Ref Module::Label() { return AddCode(spv::Op::OpLabel, bound++); } Ref Module::Branch(Ref target_label) { auto op{new Op(spv::Op::OpBranch)}; @@ -37,20 +37,19 @@ Ref Module::Branch(Ref target_label) { } Ref Module::BranchConditional(Ref condition, Ref true_label, Ref false_label, - std::uint32_t true_weight, std::uint32_t false_weight) { + std::uint32_t true_weight, + std::uint32_t false_weight) { auto op{new Op(spv::Op::OpBranchConditional)}; op->Add(condition); op->Add(true_label); op->Add(false_label); if (true_weight != 0 || false_weight != 0) { - op->Add(Literal(true_weight)); - op->Add(Literal(false_weight)); + op->Add(true_weight); + op->Add(false_weight); } return AddCode(op); } -Ref Module::Return() { - return AddCode(spv::Op::OpReturn); -} +Ref Module::Return() { return AddCode(spv::Op::OpReturn); } } // namespace Sirit diff --git a/src/insts/function.cpp b/src/insts/function.cpp index 9b8ee0f..efcc2c6 100644 --- a/src/insts/function.cpp +++ b/src/insts/function.cpp @@ -4,20 +4,19 @@ * Lesser General Public License version 2.1 or any later version. */ -#include "sirit/sirit.h" #include "insts.h" +#include "sirit/sirit.h" namespace Sirit { -Ref Module::Function(Ref result_type, spv::FunctionControlMask function_control, Ref function_type) { +Ref Module::Function(Ref result_type, spv::FunctionControlMask function_control, + Ref function_type) { auto op{new Op{spv::Op::OpFunction, bound++, result_type}}; op->Add(static_cast(function_control)); op->Add(function_type); return AddCode(op); } -Ref Module::FunctionEnd() { - return AddCode(spv::Op::OpFunctionEnd); -} +Ref Module::FunctionEnd() { return AddCode(spv::Op::OpFunctionEnd); } } // namespace Sirit diff --git a/src/insts/type.cpp b/src/insts/type.cpp index 4a2e1a5..2587ff4 100644 --- a/src/insts/type.cpp +++ b/src/insts/type.cpp @@ -7,8 +7,8 @@ #include #include -#include "sirit/sirit.h" #include "insts.h" +#include "sirit/sirit.h" namespace Sirit { @@ -62,68 +62,68 @@ Ref Module::TypeMatrix(Ref column_type, int column_count) { return AddDeclaration(op); } -Ref Module::TypeImage(Ref sampled_type, spv::Dim dim, int depth, bool arrayed, bool ms, - int sampled, spv::ImageFormat image_format, +Ref Module::TypeImage(Ref sampled_type, spv::Dim dim, int depth, bool arrayed, + bool ms, int sampled, spv::ImageFormat image_format, std::optional access_qualifier) { switch (dim) { - case spv::Dim::Dim1D: - AddCapability(spv::Capability::Sampled1D); - break; - case spv::Dim::Cube: - AddCapability(spv::Capability::Shader); - break; - case spv::Dim::Rect: - AddCapability(spv::Capability::SampledRect); - break; - case spv::Dim::Buffer: - AddCapability(spv::Capability::SampledBuffer); - break; - case spv::Dim::SubpassData: - AddCapability(spv::Capability::InputAttachment); - break; + case spv::Dim::Dim1D: + AddCapability(spv::Capability::Sampled1D); + break; + case spv::Dim::Cube: + AddCapability(spv::Capability::Shader); + break; + case spv::Dim::Rect: + AddCapability(spv::Capability::SampledRect); + break; + case spv::Dim::Buffer: + AddCapability(spv::Capability::SampledBuffer); + break; + case spv::Dim::SubpassData: + AddCapability(spv::Capability::InputAttachment); + break; } switch (image_format) { - case spv::ImageFormat::Rgba32f: - case spv::ImageFormat::Rgba16f: - case spv::ImageFormat::R32f: - case spv::ImageFormat::Rgba8: - case spv::ImageFormat::Rgba8Snorm: - case spv::ImageFormat::Rgba32i: - case spv::ImageFormat::Rgba16i: - case spv::ImageFormat::Rgba8i: - case spv::ImageFormat::R32i: - case spv::ImageFormat::Rgba32ui: - case spv::ImageFormat::Rgba16ui: - case spv::ImageFormat::Rgba8ui: - case spv::ImageFormat::R32ui: - AddCapability(spv::Capability::Shader); - break; - case spv::ImageFormat::Rg32f: - case spv::ImageFormat::Rg16f: - case spv::ImageFormat::R11fG11fB10f: - case spv::ImageFormat::R16f: - case spv::ImageFormat::Rgba16: - case spv::ImageFormat::Rgb10A2: - case spv::ImageFormat::Rg16: - case spv::ImageFormat::Rg8: - case spv::ImageFormat::R16: - case spv::ImageFormat::R8: - case spv::ImageFormat::Rgba16Snorm: - case spv::ImageFormat::Rg16Snorm: - case spv::ImageFormat::Rg8Snorm: - case spv::ImageFormat::Rg32i: - case spv::ImageFormat::Rg16i: - case spv::ImageFormat::Rg8i: - case spv::ImageFormat::R16i: - case spv::ImageFormat::R8i: - case spv::ImageFormat::Rgb10a2ui: - case spv::ImageFormat::Rg32ui: - case spv::ImageFormat::Rg16ui: - case spv::ImageFormat::Rg8ui: - case spv::ImageFormat::R16ui: - case spv::ImageFormat::R8ui: - AddCapability(spv::Capability::StorageImageExtendedFormats); - break; + case spv::ImageFormat::Rgba32f: + case spv::ImageFormat::Rgba16f: + case spv::ImageFormat::R32f: + case spv::ImageFormat::Rgba8: + case spv::ImageFormat::Rgba8Snorm: + case spv::ImageFormat::Rgba32i: + case spv::ImageFormat::Rgba16i: + case spv::ImageFormat::Rgba8i: + case spv::ImageFormat::R32i: + case spv::ImageFormat::Rgba32ui: + case spv::ImageFormat::Rgba16ui: + case spv::ImageFormat::Rgba8ui: + case spv::ImageFormat::R32ui: + AddCapability(spv::Capability::Shader); + break; + case spv::ImageFormat::Rg32f: + case spv::ImageFormat::Rg16f: + case spv::ImageFormat::R11fG11fB10f: + case spv::ImageFormat::R16f: + case spv::ImageFormat::Rgba16: + case spv::ImageFormat::Rgb10A2: + case spv::ImageFormat::Rg16: + case spv::ImageFormat::Rg8: + case spv::ImageFormat::R16: + case spv::ImageFormat::R8: + case spv::ImageFormat::Rgba16Snorm: + case spv::ImageFormat::Rg16Snorm: + case spv::ImageFormat::Rg8Snorm: + case spv::ImageFormat::Rg32i: + case spv::ImageFormat::Rg16i: + case spv::ImageFormat::Rg8i: + case spv::ImageFormat::R16i: + case spv::ImageFormat::R8i: + case spv::ImageFormat::Rgb10a2ui: + case spv::ImageFormat::Rg32ui: + case spv::ImageFormat::Rg16ui: + case spv::ImageFormat::Rg8ui: + case spv::ImageFormat::R16ui: + case spv::ImageFormat::R8ui: + AddCapability(spv::Capability::StorageImageExtendedFormats); + break; } auto op{new Op(spv::Op::OpTypeImage, bound)}; op->Add(sampled_type); @@ -179,19 +179,19 @@ Ref Module::TypeOpaque(const std::string& name) { Ref Module::TypePointer(spv::StorageClass storage_class, Ref type) { switch (storage_class) { - case spv::StorageClass::Uniform: - case spv::StorageClass::Output: - case spv::StorageClass::Private: - case spv::StorageClass::PushConstant: - case spv::StorageClass::StorageBuffer: - AddCapability(spv::Capability::Shader); - break; - case spv::StorageClass::Generic: - AddCapability(spv::Capability::GenericPointer); - break; - case spv::StorageClass::AtomicCounter: - AddCapability(spv::Capability::AtomicStorage); - break; + case spv::StorageClass::Uniform: + case spv::StorageClass::Output: + case spv::StorageClass::Private: + case spv::StorageClass::PushConstant: + case spv::StorageClass::StorageBuffer: + AddCapability(spv::Capability::Shader); + break; + case spv::StorageClass::Generic: + AddCapability(spv::Capability::GenericPointer); + break; + case spv::StorageClass::AtomicCounter: + AddCapability(spv::Capability::AtomicStorage); + break; } auto op{new Op(spv::Op::OpTypePointer, bound)}; op->Add(static_cast(storage_class)); diff --git a/src/literal.cpp b/src/literal.cpp deleted file mode 100644 index ba8738a..0000000 --- a/src/literal.cpp +++ /dev/null @@ -1,26 +0,0 @@ -/* This file is part of the sirit project. - * Copyright (c) 2018 ReinUsesLisp - * This software may be used and distributed according to the terms of the GNU - * Lesser General Public License version 2.1 or any later version. - */ - -#include "common_types.h" -#include "literal-number.h" -#include "operand.h" -#include "sirit/sirit.h" - -namespace Sirit { - -#define DEFINE_LITERAL(type) \ - Operand* Module::Literal(type value) { \ - return LiteralNumber::Create(value); \ - } - -DEFINE_LITERAL(u32) -DEFINE_LITERAL(u64) -DEFINE_LITERAL(s32) -DEFINE_LITERAL(s64) -DEFINE_LITERAL(f32) -DEFINE_LITERAL(f64) - -} // namespace Sirit diff --git a/src/op.cpp b/src/op.cpp index c1d3e19..ea228ee 100644 --- a/src/op.cpp +++ b/src/op.cpp @@ -71,6 +71,34 @@ void Op::Sink(const std::vector& operands) { } } +void Op::Add(const Literal& literal) { + Operand* operand = [&]() { + switch (literal.index()) { + case 0: + return LiteralNumber::Create(std::get<0>(literal)); + case 1: + return LiteralNumber::Create(std::get<1>(literal)); + case 2: + return LiteralNumber::Create(std::get<2>(literal)); + case 3: + return LiteralNumber::Create(std::get<3>(literal)); + case 4: + return LiteralNumber::Create(std::get<4>(literal)); + case 5: + return LiteralNumber::Create(std::get<5>(literal)); + default: + assert(!"invalid literal type"); + } + }(); + Sink(operand); +} + +void Op::Add(const std::vector& literals) { + for (const auto& literal : literals) { + Add(literal); + } +} + void Op::Add(const Operand* operand) { operands.push_back(operand); } void Op::Add(u32 integer) { Sink(LiteralNumber::Create(integer)); } diff --git a/src/op.h b/src/op.h index 87e51c1..0e13cdd 100644 --- a/src/op.h +++ b/src/op.h @@ -31,6 +31,10 @@ class Op : public Operand { void Sink(const std::vector& operands); + void Add(const Literal& literal); + + void Add(const std::vector& literals); + void Add(const Operand* operand); void Add(u32 integer); diff --git a/src/sirit.cpp b/src/sirit.cpp index fee7325..9132958 100644 --- a/src/sirit.cpp +++ b/src/sirit.cpp @@ -20,8 +20,7 @@ static void WriteEnum(Stream& stream, spv::Op opcode, T value) { op.Write(stream); } -template -static void WriteSet(Stream& stream, const T& set) { +template static void WriteSet(Stream& stream, const T& set) { for (const auto& item : set) { item->Write(stream); }