diff --git a/backend/app/tests/api/endpoints/base_api_test.py b/backend/app/tests/api/endpoints/base_api_test.py index 322fdd31c6e5cfc1dd8fe98ee55c1da1dd1a7a3c..f97a9a9a76f622337fdd709aafdab7173b725981 100644 --- a/backend/app/tests/api/endpoints/base_api_test.py +++ b/backend/app/tests/api/endpoints/base_api_test.py @@ -42,3 +42,8 @@ class BaseApiTests(unittest.TestCase): @classmethod def tearDownClass(cls) -> None: clear_di() + + def _delete_entry(self, entry): + self.session.rollback() + self.session.delete(entry) + self.session.commit() diff --git a/backend/app/tests/api/endpoints/test_catalogs.py b/backend/app/tests/api/endpoints/test_catalogs.py index a1aa79f7145e3e699378a74074d5406adc90ac4f..51279c207b1c22ac9fbf174dc0b16647ffd45205 100644 --- a/backend/app/tests/api/endpoints/test_catalogs.py +++ b/backend/app/tests/api/endpoints/test_catalogs.py @@ -42,15 +42,17 @@ class TestCatalogs(BaseApiTests): cat_name = "created_catalog" json_input = {"name": cat_name} expected_data = {"name": cat_name, "doi": None} - # When - response = self.client.post("/api/catalogs/", json=json_input) - # Then - self.assertEqual(response.status_code, 200) - self.assertDictEqual(response.json(), expected_data) - # Finally delete item - cat = self.session.exec(select(Catalog).where(Catalog.name == cat_name)).one() - self.session.delete(cat) - self.session.commit() + try: + # When + response = self.client.post("/api/catalogs/", json=json_input) + # Then + self.assertEqual(response.status_code, 200) + self.assertDictEqual(response.json(), expected_data) + finally: + cat = self.session.exec( + select(Catalog).where(Catalog.name == cat_name) + ).one() + self._delete_entry(cat) def test_create_existing_catalog(self): # Given @@ -95,11 +97,11 @@ class TestCatalogs(BaseApiTests): self.session.commit() json_input = {"name": name_to_update, "doi": "12345"} expected_data = json_input - # When - response = self.client.put(f"/api/catalogs/", json=json_input) - # Then - self.assertEqual(response.status_code, 200) - self.assertDictEqual(response.json(), expected_data) - # Finally - self.session.delete(catalog_to_update) - self.session.commit() + try: + # When + response = self.client.put(f"/api/catalogs/", json=json_input) + # Then + self.assertEqual(response.status_code, 200) + self.assertDictEqual(response.json(), expected_data) + finally: + self._delete_entry(catalog_to_update) diff --git a/backend/app/tests/api/endpoints/test_keggs.py b/backend/app/tests/api/endpoints/test_keggs.py index de41fafd57a82418d79e38e6e1a8ced61161599b..4a9b18b1ffbfd1c7ecb63916678b9c53d0f1502f 100644 --- a/backend/app/tests/api/endpoints/test_keggs.py +++ b/backend/app/tests/api/endpoints/test_keggs.py @@ -43,15 +43,15 @@ class TestKeggs(BaseApiTests): kegg_name = "created_kegg" json_input = {"kegg_id": kegg_id, "name": kegg_name} expected_data = {"kegg_id": kegg_id, "name": kegg_name} - # When - response = self.client.post("/api/keggs/", json=json_input) - # Then - self.assertEqual(response.status_code, 200) - self.assertDictEqual(response.json(), expected_data) - # Finally delete item - kegg = self.session.exec(select(Kegg).where(Kegg.kegg_id == kegg_id)).one() - self.session.delete(kegg) - self.session.commit() + try: + # When + response = self.client.post("/api/keggs/", json=json_input) + # Then + self.assertEqual(response.status_code, 200) + self.assertDictEqual(response.json(), expected_data) + finally: + kegg = self.session.exec(select(Kegg).where(Kegg.kegg_id == kegg_id)).one() + self._delete_entry(kegg) def test_create_existing_kegg(self): # Given @@ -100,11 +100,11 @@ class TestKeggs(BaseApiTests): self.session.commit() json_input = {"kegg_id": kegg_id, "name": new_name} expected_data = json_input - # When - response = self.client.put(f"/api/keggs/", json=json_input) - # Then - self.assertEqual(response.status_code, 200) - self.assertDictEqual(response.json(), expected_data) - # Finally - self.session.delete(kegg_to_update) - self.session.commit() + try: + # When + response = self.client.put(f"/api/keggs/", json=json_input) + # Then + self.assertEqual(response.status_code, 200) + self.assertDictEqual(response.json(), expected_data) + finally: + self._delete_entry(kegg_to_update) diff --git a/backend/app/tests/api/endpoints/test_ncbi_taxonomy.py b/backend/app/tests/api/endpoints/test_ncbi_taxonomy.py index 8bc7356a2954b6ed02db8fb9f7db8a878fc41c37..a0e7377e12bccbee88b550da0ef80ef3d47694e6 100644 --- a/backend/app/tests/api/endpoints/test_ncbi_taxonomy.py +++ b/backend/app/tests/api/endpoints/test_ncbi_taxonomy.py @@ -58,17 +58,17 @@ class TestNcbiTaxonomys(BaseApiTests): "name": ncbi_taxonomy_name, "rank": rank_created, } - # When - response = self.client.post("/api/ncbi_taxonomy/", json=json_input) - # Then - self.assertEqual(response.status_code, 200) - self.assertDictEqual(response.json(), expected_data) - # Finally delete item - ncbi_tax = self.session.exec( - select(NcbiTaxonomy).where(NcbiTaxonomy.tax_id == tax_id) - ).one() - self.session.delete(ncbi_tax) - self.session.commit() + try: + # When + response = self.client.post("/api/ncbi_taxonomy/", json=json_input) + # Then + self.assertEqual(response.status_code, 200) + self.assertDictEqual(response.json(), expected_data) + finally: + ncbi_tax = self.session.exec( + select(NcbiTaxonomy).where(NcbiTaxonomy.tax_id == tax_id) + ).one() + self._delete_entry(ncbi_tax) def test_create_existing_ncbi_taxonomy(self): # Given @@ -124,16 +124,16 @@ class TestNcbiTaxonomys(BaseApiTests): self.session.commit() json_input = {"tax_id": tax_id, "name": new_name} expected_data = {"tax_id": tax_id, "name": new_name, "rank": rank} - # When - response = self.client.put( - f"/api/ncbi_taxonomy/?exclude_none=true", json=json_input - ) - # Then - self.assertEqual(response.status_code, 200) - self.assertDictEqual(response.json(), expected_data) - # Finally - self.session.delete(ncbi_taxonomy_to_update) - self.session.commit() + try: + # When + response = self.client.put( + f"/api/ncbi_taxonomy/?exclude_none=true", json=json_input + ) + # Then + self.assertEqual(response.status_code, 200) + self.assertDictEqual(response.json(), expected_data) + finally: + self._delete_entry(ncbi_taxonomy_to_update) def test_update_ncbi_taxonomy_no_rank_exclude_none_false(self): # Given @@ -154,8 +154,5 @@ class TestNcbiTaxonomys(BaseApiTests): f"/api/ncbi_taxonomy/?exclude_none=false", json=json_input ) self.assertEqual(response.status_code, 422) - # Finally finally: - self.session.rollback() - self.session.delete(ncbi_taxonomy_to_update) - self.session.commit() + self._delete_entry(ncbi_taxonomy_to_update)