concurrency for getting all three financial statements; improved imports; renamed functions

This commit is contained in:
Daniil Fajnberg 2021-12-03 15:38:21 +01:00
parent 8d18e03018
commit 55a56c2f9c
5 changed files with 45 additions and 44 deletions

View File

@ -1 +1 @@
from .functions import get_company_financials from .functions import get_balance_sheet, get_income_statement, get_cash_flow_statement, get_all_financials

View File

@ -6,7 +6,7 @@ from argparse import ArgumentParser
from pathlib import Path from pathlib import Path
from typing import Dict from typing import Dict
from .functions import get_company_financials, ResultDict from .functions import get_all_financials, ResultDict
from .constants import END_DATE, MAIN_LOGGER_NAME from .constants import END_DATE, MAIN_LOGGER_NAME
@ -78,7 +78,7 @@ def write_to_csv(data: Dict[str, ResultDict], file_obj) -> None:
async def main() -> None: async def main() -> None:
args = parse_cli() args = parse_cli()
configure_logging(args[VERBOSE]) configure_logging(args[VERBOSE])
data = await get_company_financials(args[TICKER_SYMBOL], quarterly=args[QUARTERLY]) data = await get_all_financials(args[TICKER_SYMBOL], quarterly=args[QUARTERLY])
path: Path = args[TO_FILE] path: Path = args[TO_FILE]
if path is None: if path is None:
print(json.dumps(data, indent=args[JSON_INDENT])) print(json.dumps(data, indent=args[JSON_INDENT]))

View File

