From 64a6e409eff570c3dfa732bd834a2d68f4443eb9 Mon Sep 17 00:00:00 2001 From: Stefan Garlonta Date: Fri, 6 Jun 2025 19:53:33 +0200 Subject: [PATCH] :bug: Refactor data loaders to be lazy and use generators to prevent memory problems --- pystreamapi/loaders/__csv/__csv_loader.py | 65 +++++--- pystreamapi/loaders/__json/__json_loader.py | 55 +++--- pystreamapi/loaders/__xml/__xml_loader.py | 54 +++--- pystreamapi/loaders/__yaml/__yaml_loader.py | 25 +-- tests/_loaders/test_csv_loader.py | 175 +++++++++++++------- tests/_loaders/test_json_loader.py | 27 +-- tests/_loaders/test_xml_loader.py | 120 +++++++++----- tests/_loaders/test_yaml_loader.py | 26 ++- 8 files changed, 340 insertions(+), 207 deletions(-) diff --git a/pystreamapi/loaders/__csv/__csv_loader.py b/pystreamapi/loaders/__csv/__csv_loader.py index f009ad4..85d8b6e 100644 --- a/pystreamapi/loaders/__csv/__csv_loader.py +++ b/pystreamapi/loaders/__csv/__csv_loader.py @@ -1,41 +1,62 @@ from collections import namedtuple from csv import reader +from io import StringIO +from typing import Any, Iterator from pystreamapi.loaders.__loader_utils import LoaderUtils -from pystreamapi.loaders.__lazy_file_iterable import LazyFileIterable -def csv(file_path: str, cast_types=True, delimiter=',', encoding="utf-8") -> LazyFileIterable: +def csv( + src: str, read_from_src=False, cast_types=True, delimiter=',', encoding="utf-8" +) -> Iterator[Any]: """ - Loads a CSV file and converts it into a list of namedtuples. - - Returns: - list: A list of namedtuples, where each namedtuple represents a row in the CSV. - :param cast_types: Set as False to disable casting of values to int, bool or float. - :param encoding: The encoding of the CSV file. - :param file_path: The path to the CSV file. - :param delimiter: The delimiter used in the CSV file. + Lazily loads CSV data from either a path or a string and yields namedtuples. + + Args: + src (str): Either the path to a CSV file or a CSV string. + read_from_src (bool): If True, src is treated as a CSV string. + If False, src is treated as a path to a CSV file. + cast_types (bool): Set as False to disable casting of values to int, bool or float. + delimiter (str): The delimiter used in the CSV data. + encoding (str): The encoding of the CSV file (only used when reading from file). + + Yields: + namedtuple: Each row in the CSV as a namedtuple. """ - file_path = LoaderUtils.validate_path(file_path) - return LazyFileIterable(lambda: __load_csv(file_path, cast_types, delimiter, encoding)) + if not read_from_src: + src = LoaderUtils.validate_path(src) + return __load_csv_from_file(src, cast_types, delimiter, encoding) + return __load_csv_from_string(src, cast_types, delimiter) -def __load_csv(file_path, cast, delimiter, encoding): - """Load a CSV file and convert it into a list of namedtuples""" +def __load_csv_from_file(file_path, cast, delimiter, encoding): + """Load a CSV file and convert it into a generator of namedtuples""" # skipcq: PTC-W6004 with open(file_path, mode='r', newline='', encoding=encoding) as csvfile: - csvreader = reader(csvfile, delimiter=delimiter) + yield from __process_csv(csvfile, cast, delimiter) + + +def __load_csv_from_string(csv_string, cast, delimiter): + """Load a CSV from string and convert it into a generator of namedtuples""" + with StringIO(csv_string) as csvfile: + yield from __process_csv(csvfile, cast, delimiter) + - # Create a namedtuple type, casting the header values to int or float if possible - header = __get_csv_header(csvreader) +def __process_csv(csvfile, cast, delimiter): + """Process CSV data and yield namedtuples""" + csvreader = reader(csvfile, delimiter=delimiter) - Row = namedtuple('Row', list(header)) + # Create a namedtuple type, casting the header values to int or float if possible + header = __get_csv_header(csvreader) + if not header: + return - mapper = LoaderUtils.try_cast if cast else lambda x: x + Row = namedtuple('Row', list(header)) + mapper = LoaderUtils.try_cast if cast else lambda x: x - # Process the data, casting values to int or float if possible - data = [Row(*[mapper(value) for value in row]) for row in csvreader] - return data + # Yield the data row by row, casting values to int or float if possible + for row in csvreader: + yield Row(*[mapper(value) for value in row]) def __get_csv_header(csvreader): diff --git a/pystreamapi/loaders/__json/__json_loader.py b/pystreamapi/loaders/__json/__json_loader.py index cf416ef..3a743b3 100644 --- a/pystreamapi/loaders/__json/__json_loader.py +++ b/pystreamapi/loaders/__json/__json_loader.py @@ -1,40 +1,51 @@ import json as jsonlib from collections import namedtuple +from typing import Any, Iterator -from pystreamapi.loaders.__lazy_file_iterable import LazyFileIterable from pystreamapi.loaders.__loader_utils import LoaderUtils -def json(src: str, read_from_src=False) -> LazyFileIterable: +def json(src: str, read_from_src=False) -> Iterator[Any]: """ - Loads JSON data from either a path or a string and converts it into a list of namedtuples. + Lazily loads JSON data from either a path or a string and yields namedtuples. - Returns: - list: A list of namedtuples, where each namedtuple represents an object in the JSON. - :param src: Either the path to a JSON file or a JSON string. - :param read_from_src: If True, src is treated as a JSON string. If False, src is treated as - a path to a JSON file. + Args: + src (str): Either the path to a JSON file or a JSON string. + read_from_src (bool): If True, src is treated as a JSON string. + If False, src is treated as a path to a JSON file. + + Yields: + namedtuple: Each object in the JSON as a namedtuple. """ if read_from_src: - return LazyFileIterable(lambda: __load_json_string(src)) + return __lazy_load_json_string(src) path = LoaderUtils.validate_path(src) - return LazyFileIterable(lambda: __load_json_file(path)) + return __lazy_load_json_file(path) + + +def __lazy_load_json_file(file_path: str) -> Iterator[Any]: + """Lazily read and parse a JSON file, yielding namedtuples.""" + + def generator(): + # skipcq: PTC-W6004 + with open(file_path, mode='r', encoding='utf-8') as jsonfile: + src = jsonfile.read() + if src == '': + return + yield from jsonlib.loads(src, object_hook=__dict_to_namedtuple) + + return generator() -def __load_json_file(file_path): - """Load a JSON file and convert it into a list of namedtuples""" - # skipcq: PTC-W6004 - with open(file_path, mode='r', encoding='utf-8') as jsonfile: - src = jsonfile.read() - if src == '': - return [] - data = jsonlib.loads(src, object_hook=__dict_to_namedtuple) - return data +def __lazy_load_json_string(json_string: str) -> Iterator[Any]: + """Lazily parse a JSON string, yielding namedtuples.""" + def generator(): + if not json_string.strip(): + return + yield from jsonlib.loads(json_string, object_hook=__dict_to_namedtuple) -def __load_json_string(json_string): - """Load JSON data from a string and convert it into a list of namedtuples""" - return jsonlib.loads(json_string, object_hook=__dict_to_namedtuple) + return generator() def __dict_to_namedtuple(d, name='Item'): diff --git a/pystreamapi/loaders/__xml/__xml_loader.py b/pystreamapi/loaders/__xml/__xml_loader.py index 98b551e..f617677 100644 --- a/pystreamapi/loaders/__xml/__xml_loader.py +++ b/pystreamapi/loaders/__xml/__xml_loader.py @@ -1,3 +1,5 @@ +from typing import Iterator, Any + try: from defusedxml import ElementTree except ImportError as exc: @@ -5,7 +7,6 @@ "Please install the xml_loader extra dependency to use the xml loader." ) from exc from collections import namedtuple -from pystreamapi.loaders.__lazy_file_iterable import LazyFileIterable from pystreamapi.loaders.__loader_utils import LoaderUtils @@ -21,14 +22,14 @@ def __init__(self): def xml(src: str, read_from_src=False, retrieve_children=True, cast_types=True, - encoding="utf-8") -> LazyFileIterable: + encoding="utf-8") -> Iterator[Any]: """ Loads XML data from either a path or a string and converts it into a list of namedtuples. Warning: This method isn't safe against malicious XML trees. Parse only safe XML from sources you trust. Returns: - LazyFileIterable: A list of namedtuples, where each namedtuple represents an XML element. + An iterator with namedtuples, where each namedtuple represents an XML element. :param retrieve_children: If true, the children of the root element are used as stream elements. :param encoding: The encoding of the XML file. @@ -39,32 +40,37 @@ def xml(src: str, read_from_src=False, retrieve_children=True, cast_types=True, """ config.cast_types = cast_types config.retrieve_children = retrieve_children + if read_from_src: - return LazyFileIterable(lambda: __load_xml_string(src)) + return _lazy_parse_xml_string(src) + path = LoaderUtils.validate_path(src) - return LazyFileIterable(lambda: __load_xml_file(path, encoding)) + return _lazy_parse_xml_file(path, encoding) + + +def _lazy_parse_xml_file(file_path: str, encoding: str) -> Iterator[Any]: + def generator(): + with open(file_path, mode='r', encoding=encoding) as xmlfile: + xml_string = xmlfile.read() + yield from _parse_xml_string_lazy(xml_string) + return generator() -def __load_xml_file(file_path, encoding): - """Load an XML file and convert it into a list of namedtuples.""" - # skipcq: PTC-W6004 - with open(file_path, mode='r', encoding=encoding) as xmlfile: - src = xmlfile.read() - if src: - return __parse_xml_string(src) - return [] +def _lazy_parse_xml_string(xml_string: str) -> Iterator[Any]: + def generator(): + yield from _parse_xml_string_lazy(xml_string) -def __load_xml_string(xml_string): - """Load XML data from a string and convert it into a list of namedtuples.""" - return __parse_xml_string(xml_string) + return generator() -def __parse_xml_string(xml_string): - """Parse XML string and convert it into a list of namedtuples.""" +def _parse_xml_string_lazy(xml_string: str) -> Iterator[Any]: root = ElementTree.fromstring(xml_string) - parsed_xml = __parse_xml(root) - return __flatten(parsed_xml) if config.retrieve_children else [parsed_xml] + parsed = __parse_xml(root) + if config.retrieve_children: + yield from __flatten(parsed) + else: + yield parsed def __parse_xml(element): @@ -107,11 +113,9 @@ def __filter_single_items(tag_dict): def __flatten(data): - """Flatten a list of lists.""" - res = [] + """Yield flattened elements from a possibly nested structure.""" for item in data: if isinstance(item, list): - res.extend(item) + yield from item else: - res.append(item) - return res + yield item diff --git a/pystreamapi/loaders/__yaml/__yaml_loader.py b/pystreamapi/loaders/__yaml/__yaml_loader.py index 31dbc85..c305a56 100644 --- a/pystreamapi/loaders/__yaml/__yaml_loader.py +++ b/pystreamapi/loaders/__yaml/__yaml_loader.py @@ -1,3 +1,5 @@ +from typing import Any, Iterator + try: import yaml as yaml_lib except ImportError as exc: @@ -6,11 +8,10 @@ ) from exc from collections import namedtuple -from pystreamapi.loaders.__lazy_file_iterable import LazyFileIterable from pystreamapi.loaders.__loader_utils import LoaderUtils -def yaml(src: str, read_from_src=False) -> LazyFileIterable: +def yaml(src: str, read_from_src=False) -> Iterator[Any]: """ Loads YAML data from either a path or a string and converts it into a list of namedtuples. @@ -23,26 +24,26 @@ def yaml(src: str, read_from_src=False) -> LazyFileIterable: list: A list of namedtuples, where each namedtuple represents an object in the YAML. """ if read_from_src: - return LazyFileIterable(lambda: __load_yaml_string(src)) + return __load_yaml_string(src) path = LoaderUtils.validate_path(src) - return LazyFileIterable(lambda: __load_yaml_file(path)) + return __load_yaml_file(path) def __load_yaml_file(file_path): """Load a YAML file and convert it into a list of namedtuples""" # skipcq: PTC-W6004 - with open(file_path, mode='r', encoding='utf-8') as yamlfile: - src = yamlfile.read() - if src == '': - return [] - data = yaml_lib.safe_load(src) - return __convert_to_namedtuples(data) + with open(file_path, 'r', encoding='utf-8') as yamlfile: + # Supports both single and multiple documents + for document in yaml_lib.safe_load_all(yamlfile): + if document: + yield from __convert_to_namedtuples(document) def __load_yaml_string(yaml_string): """Load YAML data from a string and convert it into a list of namedtuples""" - data = yaml_lib.safe_load(yaml_string) - return [] if data is None else __convert_to_namedtuples(data) + for document in yaml_lib.safe_load_all(yaml_string): + if document: + yield from __convert_to_namedtuples(document) def __convert_to_namedtuples(data, name='Item'): diff --git a/tests/_loaders/test_csv_loader.py b/tests/_loaders/test_csv_loader.py index c8d6a71..161551c 100644 --- a/tests/_loaders/test_csv_loader.py +++ b/tests/_loaders/test_csv_loader.py @@ -1,73 +1,124 @@ # pylint: disable=not-context-manager +from contextlib import contextmanager from unittest import TestCase from unittest.mock import patch, mock_open from _loaders.file_test import OPEN, PATH_EXISTS, PATH_ISFILE from pystreamapi.loaders import csv -file_content = """ -attr1,attr2 + +class TestCSVLoader(TestCase): + """Test cases for the CSV loader functionality.""" + + def setUp(self): + self.file_content = """attr1,attr2 1,2.0 -a,b -""" -file_path = 'path/to/data.csv' +a,b""" + self.file_path = 'path/to/data.csv' + @contextmanager + def mock_csv_file(self, content=None, exists=True, is_file=True): + """Context manager for mocking CSV file operations. -class TestCSVLoader(TestCase): + Args: + content: The content of the mocked file + exists: Whether the file exists + is_file: Whether the path points to a file + """ + content = content if content is not None else self.file_content + with (patch(OPEN, mock_open(read_data=content)), + patch(PATH_EXISTS, return_value=exists), + patch(PATH_ISFILE, return_value=is_file)): + yield + + def test_csv_loader_basic_functionality(self): + """Test basic CSV loading with type casting.""" + with self.mock_csv_file(): + data = csv(self.file_path) + + # Test first row + first = next(data) + self.assertEqual(first.attr1, 1) + self.assertIsInstance(first.attr1, int) + self.assertEqual(first.attr2, 2.0) + self.assertIsInstance(first.attr2, float) + + # Test second row + second = next(data) + self.assertEqual(second.attr1, 'a') + self.assertIsInstance(second.attr1, str) + self.assertEqual(second.attr2, 'b') + self.assertIsInstance(second.attr2, str) + + # Verify end of file + self.assertRaises(StopIteration, next, data) + + def test_csv_loader_without_type_casting(self): + """Test CSV loading with type casting disabled.""" + with self.mock_csv_file(): + data = csv(self.file_path, cast_types=False) + + # Verify all values remain as strings + first = next(data) + self.assertIsInstance(first.attr1, str) + self.assertIsInstance(first.attr2, str) + self.assertEqual(first.attr1, '1') + self.assertEqual(first.attr2, '2.0') + + def test_csv_loader_iteration(self): + """Test CSV loader's iteration capability.""" + with self.mock_csv_file(): + data = csv(self.file_path) + self.assertEqual(len(list(data)), 2) + + def test_csv_loader_custom_delimiter(self): + """Test CSV loading with a custom delimiter.""" + content_with_semicolon = self.file_content.replace(",", ";") + with self.mock_csv_file(content=content_with_semicolon): + data = csv(self.file_path, delimiter=';') + first = next(data) + self.assertEqual(first.attr1, 1) + self.assertEqual(first.attr2, 2.0) + + def test_csv_loader_edge_cases(self): + """Test CSV loader with edge cases.""" + # Empty file + with self.mock_csv_file(content=""): + data = csv(self.file_path) + self.assertEqual(len(list(data)), 0) + + # Invalid file path + with self.mock_csv_file(exists=False): + with self.assertRaises(FileNotFoundError): + csv('path/to/invalid.csv') + + # Path is not a file + with self.mock_csv_file(is_file=False): + with self.assertRaises(ValueError): + csv('../') + + def test_csv_loader_from_string(self): + """Test CSV loading from a string.""" + data = csv(self.file_content, read_from_src=True) + + # Test first row + first = next(data) + self.assertEqual(first.attr1, 1) + self.assertIsInstance(first.attr1, int) + self.assertEqual(first.attr2, 2.0) + self.assertIsInstance(first.attr2, float) + + # Test second row + second = next(data) + self.assertEqual(second.attr1, 'a') + self.assertIsInstance(second.attr1, str) + self.assertEqual(second.attr2, 'b') + self.assertIsInstance(second.attr2, str) + + # Verify end of file + self.assertRaises(StopIteration, next, data) - def test_csv_loader(self): - with (patch(OPEN, mock_open(read_data=file_content)), - patch(PATH_EXISTS, return_value=True), - patch(PATH_ISFILE, return_value=True)): - data = csv(file_path) - self.assertEqual(len(data), 2) - self.assertEqual(data[0].attr1, 1) - self.assertIsInstance(data[0].attr1, int) - self.assertEqual(data[0].attr2, 2.0) - self.assertIsInstance(data[0].attr2, float) - self.assertEqual(data[1].attr1, 'a') - self.assertIsInstance(data[1].attr1, str) - - def test_csv_loader_with_casting_disabled(self): - with (patch(OPEN, mock_open(read_data=file_content)), - patch(PATH_EXISTS, return_value=True), - patch(PATH_ISFILE, return_value=True)): - data = csv(file_path, cast_types=False) - self.assertEqual(len(data), 2) - self.assertEqual(data[0].attr1, '1') - self.assertIsInstance(data[0].attr1, str) - self.assertEqual(data[0].attr2, '2.0') - self.assertIsInstance(data[0].attr2, str) - self.assertEqual(data[1].attr1, 'a') - self.assertIsInstance(data[1].attr1, str) - - def test_csv_loader_is_iterable(self): - with (patch(OPEN, mock_open(read_data=file_content)), - patch(PATH_EXISTS, return_value=True), - patch(PATH_ISFILE, return_value=True)): - data = csv(file_path) - self.assertEqual(len(list(iter(data))), 2) - - def test_csv_loader_with_custom_delimiter(self): - with (patch(OPEN, mock_open(read_data=file_content.replace(",", ";"))), - patch(PATH_EXISTS, return_value=True), - patch(PATH_ISFILE, return_value=True)): - data = csv(file_path, delimiter=';') - self.assertEqual(len(data), 2) - self.assertEqual(data[0].attr1, 1) - self.assertIsInstance(data[0].attr1, int) - - def test_csv_loader_with_empty_file(self): - with (patch(OPEN, mock_open(read_data="")), - patch(PATH_EXISTS, return_value=True), - patch(PATH_ISFILE, return_value=True)): - data = csv(file_path) - self.assertEqual(len(data), 0) - - def test_csv_loader_with_invalid_path(self): - with self.assertRaises(FileNotFoundError): - csv('path/to/invalid.csv') - - def test_csv_loader_with_no_file(self): - with self.assertRaises(ValueError): - csv('../') + def test_csv_loader_from_empty_string(self): + """Test CSV loading from an empty string.""" + with self.assertRaises(StopIteration): + next(csv("", read_from_src=True)) diff --git a/tests/_loaders/test_json_loader.py b/tests/_loaders/test_json_loader.py index 20cd044..b5da983 100644 --- a/tests/_loaders/test_json_loader.py +++ b/tests/_loaders/test_json_loader.py @@ -1,5 +1,4 @@ # pylint: disable=not-context-manager -from json import JSONDecodeError from unittest import TestCase from unittest.mock import patch, mock_open @@ -46,7 +45,7 @@ def test_json_loader_with_empty_file(self): patch(PATH_EXISTS, return_value=True), patch(PATH_ISFILE, return_value=True)): data = json(file_path) - self.assertEqual(len(data), 0) + self.assertRaises(StopIteration, next, data) def test_json_loader_with_invalid_path(self): with self.assertRaises(FileNotFoundError): @@ -61,14 +60,20 @@ def test_json_loader_from_string(self): self._check_extracted_data(data) def test_json_loader_from_empty_string(self): - with self.assertRaises(JSONDecodeError): - len(json('', read_from_src=True)) + self.assertRaises(StopIteration, next, json("", read_from_src=True)) def _check_extracted_data(self, data): - self.assertEqual(len(data), 2) - self.assertEqual(data[0].attr1, 1) - self.assertIsInstance(data[0].attr1, int) - self.assertEqual(data[0].attr2, 2.0) - self.assertIsInstance(data[0].attr2, float) - self.assertIsInstance(data[1].attr1, list) - self.assertEqual(data[1].attr1[0].attr1, 'a') + # Test first row + first = next(data) + self.assertEqual(first.attr1, 1) + self.assertIsInstance(first.attr1, int) + self.assertEqual(first.attr2, 2.0) + self.assertIsInstance(first.attr2, float) + + # Test second row + second = next(data) + self.assertEqual(second.attr1[0].attr1, 'a') + self.assertIsInstance(second.attr1, list) + + # Verify end of file + self.assertRaises(StopIteration, next, data) diff --git a/tests/_loaders/test_xml_loader.py b/tests/_loaders/test_xml_loader.py index 84d32b2..04fb10c 100644 --- a/tests/_loaders/test_xml_loader.py +++ b/tests/_loaders/test_xml_loader.py @@ -1,4 +1,5 @@ # pylint: disable=not-context-manager +from contextlib import contextmanager from unittest import TestCase from unittest.mock import patch, mock_open from xml.etree.ElementTree import ParseError @@ -31,58 +32,80 @@ class TestXmlLoader(TestCase): + @contextmanager + def mock_csv_file(self, content=None, exists=True, is_file=True): + """Context manager for mocking CSV file operations. + + Args: + content: The content of the mocked file + exists: Whether the file exists + is_file: Whether the path points to a file + """ + content = content if content is not None else self.file_content + with (patch(OPEN, mock_open(read_data=content)), + patch(PATH_EXISTS, return_value=exists), + patch(PATH_ISFILE, return_value=is_file)): + yield + def test_xml_loader_from_file_children(self): - with (patch(OPEN, mock_open(read_data=file_content)), - patch(PATH_EXISTS, return_value=True), - patch(PATH_ISFILE, return_value=True)): + with self.mock_csv_file(file_content): data = xml(file_path) - self.assertEqual(len(data), 3) - self.assertEqual(data[0].salary, 80000) - self.assertIsInstance(data[0].salary, int) - self.assertEqual(data[1].child.name, "Frank") - self.assertIsInstance(data[1].child.name, str) - self.assertEqual(data[2].cars.car[0], 'Bugatti') - self.assertIsInstance(data[2].cars.car[0], str) + + first = next(data) + self.assertEqual(first.salary, 80000) + self.assertIsInstance(first.salary, int) + + second = next(data) + self.assertEqual(second.child.name, "Frank") + self.assertIsInstance(second.child.name, str) + + third = next(data) + self.assertEqual(third.cars.car[0], 'Bugatti') + self.assertIsInstance(third.cars.car[0], str) + + self.assertRaises(StopIteration, next, data) def test_xml_loader_from_file_no_children_false(self): - with (patch(OPEN, mock_open(read_data=file_content)), - patch(PATH_EXISTS, return_value=True), - patch(PATH_ISFILE, return_value=True)): + with self.mock_csv_file(file_content): data = xml(file_path, retrieve_children=False) - self.assertEqual(len(data), 1) - self.assertEqual(data[0].employee[0].salary, 80000) - self.assertIsInstance(data[0].employee[0].salary, int) - self.assertEqual(data[0].employee[1].child.name, "Frank") - self.assertIsInstance(data[0].employee[1].child.name, str) - self.assertEqual(data[0].founder.cars.car[0], 'Bugatti') - self.assertIsInstance(data[0].founder.cars.car[0], str) + + first = next(data) + self.assertEqual(first.employee[0].salary, 80000) + self.assertIsInstance(first.employee[0].salary, int) + self.assertEqual(first.employee[1].child.name, "Frank") + self.assertIsInstance(first.employee[1].child.name, str) + self.assertEqual(first.founder.cars.car[0], 'Bugatti') + self.assertIsInstance(first.founder.cars.car[0], str) + + self.assertRaises(StopIteration, next, data) def test_xml_loader_no_casting(self): - with (patch(OPEN, mock_open(read_data=file_content)), - patch(PATH_EXISTS, return_value=True), - patch(PATH_ISFILE, return_value=True)): + with self.mock_csv_file(file_content): data = xml(file_path, cast_types=False) - self.assertEqual(len(data), 3) - self.assertEqual(data[0].salary, '80000') - self.assertIsInstance(data[0].salary, str) - self.assertEqual(data[1].child.name, "Frank") - self.assertIsInstance(data[1].child.name, str) - self.assertEqual(data[2].cars.car[0], 'Bugatti') - self.assertIsInstance(data[2].cars.car[0], str) + + first = next(data) + self.assertEqual(first.salary, '80000') + self.assertIsInstance(first.salary, str) + + second = next(data) + self.assertEqual(second.child.name, "Frank") + self.assertIsInstance(second.child.name, str) + + third = next(data) + self.assertEqual(third.cars.car[0], 'Bugatti') + self.assertIsInstance(third.cars.car[0], str) + + self.assertRaises(StopIteration, next, data) def test_xml_loader_is_iterable(self): - with (patch(OPEN, mock_open(read_data=file_content)), - patch(PATH_EXISTS, return_value=True), - patch(PATH_ISFILE, return_value=True)): + with self.mock_csv_file(file_content): data = xml(file_path) self.assertEqual(len(list(iter(data))), 3) def test_xml_loader_with_empty_file(self): - with (patch(OPEN, mock_open(read_data="")), - patch(PATH_EXISTS, return_value=True), - patch(PATH_ISFILE, return_value=True)): + with self.mock_csv_file(''): data = xml(file_path) - self.assertEqual(len(data), 0) + self.assertRaises(ParseError, next, data) def test_xml_loader_with_invalid_path(self): with self.assertRaises(FileNotFoundError): @@ -94,14 +117,21 @@ def test_xml_loader_with_no_file(self): def test_xml_loader_from_string(self): data = xml(file_content, read_from_src=True) - self.assertEqual(len(data), 3) - self.assertEqual(data[0].salary, 80000) - self.assertIsInstance(data[0].salary, int) - self.assertEqual(data[1].child.name, "Frank") - self.assertIsInstance(data[1].child.name, str) - self.assertEqual(data[2].cars.car[0], 'Bugatti') - self.assertIsInstance(data[2].cars.car[0], str) + + first = next(data) + self.assertEqual(first.salary, 80000) + self.assertIsInstance(first.salary, int) + + second = next(data) + self.assertEqual(second.child.name, "Frank") + self.assertIsInstance(second.child.name, str) + + third = next(data) + self.assertEqual(third.cars.car[0], 'Bugatti') + self.assertIsInstance(third.cars.car[0], str) + + self.assertRaises(StopIteration, next, data) def test_xml_loader_from_empty_string(self): with self.assertRaises(ParseError): - len(xml('', read_from_src=True)) + list(xml('', read_from_src=True)) diff --git a/tests/_loaders/test_yaml_loader.py b/tests/_loaders/test_yaml_loader.py index f9beee1..6326ea7 100644 --- a/tests/_loaders/test_yaml_loader.py +++ b/tests/_loaders/test_yaml_loader.py @@ -1,4 +1,5 @@ # pylint: disable=not-context-manager +from types import GeneratorType from unittest import TestCase from unittest.mock import patch, mock_open @@ -37,7 +38,7 @@ def test_yaml_loader_with_empty_file(self): patch(PATH_EXISTS, return_value=True), patch(PATH_ISFILE, return_value=True)): data = yaml(file_path) - self.assertEqual(len(data), 0) + self.assertEqual(len(list(data)), 0) def test_yaml_loader_with_invalid_path(self): with self.assertRaises(FileNotFoundError): @@ -54,11 +55,20 @@ def test_yaml_loader_from_string(self): def test_yaml_loader_from_empty_string(self): self.assertEqual(list(yaml('', read_from_src=True)), []) + def test_yaml_loader_is_lazy(self): + with (patch(OPEN, mock_open(read_data=file_content)), + patch(PATH_EXISTS, return_value=True), + patch(PATH_ISFILE, return_value=True)): + data = yaml(file_path) + self.assertIsInstance(data, GeneratorType) + def _check_extracted_data(self, data): - self.assertEqual(len(data), 2) - self.assertEqual(data[0].attr1, 1) - self.assertIsInstance(data[0].attr1, int) - self.assertEqual(data[0].attr2, 2.0) - self.assertIsInstance(data[0].attr2, float) - self.assertIsInstance(data[1].attr1, list) - self.assertEqual(data[1].attr1[0].attr1, 'a') + first = next(data) + self.assertEqual(first.attr1, 1) + self.assertIsInstance(first.attr1, int) + self.assertEqual(first.attr2, 2.0) + self.assertIsInstance(first.attr2, float) + second = next(data) + self.assertIsInstance(second.attr1, list) + self.assertEqual(second.attr1[0].attr1, 'a') + self.assertRaises(StopIteration, next, data)