test for __main__

This commit is contained in:
Daniil Fajnberg 2021-12-26 23:22:12 +01:00
parent 3276dbfae3
commit a2f3f590f6
2 changed files with 88 additions and 0 deletions

View File

@ -63,6 +63,7 @@ def parse_cli() -> dict:
def configure_logging(verbosity: int) -> None: def configure_logging(verbosity: int) -> None:
root_logger = logging.getLogger() root_logger = logging.getLogger()
root_logger.addHandler(logging.StreamHandler()) root_logger.addHandler(logging.StreamHandler())
root_logger.setLevel(logging.CRITICAL)
if verbosity > 2: if verbosity > 2:
root_logger.setLevel(logging.DEBUG) root_logger.setLevel(logging.DEBUG)
elif verbosity == 2: elif verbosity == 2:

87
tests/test_main.py Normal file
View File

@ -0,0 +1,87 @@
import logging
import json
from unittest import IsolatedAsyncioTestCase
from unittest.mock import patch
from argparse import Namespace
from io import StringIO
from mwfin import __main__ as main_module
class MainModuleTestCase(IsolatedAsyncioTestCase):
@patch.object(main_module.ArgumentParser, 'parse_args')
def test_parse_cli(self, mock_parse_args):
mock_parse_args.return_value = mock_args = Namespace(foo='a', bar='b')
expected_output = vars(mock_args)
output = main_module.parse_cli()
self.assertDictEqual(expected_output, output)
def test_configure_logging(self):
root_logger = logging.getLogger()
root_logger.handlers = []
main_module.configure_logging(verbosity=0)
self.assertEqual(1, len(root_logger.handlers))
self.assertIsInstance(root_logger.handlers[0], logging.StreamHandler)
self.assertEqual(logging.CRITICAL, root_logger.level)
root_logger.handlers = []
main_module.configure_logging(verbosity=1)
self.assertEqual(1, len(root_logger.handlers))
self.assertIsInstance(root_logger.handlers[0], logging.StreamHandler)
self.assertEqual(logging.WARNING, root_logger.level)
root_logger.handlers = []
main_module.configure_logging(verbosity=2)
self.assertEqual(1, len(root_logger.handlers))
self.assertIsInstance(root_logger.handlers[0], logging.StreamHandler)
self.assertEqual(logging.INFO, root_logger.level)
root_logger.handlers = []
main_module.configure_logging(verbosity=3)
self.assertEqual(1, len(root_logger.handlers))
self.assertIsInstance(root_logger.handlers[0], logging.StreamHandler)
self.assertEqual(logging.DEBUG, root_logger.level)
root_logger.handlers = []
main_module.configure_logging(verbosity=9999)
self.assertEqual(1, len(root_logger.handlers))
self.assertIsInstance(root_logger.handlers[0], logging.StreamHandler)
self.assertEqual(logging.DEBUG, root_logger.level)
@patch.object(main_module, 'get_all_financials')
@patch.object(main_module, 'configure_logging')
@patch.object(main_module, 'parse_cli')
async def test_main(self, mock_parse_cli, mock_configure_logging, mock_get_all_financials):
mock_parse_cli.return_value = args = {
main_module.VERBOSE: 'foo',
main_module.TICKER_SYMBOL: ['bar', 'baz'],
main_module.QUARTERLY: 'perhaps',
main_module.BATCH_SIZE: 'xyz',
main_module.TO_FILE: None,
main_module.JSON_INDENT: 42,
}
mock_get_all_financials.return_value = mock_data = {'data': 'something cool'}
# To stdout:
with patch.object(main_module, 'print') as mock_print:
await main_module.main()
mock_parse_cli.assert_called_once_with()
mock_configure_logging.assert_called_once_with(args[main_module.VERBOSE])
mock_get_all_financials.assert_awaited_once_with(*args[main_module.TICKER_SYMBOL],
quarterly=args[main_module.QUARTERLY],
concurrent_batch_size=args[main_module.BATCH_SIZE])
mock_print.assert_called_once_with(json.dumps(mock_data, indent=args[main_module.JSON_INDENT]))
mock_parse_cli.reset_mock()
mock_configure_logging.reset_mock()
mock_get_all_financials.reset_mock()
# To file:
args[main_module.TO_FILE] = 'some_file'
with patch.object(main_module, 'open') as mock_open:
mock_open.return_value.__enter__.return_value = mock_file = StringIO()
await main_module.main()
mock_parse_cli.assert_called_once_with()
mock_configure_logging.assert_called_once_with(args[main_module.VERBOSE])
mock_get_all_financials.assert_awaited_once_with(*args[main_module.TICKER_SYMBOL],
quarterly=args[main_module.QUARTERLY],
concurrent_batch_size=args[main_module.BATCH_SIZE])
expected_contents = json.dumps(mock_data, indent=args[main_module.JSON_INDENT])
mock_file.seek(0)
self.assertEqual(expected_contents, mock_file.read())