diff --git a/tests/test_functions.py b/tests/test_functions.py index 5bac56f..3e46364 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -141,9 +141,9 @@ class FunctionsTestCase(IsolatedAsyncioTestCase): call(statement, symbol2, quarterly, mock_session) ]) - async def _helper_test_get_any_statement(self, statement: str, mock__get_single_company_fin_stmt): - symbol, quarterly, mock_session = 'foo', False, MagicMock() - mock__get_single_company_fin_stmt.return_value = expected_output = 'bar' + async def _helper_test_get_any_statement(self, statement: str, mock__get_multi_companies_fin_stmt): + symbol1, symbol2, quarterly, mock_session = 'foo', 'bar', False, MagicMock() + mock__get_multi_companies_fin_stmt.return_value = expected_output = 'baz' if statement == BS: function = functions.get_balance_sheet elif statement == IS: @@ -152,31 +152,21 @@ class FunctionsTestCase(IsolatedAsyncioTestCase): function = functions.get_cash_flow_statement else: raise ValueError - output = await function(symbol, quarterly=quarterly, session=mock_session) - self.assertEqual(expected_output, output) - mock__get_single_company_fin_stmt.assert_called_once_with(statement, symbol, quarterly, mock_session) - mock__get_single_company_fin_stmt.reset_mock() - - symbol1, symbol2 = 'x', 'y' - expected_output = {symbol1: expected_output, symbol2: expected_output} output = await function(symbol1, symbol2, quarterly=quarterly, session=mock_session) - self.assertDictEqual(expected_output, output) - mock__get_single_company_fin_stmt.assert_has_calls([ - call(statement, symbol1, quarterly, mock_session), - call(statement, symbol2, quarterly, mock_session), - ]) + self.assertEqual(expected_output, output) + mock__get_multi_companies_fin_stmt.assert_called_once_with(statement, symbol1, symbol2, quarterly, mock_session) - @patch.object(functions, '_get_single_company_fin_stmt') - async def test_get_balance_sheet(self, mock__get_single_company_fin_stmt): - await self._helper_test_get_any_statement(BS, mock__get_single_company_fin_stmt) + @patch.object(functions, '_get_multi_companies_fin_stmt') + async def test_get_balance_sheet(self, mock__get_multi_companies_fin_stmt): + await self._helper_test_get_any_statement(BS, mock__get_multi_companies_fin_stmt) - @patch.object(functions, '_get_single_company_fin_stmt') - async def test_get_income_statement(self, mock__get_single_company_fin_stmt): - await self._helper_test_get_any_statement(IS, mock__get_single_company_fin_stmt) + @patch.object(functions, '_get_multi_companies_fin_stmt') + async def test_get_income_statement(self, mock__get_multi_companies_fin_stmt): + await self._helper_test_get_any_statement(IS, mock__get_multi_companies_fin_stmt) - @patch.object(functions, '_get_single_company_fin_stmt') - async def test_get_cash_flow_statement(self, mock__get_single_company_fin_stmt): - await self._helper_test_get_any_statement(CF, mock__get_single_company_fin_stmt) + @patch.object(functions, '_get_multi_companies_fin_stmt') + async def test_get_cash_flow_statement(self, mock__get_multi_companies_fin_stmt): + await self._helper_test_get_any_statement(CF, mock__get_multi_companies_fin_stmt) @patch.object(functions, '_get_single_company_fin_stmt') async def test__get_single_company_financials(self, mock__get_single_company_fin_stmt):