From 68f3e428ab41f34988c9e9221ecbfb7cff182854 Mon Sep 17 00:00:00 2001 From: Maximilian Fajnberg Date: Fri, 3 Dec 2021 11:46:39 +0100 Subject: [PATCH] plan to refactor get_company_financials; tests changed accordingly --- src/mwfin/functions.py | 6 +++++ tests/test_functions.py | 54 ++++++++++++++++++----------------------- 2 files changed, 29 insertions(+), 31 deletions(-) diff --git a/src/mwfin/functions.py b/src/mwfin/functions.py index e48490f..4454204 100644 --- a/src/mwfin/functions.py +++ b/src/mwfin/functions.py @@ -135,6 +135,12 @@ async def get_cash_flow_statement(*ticker_symbols: str, quarterly: bool = False, } +@in_async_session +async def _get_single_company_financials(ticker_symbol: str, quarterly: bool = False, + session: ClientSession = None) -> Dict[str, ResultDict]: + pass + + @in_async_session async def get_company_financials(*ticker_symbols: str, quarterly: bool = False, session: ClientSession = None) -> Union[Dict[str, ResultDict], diff --git a/tests/test_functions.py b/tests/test_functions.py index 47b91c4..250b57a 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -159,45 +159,37 @@ class FunctionsTestCase(IsolatedAsyncioTestCase): async def test_get_cash_flow_statement(self, mock__get_financial_statement): await self._helper_test_get_any_statement(CF, mock__get_financial_statement) - @patch.object(functions, 'get_cash_flow_statement') - @patch.object(functions, 'get_income_statement') - @patch.object(functions, 'get_balance_sheet') - async def test_get_company_financials(self, mock_get_bs, mock_get_is, mock_get_cf): - mock_end_dates = ('bar', 'baz') - mock_get_bs.return_value = {END_DATE: mock_end_dates, 'a': (1, 2)} - mock_get_is.return_value = {END_DATE: mock_end_dates, 'b': (2, 3)} - mock_get_cf.return_value = {END_DATE: mock_end_dates, 'c': (3, 4)} - expected_output = { - BS: {END_DATE: mock_end_dates, 'a': (1, 2)}, - IS: {END_DATE: mock_end_dates, 'b': (2, 3)}, - CF: {END_DATE: mock_end_dates, 'c': (3, 4)} - } + @patch.object(functions, '_get_financial_statement') + async def test__get_single_company_financials(self, mock__get_financial_statement): + symbol, quarterly, mock_session = 'foo', False, MagicMock() + mock__get_financial_statement.return_value = bar = 'bar' + expected_output = {BS: bar, IS: bar, CF: bar} + output = await functions._get_single_company_financials(symbol, quarterly, mock_session) + self.assertDictEqual(expected_output, output) + mock__get_financial_statement.assert_has_calls([ + call(BS, symbol, quarterly, mock_session), + call(IS, symbol, quarterly, mock_session), + call(CF, symbol, quarterly, mock_session) + ]) + + @patch.object(functions, '_get_single_company_financials') + async def test_get_company_financials(self, mock__get_single_company_financials): + mock__get_single_company_financials.return_value = expected_output = 'baz' symbol, quarterly, mock_session = 'foo', False, MagicMock() output = await functions.get_company_financials(symbol, quarterly=quarterly, session=mock_session) - self.assertDictEqual(expected_output, output) - mock_get_bs.assert_called_once_with(symbol, quarterly=quarterly, session=mock_session) - mock_get_is.assert_called_once_with(symbol, quarterly=quarterly, session=mock_session) - mock_get_cf.assert_called_once_with(symbol, quarterly=quarterly, session=mock_session) - mock_get_bs.reset_mock() - mock_get_is.reset_mock() - mock_get_cf.reset_mock() + self.assertEqual(expected_output, output) + mock__get_single_company_financials.assert_called_once_with(symbol, quarterly, mock_session) + mock__get_single_company_financials.reset_mock() + # keep test_symbol1, test_symbol2 = 'x', 'y' expected_output = {test_symbol1: expected_output, test_symbol2: expected_output} output = await functions.get_company_financials(test_symbol1, test_symbol2, quarterly=quarterly, session=mock_session) self.assertDictEqual(expected_output, output) - mock_get_bs.assert_has_calls([ - call(test_symbol1, quarterly=quarterly, session=mock_session), - call(test_symbol2, quarterly=quarterly, session=mock_session) - ]) - mock_get_is.assert_has_calls([ - call(test_symbol1, quarterly=quarterly, session=mock_session), - call(test_symbol2, quarterly=quarterly, session=mock_session) - ]) - mock_get_cf.assert_has_calls([ - call(test_symbol1, quarterly=quarterly, session=mock_session), - call(test_symbol2, quarterly=quarterly, session=mock_session) + mock__get_single_company_financials.assert_has_calls([ + call(test_symbol1, quarterly, mock_session), + call(test_symbol2, quarterly, mock_session) ]) @patch.object(functions, 'ClientSession')