implemented new function and adjusted function call sequence

This commit is contained in:
Maximilian Fajnberg 2021-12-03 15:19:21 +01:00
parent a7dbdcf917
commit 8d18e03018
2 changed files with 41 additions and 39 deletions

View File

@ -4,7 +4,7 @@ from typing import Union, List, Dict
from aiohttp.client import ClientSession from aiohttp.client import ClientSession
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from bs4.element import Tag from bs4.element import Tag
from webutils import in_async_session from webutils import in_async_session, gather_in_batches
from . import constants from . import constants
@ -97,49 +97,49 @@ async def _get_single_company_fin_stmt(statement: str, ticker_symbol: str, quart
async def _get_multi_companies_fin_stmt(statement: str, *ticker_symbols: str, quarterly: bool = False, async def _get_multi_companies_fin_stmt(statement: str, *ticker_symbols: str, quarterly: bool = False,
concurrent_batch_size: int = constants.DEFAULT_CONCURRENT_BATCH_SIZE, concurrent_batch_size: int = constants.DEFAULT_CONCURRENT_BATCH_SIZE,
session: ClientSession = None) -> Union[ResultDict, Dict[str, ResultDict]]: session: ClientSession = None) -> Union[ResultDict, Dict[str, ResultDict]]:
pass if len(ticker_symbols) == 1:
return await _get_single_company_fin_stmt(statement, ticker_symbols[0], quarterly, session)
result_list = await gather_in_batches(
concurrent_batch_size,
*(_get_single_company_fin_stmt(statement, symbol, quarterly, session) for symbol in ticker_symbols)
)
return {symbol: data for symbol, data in zip(ticker_symbols, result_list)}
@in_async_session @in_async_session
async def get_balance_sheet(*ticker_symbols: str, quarterly: bool = False, async def get_balance_sheet(*ticker_symbols: str, quarterly: bool = False,
concurrent_batch_size: int = constants.DEFAULT_CONCURRENT_BATCH_SIZE,
session: ClientSession = None) -> Union[ResultDict, Dict[str, ResultDict]]: session: ClientSession = None) -> Union[ResultDict, Dict[str, ResultDict]]:
""" """
Returns data from the balance sheet of the specified company. Returns data from the balance sheet of the specified company.
""" """
if len(ticker_symbols) == 1: return await _get_multi_companies_fin_stmt(constants.BS, *ticker_symbols,
return await _get_single_company_fin_stmt(constants.BS, ticker_symbols[0], quarterly, session) quarterly=quarterly, concurrent_batch_size=concurrent_batch_size,
return { session=session)
sym: await _get_single_company_fin_stmt(constants.BS, sym, quarterly, session)
for sym in ticker_symbols
}
@in_async_session @in_async_session
async def get_income_statement(*ticker_symbols: str, quarterly: bool = False, async def get_income_statement(*ticker_symbols: str, quarterly: bool = False,
concurrent_batch_size: int = constants.DEFAULT_CONCURRENT_BATCH_SIZE,
session: ClientSession = None) -> Union[ResultDict, Dict[str, ResultDict]]: session: ClientSession = None) -> Union[ResultDict, Dict[str, ResultDict]]:
""" """
Returns data from the income statement of the specified company. Returns data from the income statement of the specified company.
""" """
if len(ticker_symbols) == 1: return await _get_multi_companies_fin_stmt(constants.IS, *ticker_symbols,
return await _get_single_company_fin_stmt(constants.IS, ticker_symbols[0], quarterly, session) quarterly=quarterly, concurrent_batch_size=concurrent_batch_size,
return { session=session)
sym: await _get_single_company_fin_stmt(constants.IS, sym, quarterly, session)
for sym in ticker_symbols
}
@in_async_session @in_async_session
async def get_cash_flow_statement(*ticker_symbols: str, quarterly: bool = False, async def get_cash_flow_statement(*ticker_symbols: str, quarterly: bool = False,
concurrent_batch_size: int = constants.DEFAULT_CONCURRENT_BATCH_SIZE,
session: ClientSession = None) -> Union[ResultDict, Dict[str, ResultDict]]: session: ClientSession = None) -> Union[ResultDict, Dict[str, ResultDict]]:
""" """
Returns data from the cash flow statement of the specified company. Returns data from the cash flow statement of the specified company.
""" """
if len(ticker_symbols) == 1: return await _get_multi_companies_fin_stmt(constants.CF, *ticker_symbols,
return await _get_single_company_fin_stmt(constants.CF, ticker_symbols[0], quarterly, session) quarterly=quarterly, concurrent_batch_size=concurrent_batch_size,
return { session=session)
sym: await _get_single_company_fin_stmt(constants.CF, sym, quarterly, session)
for sym in ticker_symbols
}
@in_async_session @in_async_session

View File

@ -124,37 +124,39 @@ class FunctionsTestCase(IsolatedAsyncioTestCase):
@patch.object(functions, '_get_single_company_fin_stmt') @patch.object(functions, '_get_single_company_fin_stmt')
async def test__get_multi_companies_fin_stmt(self, mock__get_single_company_fin_stmt): async def test__get_multi_companies_fin_stmt(self, mock__get_single_company_fin_stmt):
statement, symbol1, symbol2, quarterly, mock_session = 'xyz', 'foo', 'bar', False, MagicMock() statement, sym1, sym2, quarterly, mock_session = 'xyz', 'foo', 'bar', False, MagicMock()
mock__get_single_company_fin_stmt.return_value = expected_output = 'baz' mock__get_single_company_fin_stmt.return_value = expected_output = 'baz'
output = await functions._get_multi_companies_fin_stmt(statement, symbol1, output = await functions._get_multi_companies_fin_stmt(statement, sym1,
quarterly=quarterly, session=mock_session) quarterly=quarterly, session=mock_session)
self.assertEqual(expected_output, output) self.assertEqual(expected_output, output)
mock__get_single_company_fin_stmt.assert_called_once_with(statement, symbol1, quarterly, mock_session) mock__get_single_company_fin_stmt.assert_called_once_with(statement, sym1, quarterly, mock_session)
mock__get_single_company_fin_stmt.reset_mock() mock__get_single_company_fin_stmt.reset_mock()
expected_output = {symbol1: expected_output, symbol2: expected_output} expected_output = {sym1: expected_output, sym2: expected_output}
output = await functions._get_multi_companies_fin_stmt(symbol1, symbol2, output = await functions._get_multi_companies_fin_stmt(statement, sym1, sym2,
quarterly=quarterly, session=mock_session) quarterly=quarterly, session=mock_session)
self.assertDictEqual(expected_output, output) self.assertDictEqual(expected_output, output)
mock__get_single_company_fin_stmt.assert_has_calls([ mock__get_single_company_fin_stmt.assert_has_calls([
call(statement, symbol1, quarterly, mock_session), call(statement, sym1, quarterly, mock_session),
call(statement, symbol2, quarterly, mock_session) call(statement, sym2, quarterly, mock_session)
]) ])
async def _helper_test_get_any_statement(self, statement: str, mock__get_multi_companies_fin_stmt): async def _helper_test_get_any_statement(self, stmt: str, mock__get_multi_companies_fin_stmt):
symbol1, symbol2, quarterly, mock_session = 'foo', 'bar', False, MagicMock() sym1, sym2, quarterly, batch_size, mock_session = 'foo', 'bar', False, 2, MagicMock()
mock__get_multi_companies_fin_stmt.return_value = expected_output = 'baz' mock__get_multi_companies_fin_stmt.return_value = expected_output = 'baz'
if statement == BS: if stmt == BS:
function = functions.get_balance_sheet function = functions.get_balance_sheet
elif statement == IS: elif stmt == IS:
function = functions.get_income_statement function = functions.get_income_statement
elif statement == CF: elif stmt == CF:
function = functions.get_cash_flow_statement function = functions.get_cash_flow_statement
else: else:
raise ValueError raise ValueError
output = await function(symbol1, symbol2, quarterly=quarterly, session=mock_session) output = await function(sym1, sym2, quarterly=quarterly, concurrent_batch_size=batch_size, session=mock_session)
self.assertEqual(expected_output, output) self.assertEqual(expected_output, output)
mock__get_multi_companies_fin_stmt.assert_called_once_with(statement, symbol1, symbol2, quarterly, mock_session) mock__get_multi_companies_fin_stmt.assert_called_once_with(
stmt, sym1, sym2, quarterly=quarterly, concurrent_batch_size=batch_size, session=mock_session
)
@patch.object(functions, '_get_multi_companies_fin_stmt') @patch.object(functions, '_get_multi_companies_fin_stmt')
async def test_get_balance_sheet(self, mock__get_multi_companies_fin_stmt): async def test_get_balance_sheet(self, mock__get_multi_companies_fin_stmt):
@ -190,14 +192,14 @@ class FunctionsTestCase(IsolatedAsyncioTestCase):
mock__get_single_company_financials.assert_called_once_with(symbol, quarterly, mock_session) mock__get_single_company_financials.assert_called_once_with(symbol, quarterly, mock_session)
mock__get_single_company_financials.reset_mock() mock__get_single_company_financials.reset_mock()
test_symbol1, test_symbol2 = 'x', 'y' test_sym1, test_sym2 = 'x', 'y'
expected_output = {test_symbol1: expected_output, test_symbol2: expected_output} expected_output = {test_sym1: expected_output, test_sym2: expected_output}
output = await functions.get_company_financials(test_symbol1, test_symbol2, output = await functions.get_company_financials(test_sym1, test_sym2,
quarterly=quarterly, session=mock_session) quarterly=quarterly, session=mock_session)
self.assertDictEqual(expected_output, output) self.assertDictEqual(expected_output, output)
mock__get_single_company_financials.assert_has_calls([ mock__get_single_company_financials.assert_has_calls([
call(test_symbol1, quarterly, mock_session), call(test_sym1, quarterly, mock_session),
call(test_symbol2, quarterly, mock_session) call(test_sym2, quarterly, mock_session)
]) ])
@patch.object(functions, 'ClientSession') @patch.object(functions, 'ClientSession')