diff --git a/include/sirit/sirit.h b/include/sirit/sirit.h index 887bf17..85aa71c 100644 --- a/include/sirit/sirit.h +++ b/include/sirit/sirit.h @@ -87,7 +87,22 @@ public: /// Returns type sampled image. const Op* TypeSampledImage(const Op* image_type); - /// Returns a function type. + /// Returns type array. + const Op* TypeArray(const Op* element_type, const Op* length); + + /// Returns type runtime array. + const Op* TypeRuntimeArray(const Op* element_type); + + /// Returns type struct. + const Op* TypeStruct(const std::vector& members = {}); + + /// Returns type opaque. + const Op* TypeOpaque(const std::string& name); + + /// Returns type pointer. + const Op* TypePointer(spv::StorageClass storage_class, const Op* type); + + /// Returns type function. const Op* TypeFunction(const Op* return_type, const std::vector& arguments = {}); // Function diff --git a/src/opcodes/type.cpp b/src/opcodes/type.cpp index 403213e..4588506 100644 --- a/src/opcodes/type.cpp +++ b/src/opcodes/type.cpp @@ -148,6 +148,55 @@ const Op* Module::TypeSampledImage(const Op* image_type) { return AddDeclaration(op); } +const Op* Module::TypeArray(const Op* element_type, const Op* length) { + Op* op{new Op(spv::Op::OpTypeArray, bound)}; + op->Add(element_type); + op->Add(length); + return AddDeclaration(op); +} + +const Op* Module::TypeRuntimeArray(const Op* element_type) { + AddCapability(spv::Capability::Shader); + Op* op{new Op(spv::Op::OpTypeRuntimeArray, bound)}; + op->Add(element_type); + return AddDeclaration(op); +} + +const Op* Module::TypeStruct(const std::vector& members) { + Op* op{new Op(spv::Op::OpTypeStruct, bound)}; + op->Add(members); + return AddDeclaration(op); +} + +const Op* Module::TypeOpaque(const std::string& name) { + AddCapability(spv::Capability::Kernel); + Op* op{new Op(spv::Op::OpTypeOpaque, bound)}; + op->Add(name); + return AddDeclaration(op); +} + +const Op* Module::TypePointer(spv::StorageClass storage_class, const Op* type) { + switch (storage_class) { + case spv::StorageClass::Uniform: + case spv::StorageClass::Output: + case spv::StorageClass::Private: + case spv::StorageClass::PushConstant: + case spv::StorageClass::StorageBuffer: + AddCapability(spv::Capability::Shader); + break; + case spv::StorageClass::Generic: + AddCapability(spv::Capability::GenericPointer); + break; + case spv::StorageClass::AtomicCounter: + AddCapability(spv::Capability::AtomicStorage); + break; + } + Op* op{new Op(spv::Op::OpTypePointer, bound)}; + op->Add(static_cast(storage_class)); + op->Add(type); + return AddDeclaration(op); +} + const Op* Module::TypeFunction(const Op* return_type, const std::vector& arguments) { Op* type_func{new Op(spv::Op::OpTypeFunction, bound)}; type_func->Add(return_type); diff --git a/tests/main.cpp b/tests/main.cpp index e0a7de0..b801239 100644 --- a/tests/main.cpp +++ b/tests/main.cpp @@ -28,13 +28,16 @@ public: TypeFloat(64); TypeVector(TypeBool(), 4); TypeVector(TypeBool(), 3); - TypeVector(TypeVector(TypeFloat(32), 4), 3); - TypeVector(TypeVector(TypeFloat(32), 4), 3); TypeMatrix(TypeVector(TypeFloat(32), 4), 4); TypeImage(TypeFloat(32), spv::Dim::Dim2D, 0, false, false, 0, - spv::ImageFormat::Rg32f, spv::AccessQualifier::ReadOnly); + spv::ImageFormat::Rg32f); TypeSampledImage(TypeImage(TypeFloat(32), spv::Dim::Rect, 0, false, false, 0, - spv::ImageFormat::Rg32f, spv::AccessQualifier::ReadOnly)); + spv::ImageFormat::Rg32f)); + TypeVector(TypeInt(32, true), 4); + TypeVector(TypeInt(64, true), 4); + TypeRuntimeArray(TypeInt(32, true)); + TypeStruct({TypeInt(32, true), TypeFloat(64)}); + TypePointer(spv::StorageClass::Private, TypeFloat(16)); auto main_type{TypeFunction(TypeVoid())}; auto main_func{Emit(Function(TypeVoid(), spv::FunctionControlMask::MaskNone, main_type))};