concurrency for getting all three financial statements; improved imports; renamed functions
This commit is contained in:
		| @@ -1 +1 @@ | ||||
| from .functions import get_company_financials | ||||
| from .functions import get_balance_sheet, get_income_statement, get_cash_flow_statement, get_all_financials | ||||
|   | ||||
| @@ -6,7 +6,7 @@ from argparse import ArgumentParser | ||||
| from pathlib import Path | ||||
| from typing import Dict | ||||
|  | ||||
| from .functions import get_company_financials, ResultDict | ||||
| from .functions import get_all_financials, ResultDict | ||||
| from .constants import END_DATE, MAIN_LOGGER_NAME | ||||
|  | ||||
|  | ||||
| @@ -78,7 +78,7 @@ def write_to_csv(data: Dict[str, ResultDict], file_obj) -> None: | ||||
| async def main() -> None: | ||||
|     args = parse_cli() | ||||
|     configure_logging(args[VERBOSE]) | ||||
|     data = await get_company_financials(args[TICKER_SYMBOL], quarterly=args[QUARTERLY]) | ||||
|     data = await get_all_financials(args[TICKER_SYMBOL], quarterly=args[QUARTERLY]) | ||||
|     path: Path = args[TO_FILE] | ||||
|     if path is None: | ||||
|         print(json.dumps(data, indent=args[JSON_INDENT])) | ||||
|   | ||||
| @@ -15,7 +15,7 @@ END_DATE = 'End Date' | ||||
|  | ||||
| # All items marked `False` do not need to be scraped | ||||
| # because they are calculated from other items (e.g. growth or ratios). | ||||
| FINANCIAL_STATEMENT_ITEMS = { | ||||
| FIN_STMT_ITEMS = { | ||||
|     ################# | ||||
|     # Balance Sheet # | ||||
|     ################# | ||||
|   | ||||
| @@ -1,4 +1,5 @@ | ||||
| import logging | ||||
| import asyncio | ||||
| from typing import Union, List, Dict | ||||
|  | ||||
| from aiohttp.client import ClientSession | ||||
| @@ -6,7 +7,8 @@ from bs4 import BeautifulSoup | ||||
| from bs4.element import Tag | ||||
| from webutils import in_async_session, gather_in_batches | ||||
|  | ||||
| from . import constants | ||||
| from .constants import (HTML_PARSER, BASE_URL, END_DATE, BS, IS, CF, FIN_STMT_URL_SUFFIX, FIN_STMT_ITEMS, | ||||
|                         DEFAULT_CONCURRENT_BATCH_SIZE) | ||||
|  | ||||
|  | ||||
| log = logging.getLogger(__name__) | ||||
| @@ -25,7 +27,7 @@ async def soup_from_url(url: str, session: ClientSession = None) -> BeautifulSou | ||||
|     """ | ||||
|     async with session.get(url) as response: | ||||
|         html = await response.text() | ||||
|     return BeautifulSoup(html, constants.HTML_PARSER) | ||||
|     return BeautifulSoup(html, HTML_PARSER) | ||||
|  | ||||
|  | ||||
| def extract_end_dates(soup: BeautifulSoup) -> tuple[str]: | ||||
| @@ -43,7 +45,7 @@ def is_relevant_table_row(tr: Tag) -> bool: | ||||
|     """ | ||||
|     item_name = str(tr.td.div.string).strip() | ||||
|     try: | ||||
|         return constants.FINANCIAL_STATEMENT_ITEMS[item_name] | ||||
|         return FIN_STMT_ITEMS[item_name] | ||||
|     except KeyError: | ||||
|         log.warning(f"Unknown item name '{item_name}' found in financial statement.") | ||||
|         return False | ||||
| @@ -73,7 +75,7 @@ def extract_all_data(soup: BeautifulSoup) -> ResultDict: | ||||
|     """ | ||||
|     Extracts financials from the page. | ||||
|     """ | ||||
|     output = {constants.END_DATE: extract_end_dates(soup)} | ||||
|     output = {END_DATE: extract_end_dates(soup)} | ||||
|     for row in find_relevant_table_rows(soup): | ||||
|         row_data = extract_row_data(row) | ||||
|         output[row_data[0]] = row_data[1] | ||||
| @@ -86,7 +88,7 @@ async def _get_single_company_fin_stmt(statement: str, ticker_symbol: str, quart | ||||
|     """ | ||||
|     Returns data from the specified financial statement of the specified company. | ||||
|     """ | ||||
|     url = f'{constants.BASE_URL}/{ticker_symbol}/financials{constants.FIN_STMT_URL_SUFFIX[statement]}' | ||||
|     url = f'{BASE_URL}/{ticker_symbol}/financials{FIN_STMT_URL_SUFFIX[statement]}' | ||||
|     if quarterly: | ||||
|         url += '/quarter' | ||||
|     soup = await soup_from_url(url, session) | ||||
| @@ -95,70 +97,69 @@ async def _get_single_company_fin_stmt(statement: str, ticker_symbol: str, quart | ||||
|  | ||||
| @in_async_session | ||||
| async def _get_multi_companies_fin_stmt(statement: str, *ticker_symbols: str, quarterly: bool = False, | ||||
|                                         concurrent_batch_size: int = constants.DEFAULT_CONCURRENT_BATCH_SIZE, | ||||
|                                         concurrent_batch_size: int = DEFAULT_CONCURRENT_BATCH_SIZE, | ||||
|                                         session: ClientSession = None) -> Union[ResultDict, Dict[str, ResultDict]]: | ||||
|     if len(ticker_symbols) == 1: | ||||
|         return await _get_single_company_fin_stmt(statement, ticker_symbols[0], quarterly, session) | ||||
|     result_list = await gather_in_batches( | ||||
|         concurrent_batch_size, | ||||
|         *(_get_single_company_fin_stmt(statement, symbol, quarterly, session) for symbol in ticker_symbols) | ||||
|     ) | ||||
|     coroutines = (_get_single_company_fin_stmt(statement, symbol, quarterly, session) for symbol in ticker_symbols) | ||||
|     result_list = await gather_in_batches(concurrent_batch_size, *coroutines) | ||||
|     return {symbol: data for symbol, data in zip(ticker_symbols, result_list)} | ||||
|  | ||||
|  | ||||
| @in_async_session | ||||
| async def get_balance_sheet(*ticker_symbols: str, quarterly: bool = False, | ||||
|                             concurrent_batch_size: int = constants.DEFAULT_CONCURRENT_BATCH_SIZE, | ||||
|                             concurrent_batch_size: int = DEFAULT_CONCURRENT_BATCH_SIZE, | ||||
|                             session: ClientSession = None) -> Union[ResultDict, Dict[str, ResultDict]]: | ||||
|     """ | ||||
|     Returns data from the balance sheet of the specified company. | ||||
|     """ | ||||
|     return await _get_multi_companies_fin_stmt(constants.BS, *ticker_symbols, | ||||
|     return await _get_multi_companies_fin_stmt(BS, *ticker_symbols, | ||||
|                                                quarterly=quarterly, concurrent_batch_size=concurrent_batch_size, | ||||
|                                                session=session) | ||||
|  | ||||
|  | ||||
| @in_async_session | ||||
| async def get_income_statement(*ticker_symbols: str, quarterly: bool = False, | ||||
|                                concurrent_batch_size: int = constants.DEFAULT_CONCURRENT_BATCH_SIZE, | ||||
|                                concurrent_batch_size: int = DEFAULT_CONCURRENT_BATCH_SIZE, | ||||
|                                session: ClientSession = None) -> Union[ResultDict, Dict[str, ResultDict]]: | ||||
|     """ | ||||
|     Returns data from the income statement of the specified company. | ||||
|     """ | ||||
|     return await _get_multi_companies_fin_stmt(constants.IS, *ticker_symbols, | ||||
|     return await _get_multi_companies_fin_stmt(IS, *ticker_symbols, | ||||
|                                                quarterly=quarterly, concurrent_batch_size=concurrent_batch_size, | ||||
|                                                session=session) | ||||
|  | ||||
|  | ||||
| @in_async_session | ||||
| async def get_cash_flow_statement(*ticker_symbols: str, quarterly: bool = False, | ||||
|                                   concurrent_batch_size: int = constants.DEFAULT_CONCURRENT_BATCH_SIZE, | ||||
|                                   concurrent_batch_size: int = DEFAULT_CONCURRENT_BATCH_SIZE, | ||||
|                                   session: ClientSession = None) -> Union[ResultDict, Dict[str, ResultDict]]: | ||||
|     """ | ||||
|     Returns data from the cash flow statement of the specified company. | ||||
|     """ | ||||
|     return await _get_multi_companies_fin_stmt(constants.CF, *ticker_symbols, | ||||
|     return await _get_multi_companies_fin_stmt(CF, *ticker_symbols, | ||||
|                                                quarterly=quarterly, concurrent_batch_size=concurrent_batch_size, | ||||
|                                                session=session) | ||||
|  | ||||
|  | ||||
| @in_async_session | ||||
| async def _get_single_company_financials(ticker_symbol: str, quarterly: bool = False, | ||||
|                                          session: ClientSession = None) -> Dict[str, ResultDict]: | ||||
|     return { | ||||
|         constants.BS: await _get_single_company_fin_stmt(constants.BS, ticker_symbol, quarterly, session), | ||||
|         constants.IS: await _get_single_company_fin_stmt(constants.IS, ticker_symbol, quarterly, session), | ||||
|         constants.CF: await _get_single_company_fin_stmt(constants.CF, ticker_symbol, quarterly, session) | ||||
|     } | ||||
| async def _get_single_company_all_financials(ticker_symbol: str, quarterly: bool = False, | ||||
|                                              session: ClientSession = None) -> Dict[str, ResultDict]: | ||||
|     coroutines = (_get_single_company_fin_stmt(stmt, ticker_symbol, quarterly, session) for stmt in (BS, IS, CF)) | ||||
|     results = await asyncio.gather(*coroutines) | ||||
|     return {stmt: data for stmt, data in zip((BS, IS, CF), results)} | ||||
|  | ||||
|  | ||||
| @in_async_session | ||||
| async def get_company_financials(*ticker_symbols: str, quarterly: bool = False, | ||||
|                                  session: ClientSession = None) -> Union[Dict[str, ResultDict], | ||||
|                                                                          Dict[str, Dict[str, ResultDict]]]: | ||||
| async def get_all_financials(*ticker_symbols: str, quarterly: bool = False, | ||||
|                              concurrent_batch_size: int = DEFAULT_CONCURRENT_BATCH_SIZE, | ||||
|                              session: ClientSession = None) -> Union[Dict[str, ResultDict], | ||||
|                                                                      Dict[str, Dict[str, ResultDict]]]: | ||||
|     """ | ||||
|     Returns all fundamentals (balance sheet, income statement and cash flow statement) of the specified company. | ||||
|     """ | ||||
|     if len(ticker_symbols) == 1: | ||||
|         return await _get_single_company_financials(ticker_symbols[0], quarterly, session) | ||||
|     return {symbol: await _get_single_company_financials(symbol, quarterly, session) for symbol in ticker_symbols} | ||||
|         return await _get_single_company_all_financials(ticker_symbols[0], quarterly, session) | ||||
|     coroutines = (_get_single_company_all_financials(symbol, quarterly, session) for symbol in ticker_symbols) | ||||
|     result_list = await gather_in_batches(concurrent_batch_size, *coroutines) | ||||
|     return {symbol: data for symbol, data in zip(ticker_symbols, result_list)} | ||||
|   | ||||
| @@ -171,11 +171,11 @@ class FunctionsTestCase(IsolatedAsyncioTestCase): | ||||
|         await self._helper_test_get_any_statement(CF, mock__get_multi_companies_fin_stmt) | ||||
|  | ||||
|     @patch.object(functions, '_get_single_company_fin_stmt') | ||||
|     async def test__get_single_company_financials(self, mock__get_single_company_fin_stmt): | ||||
|     async def test__get_single_company_all_financials(self, mock__get_single_company_fin_stmt): | ||||
|         symbol, quarterly, mock_session = 'foo', False, MagicMock() | ||||
|         mock__get_single_company_fin_stmt.return_value = bar = 'bar' | ||||
|         expected_output = {BS: bar, IS: bar, CF: bar} | ||||
|         output = await functions._get_single_company_financials(symbol, quarterly, mock_session) | ||||
|         output = await functions._get_single_company_all_financials(symbol, quarterly, mock_session) | ||||
|         self.assertDictEqual(expected_output, output) | ||||
|         mock__get_single_company_fin_stmt.assert_has_calls([ | ||||
|             call(BS, symbol, quarterly, mock_session), | ||||
| @@ -183,21 +183,21 @@ class FunctionsTestCase(IsolatedAsyncioTestCase): | ||||
|             call(CF, symbol, quarterly, mock_session) | ||||
|         ]) | ||||
|  | ||||
|     @patch.object(functions, '_get_single_company_financials') | ||||
|     async def test_get_company_financials(self, mock__get_single_company_financials): | ||||
|         mock__get_single_company_financials.return_value = expected_output = 'baz' | ||||
|     @patch.object(functions, '_get_single_company_all_financials') | ||||
|     async def test_get_company_financials(self, mock__get_single_company_all_financials): | ||||
|         mock__get_single_company_all_financials.return_value = expected_output = 'baz' | ||||
|         symbol, quarterly, mock_session = 'foo', False, MagicMock() | ||||
|         output = await functions.get_company_financials(symbol, quarterly=quarterly, session=mock_session) | ||||
|         output = await functions.get_all_financials(symbol, quarterly=quarterly, session=mock_session) | ||||
|         self.assertEqual(expected_output, output) | ||||
|         mock__get_single_company_financials.assert_called_once_with(symbol, quarterly, mock_session) | ||||
|         mock__get_single_company_financials.reset_mock() | ||||
|         mock__get_single_company_all_financials.assert_called_once_with(symbol, quarterly, mock_session) | ||||
|         mock__get_single_company_all_financials.reset_mock() | ||||
|  | ||||
|         test_sym1, test_sym2 = 'x', 'y' | ||||
|         expected_output = {test_sym1: expected_output, test_sym2: expected_output} | ||||
|         output = await functions.get_company_financials(test_sym1, test_sym2, | ||||
|                                                         quarterly=quarterly, session=mock_session) | ||||
|         output = await functions.get_all_financials(test_sym1, test_sym2, | ||||
|                                                     quarterly=quarterly, session=mock_session) | ||||
|         self.assertDictEqual(expected_output, output) | ||||
|         mock__get_single_company_financials.assert_has_calls([ | ||||
|         mock__get_single_company_all_financials.assert_has_calls([ | ||||
|             call(test_sym1, quarterly, mock_session), | ||||
|             call(test_sym2, quarterly, mock_session) | ||||
|         ]) | ||||
| @@ -212,7 +212,7 @@ class FunctionsTestCase(IsolatedAsyncioTestCase): | ||||
|             IS: {END_DATE: ('End_Date_1', 'End_Date_2'), 'Cash & Short Term Investments': (11000000, -22000000)}, | ||||
|             CF: {END_DATE: ('End_Date_1', 'End_Date_2'), 'Cash & Short Term Investments': (11000000, -22000000)} | ||||
|         } | ||||
|         output = await functions.get_company_financials(symbol, session=mock_session_obj) | ||||
|         output = await functions.get_all_financials(symbol, session=mock_session_obj) | ||||
|         self.assertDictEqual(expected_output, output) | ||||
|         mock_session_obj.get.assert_has_calls([ | ||||
|             call(f'{BASE_URL}/{symbol}/financials{FIN_STMT_URL_SUFFIX[BS]}'), | ||||
|   | ||||
		Reference in New Issue
	
	Block a user