diff --git a/tests/policy.py b/tests/policy.py
index 15a537ec39bcb00b08fd2925965401b25c6ca162..b8a3621298ed8e11fdae32e7f0def3f5d70b8448 100644
--- a/tests/policy.py
+++ b/tests/policy.py
@@ -124,8 +124,8 @@ class Policy:
             yield Rule
 
 
-    def GetAllTypes(self):
-        TypeIterP = self.__libsepolwrap.init_type_iter(self.__policydbP, None, False)
+    def GetAllTypes(self, isAttr):
+        TypeIterP = self.__libsepolwrap.init_type_iter(self.__policydbP, None, isAttr)
         if (TypeIterP == None):
             sys.exit("Failed to initialize type iterator")
         buf = create_string_buffer(self.__BUFSIZE)
diff --git a/tests/sepol_wrap.cpp b/tests/sepol_wrap.cpp
index cd5336795e49315aed9ad690060a634af75e75da..8fea2d5b458c65fcaac64b07dfb431c60d22e940 100644
--- a/tests/sepol_wrap.cpp
+++ b/tests/sepol_wrap.cpp
@@ -98,13 +98,15 @@ int get_type(char *out, size_t max_size, void *policydbp, void *type_iterp)
             break;
         }
     }
-    if (i->bit >= i->length)
-        return 1;
-    while ((i->alltypes == TYPE_ITER_ALLATTRS
+    while (i->bit < i->length &&
+           ((i->alltypes == TYPE_ITER_ALLATTRS
             && db->type_val_to_struct[i->bit]->flavor != TYPE_ATTRIB)
             || (i->alltypes == TYPE_ITER_ALLTYPES
-            && db->type_val_to_struct[i->bit]->flavor != TYPE_TYPE))
+            && db->type_val_to_struct[i->bit]->flavor != TYPE_TYPE))) {
         i->bit++;
+    }
+    if (i->bit >= i->length)
+        return 1;
     len = snprintf(out, max_size, "%s", db->p_type_val_to_name[i->bit]);
     if (len >= max_size) {
         std::cerr << "type name exceeds buffer size." << std::endl;
diff --git a/tests/treble_sepolicy_tests.py b/tests/treble_sepolicy_tests.py
index c48066d1b27b7fb817afd74f96a69f6afec8713d..58fd85bc3e2b5e9c1f1eec91e5fbb148ccdcd74f 100644
--- a/tests/treble_sepolicy_tests.py
+++ b/tests/treble_sepolicy_tests.py
@@ -155,8 +155,8 @@ def GetAttributes(pol):
 def GetAllTypes(pol, oldpol):
     global alltypes
     global oldalltypes
-    alltypes = pol.GetAllTypes()
-    oldalltypes = oldpol.GetAllTypes()
+    alltypes = pol.GetAllTypes(False)
+    oldalltypes = oldpol.GetAllTypes(False)
 
 def setup(pol):
     GetAllDomains(pol)