From a2f3f590f69c42832b4da244b8b8d33a56bbd24f Mon Sep 17 00:00:00 2001 From: Daniil Fajnberg Date: Sun, 26 Dec 2021 23:22:12 +0100 Subject: [PATCH] test for __main__ --- src/mwfin/__main__.py | 1 + tests/test_main.py | 87 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 88 insertions(+) create mode 100644 tests/test_main.py diff --git a/src/mwfin/__main__.py b/src/mwfin/__main__.py index 89d39d6..9db4d13 100644 --- a/src/mwfin/__main__.py +++ b/src/mwfin/__main__.py @@ -63,6 +63,7 @@ def parse_cli() -> dict: def configure_logging(verbosity: int) -> None: root_logger = logging.getLogger() root_logger.addHandler(logging.StreamHandler()) + root_logger.setLevel(logging.CRITICAL) if verbosity > 2: root_logger.setLevel(logging.DEBUG) elif verbosity == 2: diff --git a/tests/test_main.py b/tests/test_main.py new file mode 100644 index 0000000..f73d916 --- /dev/null +++ b/tests/test_main.py @@ -0,0 +1,87 @@ +import logging +import json +from unittest import IsolatedAsyncioTestCase +from unittest.mock import patch +from argparse import Namespace +from io import StringIO + +from mwfin import __main__ as main_module + + +class MainModuleTestCase(IsolatedAsyncioTestCase): + + @patch.object(main_module.ArgumentParser, 'parse_args') + def test_parse_cli(self, mock_parse_args): + mock_parse_args.return_value = mock_args = Namespace(foo='a', bar='b') + expected_output = vars(mock_args) + output = main_module.parse_cli() + self.assertDictEqual(expected_output, output) + + def test_configure_logging(self): + root_logger = logging.getLogger() + root_logger.handlers = [] + main_module.configure_logging(verbosity=0) + self.assertEqual(1, len(root_logger.handlers)) + self.assertIsInstance(root_logger.handlers[0], logging.StreamHandler) + self.assertEqual(logging.CRITICAL, root_logger.level) + root_logger.handlers = [] + main_module.configure_logging(verbosity=1) + self.assertEqual(1, len(root_logger.handlers)) + self.assertIsInstance(root_logger.handlers[0], logging.StreamHandler) + self.assertEqual(logging.WARNING, root_logger.level) + root_logger.handlers = [] + main_module.configure_logging(verbosity=2) + self.assertEqual(1, len(root_logger.handlers)) + self.assertIsInstance(root_logger.handlers[0], logging.StreamHandler) + self.assertEqual(logging.INFO, root_logger.level) + root_logger.handlers = [] + main_module.configure_logging(verbosity=3) + self.assertEqual(1, len(root_logger.handlers)) + self.assertIsInstance(root_logger.handlers[0], logging.StreamHandler) + self.assertEqual(logging.DEBUG, root_logger.level) + root_logger.handlers = [] + main_module.configure_logging(verbosity=9999) + self.assertEqual(1, len(root_logger.handlers)) + self.assertIsInstance(root_logger.handlers[0], logging.StreamHandler) + self.assertEqual(logging.DEBUG, root_logger.level) + + @patch.object(main_module, 'get_all_financials') + @patch.object(main_module, 'configure_logging') + @patch.object(main_module, 'parse_cli') + async def test_main(self, mock_parse_cli, mock_configure_logging, mock_get_all_financials): + mock_parse_cli.return_value = args = { + main_module.VERBOSE: 'foo', + main_module.TICKER_SYMBOL: ['bar', 'baz'], + main_module.QUARTERLY: 'perhaps', + main_module.BATCH_SIZE: 'xyz', + main_module.TO_FILE: None, + main_module.JSON_INDENT: 42, + } + mock_get_all_financials.return_value = mock_data = {'data': 'something cool'} + + # To stdout: + with patch.object(main_module, 'print') as mock_print: + await main_module.main() + mock_parse_cli.assert_called_once_with() + mock_configure_logging.assert_called_once_with(args[main_module.VERBOSE]) + mock_get_all_financials.assert_awaited_once_with(*args[main_module.TICKER_SYMBOL], + quarterly=args[main_module.QUARTERLY], + concurrent_batch_size=args[main_module.BATCH_SIZE]) + mock_print.assert_called_once_with(json.dumps(mock_data, indent=args[main_module.JSON_INDENT])) + mock_parse_cli.reset_mock() + mock_configure_logging.reset_mock() + mock_get_all_financials.reset_mock() + + # To file: + args[main_module.TO_FILE] = 'some_file' + with patch.object(main_module, 'open') as mock_open: + mock_open.return_value.__enter__.return_value = mock_file = StringIO() + await main_module.main() + mock_parse_cli.assert_called_once_with() + mock_configure_logging.assert_called_once_with(args[main_module.VERBOSE]) + mock_get_all_financials.assert_awaited_once_with(*args[main_module.TICKER_SYMBOL], + quarterly=args[main_module.QUARTERLY], + concurrent_batch_size=args[main_module.BATCH_SIZE]) + expected_contents = json.dumps(mock_data, indent=args[main_module.JSON_INDENT]) + mock_file.seek(0) + self.assertEqual(expected_contents, mock_file.read())