test for __main__
This commit is contained in:
parent
3276dbfae3
commit
a2f3f590f6
@ -63,6 +63,7 @@ def parse_cli() -> dict:
|
||||
def configure_logging(verbosity: int) -> None:
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.addHandler(logging.StreamHandler())
|
||||
root_logger.setLevel(logging.CRITICAL)
|
||||
if verbosity > 2:
|
||||
root_logger.setLevel(logging.DEBUG)
|
||||
elif verbosity == 2:
|
||||
|
87
tests/test_main.py
Normal file
87
tests/test_main.py
Normal file
@ -0,0 +1,87 @@
|
||||
import logging
|
||||
import json
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
from unittest.mock import patch
|
||||
from argparse import Namespace
|
||||
from io import StringIO
|
||||
|
||||
from mwfin import __main__ as main_module
|
||||
|
||||
|
||||
class MainModuleTestCase(IsolatedAsyncioTestCase):
|
||||
|
||||
@patch.object(main_module.ArgumentParser, 'parse_args')
|
||||
def test_parse_cli(self, mock_parse_args):
|
||||
mock_parse_args.return_value = mock_args = Namespace(foo='a', bar='b')
|
||||
expected_output = vars(mock_args)
|
||||
output = main_module.parse_cli()
|
||||
self.assertDictEqual(expected_output, output)
|
||||
|
||||
def test_configure_logging(self):
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.handlers = []
|
||||
main_module.configure_logging(verbosity=0)
|
||||
self.assertEqual(1, len(root_logger.handlers))
|
||||
self.assertIsInstance(root_logger.handlers[0], logging.StreamHandler)
|
||||
self.assertEqual(logging.CRITICAL, root_logger.level)
|
||||
root_logger.handlers = []
|
||||
main_module.configure_logging(verbosity=1)
|
||||
self.assertEqual(1, len(root_logger.handlers))
|
||||
self.assertIsInstance(root_logger.handlers[0], logging.StreamHandler)
|
||||
self.assertEqual(logging.WARNING, root_logger.level)
|
||||
root_logger.handlers = []
|
||||
main_module.configure_logging(verbosity=2)
|
||||
self.assertEqual(1, len(root_logger.handlers))
|
||||
self.assertIsInstance(root_logger.handlers[0], logging.StreamHandler)
|
||||
self.assertEqual(logging.INFO, root_logger.level)
|
||||
root_logger.handlers = []
|
||||
main_module.configure_logging(verbosity=3)
|
||||
self.assertEqual(1, len(root_logger.handlers))
|
||||
self.assertIsInstance(root_logger.handlers[0], logging.StreamHandler)
|
||||
self.assertEqual(logging.DEBUG, root_logger.level)
|
||||
root_logger.handlers = []
|
||||
main_module.configure_logging(verbosity=9999)
|
||||
self.assertEqual(1, len(root_logger.handlers))
|
||||
self.assertIsInstance(root_logger.handlers[0], logging.StreamHandler)
|
||||
self.assertEqual(logging.DEBUG, root_logger.level)
|
||||
|
||||
@patch.object(main_module, 'get_all_financials')
|
||||
@patch.object(main_module, 'configure_logging')
|
||||
@patch.object(main_module, 'parse_cli')
|
||||
async def test_main(self, mock_parse_cli, mock_configure_logging, mock_get_all_financials):
|
||||
mock_parse_cli.return_value = args = {
|
||||
main_module.VERBOSE: 'foo',
|
||||
main_module.TICKER_SYMBOL: ['bar', 'baz'],
|
||||
main_module.QUARTERLY: 'perhaps',
|
||||
main_module.BATCH_SIZE: 'xyz',
|
||||
main_module.TO_FILE: None,
|
||||
main_module.JSON_INDENT: 42,
|
||||
}
|
||||
mock_get_all_financials.return_value = mock_data = {'data': 'something cool'}
|
||||
|
||||
# To stdout:
|
||||
with patch.object(main_module, 'print') as mock_print:
|
||||
await main_module.main()
|
||||
mock_parse_cli.assert_called_once_with()
|
||||
mock_configure_logging.assert_called_once_with(args[main_module.VERBOSE])
|
||||
mock_get_all_financials.assert_awaited_once_with(*args[main_module.TICKER_SYMBOL],
|
||||
quarterly=args[main_module.QUARTERLY],
|
||||
concurrent_batch_size=args[main_module.BATCH_SIZE])
|
||||
mock_print.assert_called_once_with(json.dumps(mock_data, indent=args[main_module.JSON_INDENT]))
|
||||
mock_parse_cli.reset_mock()
|
||||
mock_configure_logging.reset_mock()
|
||||
mock_get_all_financials.reset_mock()
|
||||
|
||||
# To file:
|
||||
args[main_module.TO_FILE] = 'some_file'
|
||||
with patch.object(main_module, 'open') as mock_open:
|
||||
mock_open.return_value.__enter__.return_value = mock_file = StringIO()
|
||||
await main_module.main()
|
||||
mock_parse_cli.assert_called_once_with()
|
||||
mock_configure_logging.assert_called_once_with(args[main_module.VERBOSE])
|
||||
mock_get_all_financials.assert_awaited_once_with(*args[main_module.TICKER_SYMBOL],
|
||||
quarterly=args[main_module.QUARTERLY],
|
||||
concurrent_batch_size=args[main_module.BATCH_SIZE])
|
||||
expected_contents = json.dumps(mock_data, indent=args[main_module.JSON_INDENT])
|
||||
mock_file.seek(0)
|
||||
self.assertEqual(expected_contents, mock_file.read())
|
Loading…
x
Reference in New Issue
Block a user