@ -15,7 +15,7 @@ END_DATE = 'End Date'
# All items marked `False` do not need to be scraped # All items marked `False` do not need to be scraped
# because they are calculated from other items (e.g. growth or ratios). # because they are calculated from other items (e.g. growth or ratios).
FINANCIAL_STATEMENT_ITEMS = { FIN_STMT_ITEMS = {
################# #################
# Balance Sheet # # Balance Sheet #
################# #################

View File

@ -1,4 +1,5 @@
import logging import logging
import asyncio
from typing import Union, List, Dict from typing import Union, List, Dict
from aiohttp.client import ClientSession from aiohttp.client import ClientSession
@ -6,7 +7,8 @@ from bs4 import BeautifulSoup
from bs4.element import Tag from bs4.element import Tag
from webutils import in_async_session, gather_in_batches from webutils import in_async_session, gather_in_batches
from . import constants from .constants import (HTML_PARSER, BASE_URL, END_DATE, BS, IS, CF, FIN_STMT_URL_SUFFIX, FIN_STMT_ITEMS,
DEFAULT_CONCURRENT_BATCH_SIZE)
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -25,7 +27,7 @@ async def soup_from_url(url: str, session: ClientSession = None) -> BeautifulSou
""" """
async with session.get(url) as response: async with session.get(url) as response:
html = await response.text() html = await response.text()
return BeautifulSoup(html, constants.HTML_PARSER) return BeautifulSoup(html, HTML_PARSER)
def extract_end_dates(soup: BeautifulSoup) -> tuple[str]: def extract_end_dates(soup: BeautifulSoup) -> tuple[str]:
@ -43,7 +45,7 @@ def is_relevant_table_row(tr: Tag) -> bool:
""" """
item_name = str(tr.td.div.string).strip() item_name = str(tr.td.div.string).strip()
try: try:
return constants.FINANCIAL_STATEMENT_ITEMS[item_name] return FIN_STMT_ITEMS[item_name]
except KeyError: except KeyError:
log.warning(f"Unknown item name '{item_name}' found in financial statement.") log.warning(f"Unknown item name '{item_name}' found in financial statement.")
return False return False
@ -73,7 +75,7 @@ def extract_all_data(soup: BeautifulSoup) -> ResultDict:
""" """
Extracts financials from the page. Extracts financials from the page.
""" """
output = {constants.END_DATE: extract_end_dates(soup)} output = {END_DATE: extract_end_dates(soup)}
for row in find_relevant_table_rows(soup): for row in find_relevant_table_rows(soup):
row_data = extract_row_data(row) row_data = extract_row_data(row)
output[row_data[0]] = row_data[1] output[row_data[0]] = row_data[1]
@ -86,7 +88,7 @@ async def _get_single_company_fin_stmt(statement: str, ticker_symbol: str, quart
""" """
Returns data from the specified financial statement of the specified company. Returns data from the specified financial statement of the specified company.
""" """
url = f'{constants.BASE_URL}/{ticker_symbol}/financials{constants.FIN_STMT_URL_SUFFIX[statement]}' url = f'{BASE_URL}/{ticker_symbol}/financials{FIN_STMT_URL_SUFFIX[statement]}'
if quarterly: if quarterly:
url += '/quarter' url += '/quarter'
soup = await soup_from_url(url, session) soup = await soup_from_url(url, session)
@ -95,70 +97,69 @@ async def _get_single_company_fin_stmt(statement: str, ticker_symbol: str, quart
@in_async_session @in_async_session
async def _get_multi_companies_fin_stmt(statement: str, *ticker_symbols: str, quarterly: bool = False, async def _get_multi_companies_fin_stmt(statement: str, *ticker_symbols: str, quarterly: bool = False,
concurrent_batch_size: int = constants.DEFAULT_CONCURRENT_BATCH_SIZE, concurrent_batch_size: int = DEFAULT_CONCURRENT_BATCH_SIZE,
session: ClientSession = None) -> Union[ResultDict, Dict[str, ResultDict]]: session: ClientSession = None) -> Union[ResultDict, Dict[str, ResultDict]]:
if len(ticker_symbols) == 1: if len(ticker_symbols) == 1:
return await _get_single_company_fin_stmt(statement, ticker_symbols[0], quarterly, session) return await _get_single_company_fin_stmt(statement, ticker_symbols[0], quarterly, session)
result_list = await gather_in_batches( coroutines = (_get_single_company_fin_stmt(statement, symbol, quarterly, session) for symbol in ticker_symbols)
concurrent_batch_size, result_list = await gather_in_batches(concurrent_batch_size, *coroutines)
*(_get_single_company_fin_stmt(statement, symbol, quarterly, session) for symbol in ticker_symbols)
)
return {symbol: data for symbol, data in zip(ticker_symbols, result_list)} return {symbol: data for symbol, data in zip(ticker_symbols, result_list)}
@in_async_session @in_async_session
async def get_balance_sheet(*ticker_symbols: str, quarterly: bool = False, async def get_balance_sheet(*ticker_symbols: str, quarterly: bool = False,
concurrent_batch_size: int = constants.DEFAULT_CONCURRENT_BATCH_SIZE, concurrent_batch_size: int = DEFAULT_CONCURRENT_BATCH_SIZE,
session: ClientSession = None) -> Union[ResultDict, Dict[str, ResultDict]]: session: ClientSession = None) -> Union[ResultDict, Dict[str, ResultDict]]:
""" """
Returns data from the balance sheet of the specified company. Returns data from the balance sheet of the specified company.
""" """
return await _get_multi_companies_fin_stmt(constants.BS, *ticker_symbols, return await _get_multi_companies_fin_stmt(BS, *ticker_symbols,
quarterly=quarterly, concurrent_batch_size=concurrent_batch_size, quarterly=quarterly, concurrent_batch_size=concurrent_batch_size,
session=session) session=session)
@in_async_session @in_async_session
async def get_income_statement(*ticker_symbols: str, quarterly: bool = False, async def get_income_statement(*ticker_symbols: str, quarterly: bool = False,
concurrent_batch_size: int = constants.DEFAULT_CONCURRENT_BATCH_SIZE, concurrent_batch_size: int = DEFAULT_CONCURRENT_BATCH_SIZE,
session: ClientSession = None) -> Union[ResultDict, Dict[str, ResultDict]]: session: ClientSession = None) -> Union[ResultDict, Dict[str, ResultDict]]:
""" """
Returns data from the income statement of the specified company. Returns data from the income statement of the specified company.
""" """
return await _get_multi_companies_fin_stmt(constants.IS, *ticker_symbols, return await _get_multi_companies_fin_stmt(IS, *ticker_symbols,
quarterly=quarterly, concurrent_batch_size=concurrent_batch_size, quarterly=quarterly, concurrent_batch_size=concurrent_batch_size,
session=session) session=session)
@in_async_session @in_async_session
async def get_cash_flow_statement(*ticker_symbols: str, quarterly: bool = False, async def get_cash_flow_statement(*ticker_symbols: str, quarterly: bool = False,
concurrent_batch_size: int = constants.DEFAULT_CONCURRENT_BATCH_SIZE, concurrent_batch_size: int = DEFAULT_CONCURRENT_BATCH_SIZE,
session: ClientSession = None) -> Union[ResultDict, Dict[str, ResultDict]]: session: ClientSession = None) -> Union[ResultDict, Dict[str, ResultDict]]:
""" """
Returns data from the cash flow statement of the specified company. Returns data from the cash flow statement of the specified company.
""" """
return await _get_multi_companies_fin_stmt(constants.CF, *ticker_symbols, return await _get_multi_companies_fin_stmt(CF, *ticker_symbols,
quarterly=quarterly, concurrent_batch_size=concurrent_batch_size, quarterly=quarterly, concurrent_batch_size=concurrent_batch_size,
session=session) session=session)
@in_async_session @in_async_session
async def _get_single_company_financials(ticker_symbol: str, quarterly: bool = False, async def _get_single_company_all_financials(ticker_symbol: str, quarterly: bool = False,
session: ClientSession = None) -> Dict[str, ResultDict]: session: ClientSession = None) -> Dict[str, ResultDict]:
return { coroutines = (_get_single_company_fin_stmt(stmt, ticker_symbol, quarterly, session) for stmt in (BS, IS, CF))
constants.BS: await _get_single_company_fin_stmt(constants.BS, ticker_symbol, quarterly, session), results = await asyncio.gather(*coroutines)
constants.IS: await _get_single_company_fin_stmt(constants.IS, ticker_symbol, quarterly, session), return {stmt: data for stmt, data in zip((BS, IS, CF), results)}
constants.CF: await _get_single_company_fin_stmt(constants.CF, ticker_symbol, quarterly, session)
}
@in_async_session @in_async_session
async def get_company_financials(*ticker_symbols: str, quarterly: bool = False, async def get_all_financials(*ticker_symbols: str, quarterly: bool = False,
session: ClientSession = None) -> Union[Dict[str, ResultDict], concurrent_batch_size: int = DEFAULT_CONCURRENT_BATCH_SIZE,
Dict[str, Dict[str, ResultDict]]]: session: ClientSession = None) -> Union[Dict[str, ResultDict],
Dict[str, Dict[str, ResultDict]]]:
""" """
Returns all fundamentals (balance sheet, income statement and cash flow statement) of the specified company. Returns all fundamentals (balance sheet, income statement and cash flow statement) of the specified company.
""" """
if len(ticker_symbols) == 1: if len(ticker_symbols) == 1:
return await _get_single_company_financials(ticker_symbols[0], quarterly, session) return await _get_single_company_all_financials(ticker_symbols[0], quarterly, session)
return {symbol: await _get_single_company_financials(symbol, quarterly, session) for symbol in ticker_symbols} coroutines = (_get_single_company_all_financials(symbol, quarterly, session) for symbol in ticker_symbols)
result_list = await gather_in_batches(concurrent_batch_size, *coroutines)
return {symbol: data for symbol, data in zip(ticker_symbols, result_list)}

View File

@ -171,11 +171,11 @@ class FunctionsTestCase(IsolatedAsyncioTestCase):
await self._helper_test_get_any_statement(CF, mock__get_multi_companies_fin_stmt) await self._helper_test_get_any_statement(CF, mock__get_multi_companies_fin_stmt)
@patch.object(functions, '_get_single_company_fin_stmt') @patch.object(functions, '_get_single_company_fin_stmt')
async def test__get_single_company_financials(self, mock__get_single_company_fin_stmt): async def test__get_single_company_all_financials(self, mock__get_single_company_fin_stmt):
symbol, quarterly, mock_session = 'foo', False, MagicMock() symbol, quarterly, mock_session = 'foo', False, MagicMock()
mock__get_single_company_fin_stmt.return_value = bar = 'bar' mock__get_single_company_fin_stmt.return_value = bar = 'bar'
expected_output = {BS: bar, IS: bar, CF: bar} expected_output = {BS: bar, IS: bar, CF: bar}
output = await functions._get_single_company_financials(symbol, quarterly, mock_session) output = await functions._get_single_company_all_financials(symbol, quarterly, mock_session)
self.assertDictEqual(expected_output, output) self.assertDictEqual(expected_output, output)
mock__get_single_company_fin_stmt.assert_has_calls([ mock__get_single_company_fin_stmt.assert_has_calls([
call(BS, symbol, quarterly, mock_session), call(BS, symbol, quarterly, mock_session),
@ -183,21 +183,21 @@ class FunctionsTestCase(IsolatedAsyncioTestCase):
call(CF, symbol, quarterly, mock_session) call(CF, symbol, quarterly, mock_session)
]) ])
@patch.object(functions, '_get_single_company_financials') @patch.object(functions, '_get_single_company_all_financials')
async def test_get_company_financials(self, mock__get_single_company_financials): async def test_get_company_financials(self, mock__get_single_company_all_financials):
mock__get_single_company_financials.return_value = expected_output = 'baz' mock__get_single_company_all_financials.return_value = expected_output = 'baz'
symbol, quarterly, mock_session = 'foo', False, MagicMock() symbol, quarterly, mock_session = 'foo', False, MagicMock()
output = await functions.get_company_financials(symbol, quarterly=quarterly, session=mock_session) output = await functions.get_all_financials(symbol, quarterly=quarterly, session=mock_session)
self.assertEqual(expected_output, output) self.assertEqual(expected_output, output)
mock__get_single_company_financials.assert_called_once_with(symbol, quarterly, mock_session) mock__get_single_company_all_financials.assert_called_once_with(symbol, quarterly, mock_session)
mock__get_single_company_financials.reset_mock() mock__get_single_company_all_financials.reset_mock()
test_sym1, test_sym2 = 'x', 'y' test_sym1, test_sym2 = 'x', 'y'
expected_output = {test_sym1: expected_output, test_sym2: expected_output} expected_output = {test_sym1: expected_output, test_sym2: expected_output}
output = await functions.get_company_financials(test_sym1, test_sym2, output = await functions.get_all_financials(test_sym1, test_sym2,
quarterly=quarterly, session=mock_session) quarterly=quarterly, session=mock_session)
self.assertDictEqual(expected_output, output) self.assertDictEqual(expected_output, output)
mock__get_single_company_financials.assert_has_calls([ mock__get_single_company_all_financials.assert_has_calls([
call(test_sym1, quarterly, mock_session), call(test_sym1, quarterly, mock_session),
call(test_sym2, quarterly, mock_session) call(test_sym2, quarterly, mock_session)
]) ])
@ -212,7 +212,7 @@ class FunctionsTestCase(IsolatedAsyncioTestCase):
IS: {END_DATE: ('End_Date_1', 'End_Date_2'), 'Cash & Short Term Investments': (11000000, -22000000)}, IS: {END_DATE: ('End_Date_1', 'End_Date_2'), 'Cash & Short Term Investments': (11000000, -22000000)},
CF: {END_DATE: ('End_Date_1', 'End_Date_2'), 'Cash & Short Term Investments': (11000000, -22000000)} CF: {END_DATE: ('End_Date_1', 'End_Date_2'), 'Cash & Short Term Investments': (11000000, -22000000)}
} }
output = await functions.get_company_financials(symbol, session=mock_session_obj) output = await functions.get_all_financials(symbol, session=mock_session_obj)
self.assertDictEqual(expected_output, output) self.assertDictEqual(expected_output, output)
mock_session_obj.get.assert_has_calls([ mock_session_obj.get.assert_has_calls([
call(f'{BASE_URL}/{symbol}/financials{FIN_STMT_URL_SUFFIX[BS]}'), call(f'{BASE_URL}/{symbol}/financials{FIN_STMT_URL_SUFFIX[BS]}'),