diff --git a/system/vulkan_enc/AndroidHardwareBuffer.cpp b/system/vulkan_enc/AndroidHardwareBuffer.cpp
index 05f14f1f3a3f9246592492a7dd94cd8c73c7f0d1..84086e1c2d0ac49766174dc7ffe32738e0af1392 100644
--- a/system/vulkan_enc/AndroidHardwareBuffer.cpp
+++ b/system/vulkan_enc/AndroidHardwareBuffer.cpp
@@ -59,8 +59,8 @@ VkResult getAndroidHardwareBufferPropertiesANDROID(
     VkAndroidHardwareBufferPropertiesANDROID* pProperties) {
 
     VkAndroidHardwareBufferFormatPropertiesANDROID* ahbFormatProps =
-        (VkAndroidHardwareBufferFormatPropertiesANDROID*)vk_find_struct(
-            (vk_struct_common*)pProperties->pNext,
+        vk_find_struct<VkAndroidHardwareBufferFormatPropertiesANDROID>(
+            pProperties,
             VK_STRUCTURE_TYPE_ANDROID_HARDWARE_BUFFER_FORMAT_PROPERTIES_ANDROID);
 
     if (ahbFormatProps) {
diff --git a/system/vulkan_enc/ResourceTracker.cpp b/system/vulkan_enc/ResourceTracker.cpp
index 631f24d1a1cceeed1420f6e1adf620e5cb88b9cc..af532b99d42950453afe5e6f80ddaba918d87e31 100644
--- a/system/vulkan_enc/ResourceTracker.cpp
+++ b/system/vulkan_enc/ResourceTracker.cpp
@@ -1550,24 +1550,24 @@ public:
             (vk_struct_common*)(&finalAllocInfo));
         structChain->pNext = nullptr;
 
-        VkExportMemoryAllocateInfo* exportAllocateInfoPtr =
-            (VkExportMemoryAllocateInfo*)vk_find_struct((vk_struct_common*)pAllocateInfo,
+        const VkExportMemoryAllocateInfo* exportAllocateInfoPtr =
+            vk_find_struct<VkExportMemoryAllocateInfo>(pAllocateInfo,
                 VK_STRUCTURE_TYPE_EXPORT_MEMORY_ALLOCATE_INFO);
 
-        VkImportAndroidHardwareBufferInfoANDROID* importAhbInfoPtr =
-            (VkImportAndroidHardwareBufferInfoANDROID*)vk_find_struct((vk_struct_common*)pAllocateInfo,
+        const VkImportAndroidHardwareBufferInfoANDROID* importAhbInfoPtr =
+            vk_find_struct<VkImportAndroidHardwareBufferInfoANDROID>(pAllocateInfo,
                 VK_STRUCTURE_TYPE_IMPORT_ANDROID_HARDWARE_BUFFER_INFO_ANDROID);
 
-        VkImportMemoryBufferCollectionFUCHSIA* importBufferCollectionInfoPtr =
-            (VkImportMemoryBufferCollectionFUCHSIA*)vk_find_struct((vk_struct_common*)pAllocateInfo,
+        const VkImportMemoryBufferCollectionFUCHSIA* importBufferCollectionInfoPtr =
+            vk_find_struct<VkImportMemoryBufferCollectionFUCHSIA>(pAllocateInfo,
                 VK_STRUCTURE_TYPE_IMPORT_MEMORY_BUFFER_COLLECTION_FUCHSIA);
 
-        VkImportMemoryZirconHandleInfoFUCHSIA* importVmoInfoPtr =
-            (VkImportMemoryZirconHandleInfoFUCHSIA*)vk_find_struct((vk_struct_common*)pAllocateInfo,
+        const VkImportMemoryZirconHandleInfoFUCHSIA* importVmoInfoPtr =
+            vk_find_struct<VkImportMemoryZirconHandleInfoFUCHSIA>(pAllocateInfo,
                 VK_STRUCTURE_TYPE_TEMP_IMPORT_MEMORY_ZIRCON_HANDLE_INFO_FUCHSIA);
 
-        VkMemoryDedicatedAllocateInfo* dedicatedAllocInfoPtr =
-            (VkMemoryDedicatedAllocateInfo*)vk_find_struct((vk_struct_common*)pAllocateInfo,
+        const VkMemoryDedicatedAllocateInfo* dedicatedAllocInfoPtr =
+            vk_find_struct<VkMemoryDedicatedAllocateInfo>(pAllocateInfo,
                 VK_STRUCTURE_TYPE_MEMORY_DEDICATED_ALLOCATE_INFO);
 
         bool shouldPassThroughDedicatedAllocInfo =
@@ -2125,10 +2125,8 @@ public:
         transformExternalResourceMemoryRequirementsForGuest(&reqs2->memoryRequirements);
 
         VkMemoryDedicatedRequirements* dedicatedReqs =
-            (VkMemoryDedicatedRequirements*)
-            vk_find_struct(
-                (vk_struct_common*)reqs2,
-                VK_STRUCTURE_TYPE_MEMORY_DEDICATED_REQUIREMENTS);
+            vk_find_struct<VkMemoryDedicatedRequirements>(
+                reqs2, VK_STRUCTURE_TYPE_MEMORY_DEDICATED_REQUIREMENTS);
 
         if (!dedicatedReqs) return;
 
@@ -2157,10 +2155,8 @@ public:
         transformExternalResourceMemoryRequirementsForGuest(&reqs2->memoryRequirements);
 
         VkMemoryDedicatedRequirements* dedicatedReqs =
-            (VkMemoryDedicatedRequirements*)
-            vk_find_struct(
-                (vk_struct_common*)reqs2,
-                VK_STRUCTURE_TYPE_MEMORY_DEDICATED_REQUIREMENTS);
+            vk_find_struct<VkMemoryDedicatedRequirements>(
+                reqs2, VK_STRUCTURE_TYPE_MEMORY_DEDICATED_REQUIREMENTS);
 
         if (!dedicatedReqs) return;
 
@@ -2181,20 +2177,18 @@ public:
 
         VkImageCreateInfo* pCreateInfo_mut = &localCreateInfo;
 
-        VkNativeBufferANDROID* anbInfoPtr =
-            (VkNativeBufferANDROID*)
-            vk_find_struct(
-                (vk_struct_common*)pCreateInfo_mut,
+        const VkNativeBufferANDROID* anbInfoPtr =
+            vk_find_struct<VkNativeBufferANDROID>(
+                pCreateInfo,
                 VK_STRUCTURE_TYPE_NATIVE_BUFFER_ANDROID);
 
         if (anbInfoPtr) {
             localAnb = *anbInfoPtr;
         }
 
-        VkExternalMemoryImageCreateInfo* extImgCiPtr =
-            (VkExternalMemoryImageCreateInfo*)
-            vk_find_struct(
-                (vk_struct_common*)pCreateInfo_mut,
+        const VkExternalMemoryImageCreateInfo* extImgCiPtr =
+            vk_find_struct<VkExternalMemoryImageCreateInfo>(
+                pCreateInfo,
                 VK_STRUCTURE_TYPE_EXTERNAL_MEMORY_IMAGE_CREATE_INFO);
 
         if (extImgCiPtr) {
@@ -2203,22 +2197,20 @@ public:
 
 #ifdef VK_USE_PLATFORM_ANDROID_KHR
         VkExternalFormatANDROID localExtFormatAndroid;
-        VkExternalFormatANDROID* extFormatAndroidPtr =
-        (VkExternalFormatANDROID*)
-        vk_find_struct(
-            (vk_struct_common*)pCreateInfo_mut,
-            VK_STRUCTURE_TYPE_EXTERNAL_FORMAT_ANDROID);
+        const VkExternalFormatANDROID* extFormatAndroidPtr =
+            vk_find_struct<VkExternalFormatANDROID>(
+                pCreateInfo,
+                VK_STRUCTURE_TYPE_EXTERNAL_FORMAT_ANDROID);
         if (extFormatAndroidPtr) {
             localExtFormatAndroid = *extFormatAndroidPtr;
         }
 #endif
 
 #ifdef VK_USE_PLATFORM_FUCHSIA
-        VkBufferCollectionImageCreateInfoFUCHSIA* extBufferCollectionPtr =
-        (VkBufferCollectionImageCreateInfoFUCHSIA*)
-        vk_find_struct(
-            (vk_struct_common*)pCreateInfo_mut,
-            VK_STRUCTURE_TYPE_BUFFER_COLLECTION_IMAGE_CREATE_INFO_FUCHSIA);
+        const VkBufferCollectionImageCreateInfoFUCHSIA* extBufferCollectionPtr =
+            vk_find_struct<VkBufferCollectionImageCreateInfoFUCHSIA>(
+                pCreateInfo,
+                VK_STRUCTURE_TYPE_BUFFER_COLLECTION_IMAGE_CREATE_INFO_FUCHSIA);
 #endif
 
         vk_struct_common* structChain =
@@ -2329,11 +2321,10 @@ public:
 
 #ifdef VK_USE_PLATFORM_ANDROID_KHR
         VkExternalFormatANDROID localExtFormatAndroid;
-        VkExternalFormatANDROID* extFormatAndroidPtr =
-        (VkExternalFormatANDROID*)
-        vk_find_struct(
-            (vk_struct_common*)pCreateInfo_mut,
-            VK_STRUCTURE_TYPE_EXTERNAL_FORMAT_ANDROID);
+        const VkExternalFormatANDROID* extFormatAndroidPtr =
+            vk_find_struct<VkExternalFormatANDROID>(
+                pCreateInfo,
+                VK_STRUCTURE_TYPE_EXTERNAL_FORMAT_ANDROID);
         if (extFormatAndroidPtr) {
             localExtFormatAndroid = *extFormatAndroidPtr;
         }
@@ -2368,11 +2359,10 @@ public:
 
 #ifdef VK_USE_PLATFORM_ANDROID_KHR
         VkExternalFormatANDROID localExtFormatAndroid;
-        VkExternalFormatANDROID* extFormatAndroidPtr =
-        (VkExternalFormatANDROID*)
-        vk_find_struct(
-            (vk_struct_common*)pCreateInfo_mut,
-            VK_STRUCTURE_TYPE_EXTERNAL_FORMAT_ANDROID);
+        const VkExternalFormatANDROID* extFormatAndroidPtr =
+            vk_find_struct<VkExternalFormatANDROID>(
+                pCreateInfo,
+                VK_STRUCTURE_TYPE_EXTERNAL_FORMAT_ANDROID);
         if (extFormatAndroidPtr) {
             localExtFormatAndroid = *extFormatAndroidPtr;
         }
@@ -2473,8 +2463,8 @@ public:
         info.createInfo = *pCreateInfo;
         info.createInfo.pNext = nullptr;
 
-        VkExternalMemoryBufferCreateInfo* extBufCi =
-            (VkExternalMemoryBufferCreateInfo*)vk_find_struct((vk_struct_common*)pCreateInfo,
+        const VkExternalMemoryBufferCreateInfo* extBufCi =
+            vk_find_struct<VkExternalMemoryBufferCreateInfo>(pCreateInfo,
                 VK_STRUCTURE_TYPE_EXTERNAL_MEMORY_BUFFER_CREATE_INFO);
 
         if (!extBufCi) return res;
@@ -2565,9 +2555,9 @@ public:
 
         VkSemaphoreCreateInfo finalCreateInfo = *pCreateInfo;
 
-        VkExportSemaphoreCreateInfoKHR* exportSemaphoreInfoPtr =
-            (VkExportSemaphoreCreateInfoKHR*)vk_find_struct(
-                (vk_struct_common*)pCreateInfo,
+        const VkExportSemaphoreCreateInfoKHR* exportSemaphoreInfoPtr =
+            vk_find_struct<VkExportSemaphoreCreateInfoKHR>(
+                pCreateInfo,
                 VK_STRUCTURE_TYPE_EXPORT_SEMAPHORE_CREATE_INFO_KHR);
 
 #ifdef VK_USE_PLATFORM_FUCHSIA
@@ -3137,8 +3127,8 @@ public:
         (void)input_result;
 
         VkAndroidHardwareBufferUsageANDROID* output_ahw_usage =
-            (VkAndroidHardwareBufferUsageANDROID*)vk_find_struct(
-                (vk_struct_common*)pImageFormatProperties,
+            vk_find_struct<VkAndroidHardwareBufferUsageANDROID>(
+                pImageFormatProperties,
                 VK_STRUCTURE_TYPE_ANDROID_HARDWARE_BUFFER_USAGE_ANDROID);
 
         VkResult hostRes;
diff --git a/system/vulkan_enc/vk_util.h b/system/vulkan_enc/vk_util.h
index da394f4ebfc4271b2be790c1c7d1cb89fc05e2c6..3638043b487d6ecb92e18c2d7fd45cd5777af4f6 100644
--- a/system/vulkan_enc/vk_util.h
+++ b/system/vulkan_enc/vk_util.h
@@ -192,11 +192,24 @@ __vk_find_struct(void *start, VkStructureType sType)
    return NULL;
 }
 
-#define vk_find_struct(__start, __sType) \
-   __vk_find_struct((__start), __sType)
+template <class T> void vk_is_vk_struct(T *s)
+{
+    static_assert(sizeof(s->sType) == sizeof(VkStructureType));
+    static_assert(sizeof(s->pNext) == sizeof(void*));
+}
 
-#define vk_find_struct_const(__start, __sType) \
-   (const void *)__vk_find_struct((void *)(__start), __sType)
+template <class T, class H> T* vk_find_struct(H* head, VkStructureType sType)
+{
+    vk_is_vk_struct(head);
+    return static_cast<T*>(__vk_find_struct(static_cast<void*>(head), sType));
+}
+
+template <class T, class H> const T* vk_find_struct(const H* head, VkStructureType sType)
+{
+    vk_is_vk_struct(head);
+    return static_cast<const T*>(__vk_find_struct(const_cast<void*>(static_cast<const void*>(head)),
+                                 sType));
+}
 
 uint32_t vk_get_driver_version(void);