Compare commits

..

2 Commits

2 changed files with 34 additions and 43 deletions

View File

@ -135,6 +135,16 @@ async def get_cash_flow_statement(*ticker_symbols: str, quarterly: bool = False,
} }
@in_async_session
async def _get_single_company_financials(ticker_symbol: str, quarterly: bool = False,
session: ClientSession = None) -> Dict[str, ResultDict]:
return {
constants.BS: await _get_financial_statement(constants.BS, ticker_symbol, quarterly, session),
constants.IS: await _get_financial_statement(constants.IS, ticker_symbol, quarterly, session),
constants.CF: await _get_financial_statement(constants.CF, ticker_symbol, quarterly, session)
}
@in_async_session @in_async_session
async def get_company_financials(*ticker_symbols: str, quarterly: bool = False, async def get_company_financials(*ticker_symbols: str, quarterly: bool = False,
session: ClientSession = None) -> Union[Dict[str, ResultDict], session: ClientSession = None) -> Union[Dict[str, ResultDict],
@ -143,15 +153,5 @@ async def get_company_financials(*ticker_symbols: str, quarterly: bool = False,
Returns all fundamentals (balance sheet, income statement and cash flow statement) of the specified company. Returns all fundamentals (balance sheet, income statement and cash flow statement) of the specified company.
""" """
if len(ticker_symbols) == 1: if len(ticker_symbols) == 1:
return { return await _get_single_company_financials(ticker_symbols[0], quarterly, session)
constants.BS: await get_balance_sheet(ticker_symbols[0], quarterly=quarterly, session=session), return {symbol: await _get_single_company_financials(symbol, quarterly, session) for symbol in ticker_symbols}
constants.IS: await get_income_statement(ticker_symbols[0], quarterly=quarterly, session=session),
constants.CF: await get_cash_flow_statement(ticker_symbols[0], quarterly=quarterly, session=session)
}
return {
sym: {
constants.BS: await get_balance_sheet(sym, quarterly=quarterly, session=session),
constants.IS: await get_income_statement(sym, quarterly=quarterly, session=session),
constants.CF: await get_cash_flow_statement(sym, quarterly=quarterly, session=session)
} for sym in ticker_symbols
}

View File

@ -159,45 +159,36 @@ class FunctionsTestCase(IsolatedAsyncioTestCase):
async def test_get_cash_flow_statement(self, mock__get_financial_statement): async def test_get_cash_flow_statement(self, mock__get_financial_statement):
await self._helper_test_get_any_statement(CF, mock__get_financial_statement) await self._helper_test_get_any_statement(CF, mock__get_financial_statement)
@patch.object(functions, 'get_cash_flow_statement') @patch.object(functions, '_get_financial_statement')
@patch.object(functions, 'get_income_statement') async def test__get_single_company_financials(self, mock__get_financial_statement):
@patch.object(functions, 'get_balance_sheet') symbol, quarterly, mock_session = 'foo', False, MagicMock()
async def test_get_company_financials(self, mock_get_bs, mock_get_is, mock_get_cf): mock__get_financial_statement.return_value = bar = 'bar'
mock_end_dates = ('bar', 'baz') expected_output = {BS: bar, IS: bar, CF: bar}
mock_get_bs.return_value = {END_DATE: mock_end_dates, 'a': (1, 2)} output = await functions._get_single_company_financials(symbol, quarterly, mock_session)
mock_get_is.return_value = {END_DATE: mock_end_dates, 'b': (2, 3)} self.assertDictEqual(expected_output, output)
mock_get_cf.return_value = {END_DATE: mock_end_dates, 'c': (3, 4)} mock__get_financial_statement.assert_has_calls([
expected_output = { call(BS, symbol, quarterly, mock_session),
BS: {END_DATE: mock_end_dates, 'a': (1, 2)}, call(IS, symbol, quarterly, mock_session),
IS: {END_DATE: mock_end_dates, 'b': (2, 3)}, call(CF, symbol, quarterly, mock_session)
CF: {END_DATE: mock_end_dates, 'c': (3, 4)} ])
}
@patch.object(functions, '_get_single_company_financials')
async def test_get_company_financials(self, mock__get_single_company_financials):
mock__get_single_company_financials.return_value = expected_output = 'baz'
symbol, quarterly, mock_session = 'foo', False, MagicMock() symbol, quarterly, mock_session = 'foo', False, MagicMock()
output = await functions.get_company_financials(symbol, quarterly=quarterly, session=mock_session) output = await functions.get_company_financials(symbol, quarterly=quarterly, session=mock_session)
self.assertDictEqual(expected_output, output) self.assertEqual(expected_output, output)
mock_get_bs.assert_called_once_with(symbol, quarterly=quarterly, session=mock_session) mock__get_single_company_financials.assert_called_once_with(symbol, quarterly, mock_session)
mock_get_is.assert_called_once_with(symbol, quarterly=quarterly, session=mock_session) mock__get_single_company_financials.reset_mock()
mock_get_cf.assert_called_once_with(symbol, quarterly=quarterly, session=mock_session)
mock_get_bs.reset_mock()
mock_get_is.reset_mock()
mock_get_cf.reset_mock()
test_symbol1, test_symbol2 = 'x', 'y' test_symbol1, test_symbol2 = 'x', 'y'
expected_output = {test_symbol1: expected_output, test_symbol2: expected_output} expected_output = {test_symbol1: expected_output, test_symbol2: expected_output}
output = await functions.get_company_financials(test_symbol1, test_symbol2, output = await functions.get_company_financials(test_symbol1, test_symbol2,
quarterly=quarterly, session=mock_session) quarterly=quarterly, session=mock_session)
self.assertDictEqual(expected_output, output) self.assertDictEqual(expected_output, output)
mock_get_bs.assert_has_calls([ mock__get_single_company_financials.assert_has_calls([
call(test_symbol1, quarterly=quarterly, session=mock_session), call(test_symbol1, quarterly, mock_session),
call(test_symbol2, quarterly=quarterly, session=mock_session) call(test_symbol2, quarterly, mock_session)
])
mock_get_is.assert_has_calls([
call(test_symbol1, quarterly=quarterly, session=mock_session),
call(test_symbol2, quarterly=quarterly, session=mock_session)
])
mock_get_cf.assert_has_calls([
call(test_symbol1, quarterly=quarterly, session=mock_session),
call(test_symbol2, quarterly=quarterly, session=mock_session)
]) ])
@patch.object(functions, 'ClientSession') @patch.object(functions, 'ClientSession')