From 55a56c2f9ca4b8c987e8965ea9da956017f4210d Mon Sep 17 00:00:00 2001 From: Daniil Fajnberg Date: Fri, 3 Dec 2021 15:38:21 +0100 Subject: [PATCH] concurrency for getting all three financial statements; improved imports; renamed functions --- src/mwfin/__init__.py | 2 +- src/mwfin/__main__.py | 4 +-- src/mwfin/constants.py | 2 +- src/mwfin/functions.py | 57 +++++++++++++++++++++-------------------- tests/test_functions.py | 24 ++++++++--------- 5 files changed, 45 insertions(+), 44 deletions(-) diff --git a/src/mwfin/__init__.py b/src/mwfin/__init__.py index 2c60e41..3f3bb90 100644 --- a/src/mwfin/__init__.py +++ b/src/mwfin/__init__.py @@ -1 +1 @@ -from .functions import get_company_financials +from .functions import get_balance_sheet, get_income_statement, get_cash_flow_statement, get_all_financials diff --git a/src/mwfin/__main__.py b/src/mwfin/__main__.py index b11d440..59311ec 100644 --- a/src/mwfin/__main__.py +++ b/src/mwfin/__main__.py @@ -6,7 +6,7 @@ from argparse import ArgumentParser from pathlib import Path from typing import Dict -from .functions import get_company_financials, ResultDict +from .functions import get_all_financials, ResultDict from .constants import END_DATE, MAIN_LOGGER_NAME @@ -78,7 +78,7 @@ def write_to_csv(data: Dict[str, ResultDict], file_obj) -> None: async def main() -> None: args = parse_cli() configure_logging(args[VERBOSE]) - data = await get_company_financials(args[TICKER_SYMBOL], quarterly=args[QUARTERLY]) + data = await get_all_financials(args[TICKER_SYMBOL], quarterly=args[QUARTERLY]) path: Path = args[TO_FILE] if path is None: print(json.dumps(data, indent=args[JSON_INDENT])) diff --git a/src/mwfin/constants.py b/src/mwfin/constants.py index 278ed1b..2f92abb 100644 --- a/src/mwfin/constants.py +++ b/src/mwfin/constants.py @@ -15,7 +15,7 @@ END_DATE = 'End Date' # All items marked `False` do not need to be scraped # because they are calculated from other items (e.g. growth or ratios). -FINANCIAL_STATEMENT_ITEMS = { +FIN_STMT_ITEMS = { ################# # Balance Sheet # ################# diff --git a/src/mwfin/functions.py b/src/mwfin/functions.py index 4098fd0..cc6f745 100644 --- a/src/mwfin/functions.py +++ b/src/mwfin/functions.py @@ -1,4 +1,5 @@ import logging +import asyncio from typing import Union, List, Dict from aiohttp.client import ClientSession @@ -6,7 +7,8 @@ from bs4 import BeautifulSoup from bs4.element import Tag from webutils import in_async_session, gather_in_batches -from . import constants +from .constants import (HTML_PARSER, BASE_URL, END_DATE, BS, IS, CF, FIN_STMT_URL_SUFFIX, FIN_STMT_ITEMS, + DEFAULT_CONCURRENT_BATCH_SIZE) log = logging.getLogger(__name__) @@ -25,7 +27,7 @@ async def soup_from_url(url: str, session: ClientSession = None) -> BeautifulSou """ async with session.get(url) as response: html = await response.text() - return BeautifulSoup(html, constants.HTML_PARSER) + return BeautifulSoup(html, HTML_PARSER) def extract_end_dates(soup: BeautifulSoup) -> tuple[str]: @@ -43,7 +45,7 @@ def is_relevant_table_row(tr: Tag) -> bool: """ item_name = str(tr.td.div.string).strip() try: - return constants.FINANCIAL_STATEMENT_ITEMS[item_name] + return FIN_STMT_ITEMS[item_name] except KeyError: log.warning(f"Unknown item name '{item_name}' found in financial statement.") return False @@ -73,7 +75,7 @@ def extract_all_data(soup: BeautifulSoup) -> ResultDict: """ Extracts financials from the page. """ - output = {constants.END_DATE: extract_end_dates(soup)} + output = {END_DATE: extract_end_dates(soup)} for row in find_relevant_table_rows(soup): row_data = extract_row_data(row) output[row_data[0]] = row_data[1] @@ -86,7 +88,7 @@ async def _get_single_company_fin_stmt(statement: str, ticker_symbol: str, quart """ Returns data from the specified financial statement of the specified company. """ - url = f'{constants.BASE_URL}/{ticker_symbol}/financials{constants.FIN_STMT_URL_SUFFIX[statement]}' + url = f'{BASE_URL}/{ticker_symbol}/financials{FIN_STMT_URL_SUFFIX[statement]}' if quarterly: url += '/quarter' soup = await soup_from_url(url, session) @@ -95,70 +97,69 @@ async def _get_single_company_fin_stmt(statement: str, ticker_symbol: str, quart @in_async_session async def _get_multi_companies_fin_stmt(statement: str, *ticker_symbols: str, quarterly: bool = False, - concurrent_batch_size: int = constants.DEFAULT_CONCURRENT_BATCH_SIZE, + concurrent_batch_size: int = DEFAULT_CONCURRENT_BATCH_SIZE, session: ClientSession = None) -> Union[ResultDict, Dict[str, ResultDict]]: if len(ticker_symbols) == 1: return await _get_single_company_fin_stmt(statement, ticker_symbols[0], quarterly, session) - result_list = await gather_in_batches( - concurrent_batch_size, - *(_get_single_company_fin_stmt(statement, symbol, quarterly, session) for symbol in ticker_symbols) - ) + coroutines = (_get_single_company_fin_stmt(statement, symbol, quarterly, session) for symbol in ticker_symbols) + result_list = await gather_in_batches(concurrent_batch_size, *coroutines) return {symbol: data for symbol, data in zip(ticker_symbols, result_list)} @in_async_session async def get_balance_sheet(*ticker_symbols: str, quarterly: bool = False, - concurrent_batch_size: int = constants.DEFAULT_CONCURRENT_BATCH_SIZE, + concurrent_batch_size: int = DEFAULT_CONCURRENT_BATCH_SIZE, session: ClientSession = None) -> Union[ResultDict, Dict[str, ResultDict]]: """ Returns data from the balance sheet of the specified company. """ - return await _get_multi_companies_fin_stmt(constants.BS, *ticker_symbols, + return await _get_multi_companies_fin_stmt(BS, *ticker_symbols, quarterly=quarterly, concurrent_batch_size=concurrent_batch_size, session=session) @in_async_session async def get_income_statement(*ticker_symbols: str, quarterly: bool = False, - concurrent_batch_size: int = constants.DEFAULT_CONCURRENT_BATCH_SIZE, + concurrent_batch_size: int = DEFAULT_CONCURRENT_BATCH_SIZE, session: ClientSession = None) -> Union[ResultDict, Dict[str, ResultDict]]: """ Returns data from the income statement of the specified company. """ - return await _get_multi_companies_fin_stmt(constants.IS, *ticker_symbols, + return await _get_multi_companies_fin_stmt(IS, *ticker_symbols, quarterly=quarterly, concurrent_batch_size=concurrent_batch_size, session=session) @in_async_session async def get_cash_flow_statement(*ticker_symbols: str, quarterly: bool = False, - concurrent_batch_size: int = constants.DEFAULT_CONCURRENT_BATCH_SIZE, + concurrent_batch_size: int = DEFAULT_CONCURRENT_BATCH_SIZE, session: ClientSession = None) -> Union[ResultDict, Dict[str, ResultDict]]: """ Returns data from the cash flow statement of the specified company. """ - return await _get_multi_companies_fin_stmt(constants.CF, *ticker_symbols, + return await _get_multi_companies_fin_stmt(CF, *ticker_symbols, quarterly=quarterly, concurrent_batch_size=concurrent_batch_size, session=session) @in_async_session -async def _get_single_company_financials(ticker_symbol: str, quarterly: bool = False, - session: ClientSession = None) -> Dict[str, ResultDict]: - return { - constants.BS: await _get_single_company_fin_stmt(constants.BS, ticker_symbol, quarterly, session), - constants.IS: await _get_single_company_fin_stmt(constants.IS, ticker_symbol, quarterly, session), - constants.CF: await _get_single_company_fin_stmt(constants.CF, ticker_symbol, quarterly, session) - } +async def _get_single_company_all_financials(ticker_symbol: str, quarterly: bool = False, + session: ClientSession = None) -> Dict[str, ResultDict]: + coroutines = (_get_single_company_fin_stmt(stmt, ticker_symbol, quarterly, session) for stmt in (BS, IS, CF)) + results = await asyncio.gather(*coroutines) + return {stmt: data for stmt, data in zip((BS, IS, CF), results)} @in_async_session -async def get_company_financials(*ticker_symbols: str, quarterly: bool = False, - session: ClientSession = None) -> Union[Dict[str, ResultDict], - Dict[str, Dict[str, ResultDict]]]: +async def get_all_financials(*ticker_symbols: str, quarterly: bool = False, + concurrent_batch_size: int = DEFAULT_CONCURRENT_BATCH_SIZE, + session: ClientSession = None) -> Union[Dict[str, ResultDict], + Dict[str, Dict[str, ResultDict]]]: """ Returns all fundamentals (balance sheet, income statement and cash flow statement) of the specified company. """ if len(ticker_symbols) == 1: - return await _get_single_company_financials(ticker_symbols[0], quarterly, session) - return {symbol: await _get_single_company_financials(symbol, quarterly, session) for symbol in ticker_symbols} + return await _get_single_company_all_financials(ticker_symbols[0], quarterly, session) + coroutines = (_get_single_company_all_financials(symbol, quarterly, session) for symbol in ticker_symbols) + result_list = await gather_in_batches(concurrent_batch_size, *coroutines) + return {symbol: data for symbol, data in zip(ticker_symbols, result_list)} diff --git a/tests/test_functions.py b/tests/test_functions.py index 5a55d78..1ba8200 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -171,11 +171,11 @@ class FunctionsTestCase(IsolatedAsyncioTestCase): await self._helper_test_get_any_statement(CF, mock__get_multi_companies_fin_stmt) @patch.object(functions, '_get_single_company_fin_stmt') - async def test__get_single_company_financials(self, mock__get_single_company_fin_stmt): + async def test__get_single_company_all_financials(self, mock__get_single_company_fin_stmt): symbol, quarterly, mock_session = 'foo', False, MagicMock() mock__get_single_company_fin_stmt.return_value = bar = 'bar' expected_output = {BS: bar, IS: bar, CF: bar} - output = await functions._get_single_company_financials(symbol, quarterly, mock_session) + output = await functions._get_single_company_all_financials(symbol, quarterly, mock_session) self.assertDictEqual(expected_output, output) mock__get_single_company_fin_stmt.assert_has_calls([ call(BS, symbol, quarterly, mock_session), @@ -183,21 +183,21 @@ class FunctionsTestCase(IsolatedAsyncioTestCase): call(CF, symbol, quarterly, mock_session) ]) - @patch.object(functions, '_get_single_company_financials') - async def test_get_company_financials(self, mock__get_single_company_financials): - mock__get_single_company_financials.return_value = expected_output = 'baz' + @patch.object(functions, '_get_single_company_all_financials') + async def test_get_company_financials(self, mock__get_single_company_all_financials): + mock__get_single_company_all_financials.return_value = expected_output = 'baz' symbol, quarterly, mock_session = 'foo', False, MagicMock() - output = await functions.get_company_financials(symbol, quarterly=quarterly, session=mock_session) + output = await functions.get_all_financials(symbol, quarterly=quarterly, session=mock_session) self.assertEqual(expected_output, output) - mock__get_single_company_financials.assert_called_once_with(symbol, quarterly, mock_session) - mock__get_single_company_financials.reset_mock() + mock__get_single_company_all_financials.assert_called_once_with(symbol, quarterly, mock_session) + mock__get_single_company_all_financials.reset_mock() test_sym1, test_sym2 = 'x', 'y' expected_output = {test_sym1: expected_output, test_sym2: expected_output} - output = await functions.get_company_financials(test_sym1, test_sym2, - quarterly=quarterly, session=mock_session) + output = await functions.get_all_financials(test_sym1, test_sym2, + quarterly=quarterly, session=mock_session) self.assertDictEqual(expected_output, output) - mock__get_single_company_financials.assert_has_calls([ + mock__get_single_company_all_financials.assert_has_calls([ call(test_sym1, quarterly, mock_session), call(test_sym2, quarterly, mock_session) ]) @@ -212,7 +212,7 @@ class FunctionsTestCase(IsolatedAsyncioTestCase): IS: {END_DATE: ('End_Date_1', 'End_Date_2'), 'Cash & Short Term Investments': (11000000, -22000000)}, CF: {END_DATE: ('End_Date_1', 'End_Date_2'), 'Cash & Short Term Investments': (11000000, -22000000)} } - output = await functions.get_company_financials(symbol, session=mock_session_obj) + output = await functions.get_all_financials(symbol, session=mock_session_obj) self.assertDictEqual(expected_output, output) mock_session_obj.get.assert_has_calls([ call(f'{BASE_URL}/{symbol}/financials{FIN_STMT_URL_SUFFIX[BS]}'),