From 2715f78e556f32ce7b946cda4f9a7cb3a042b9fe Mon Sep 17 00:00:00 2001 From: Daniil Fajnberg Date: Thu, 11 Nov 2021 18:05:12 +0100 Subject: [PATCH] renamed some functions and variables; added optional parameters --- src/stock-symbol-scraper/main.py | 41 ++++++++++++++++---------------- 1 file changed, 20 insertions(+), 21 deletions(-) diff --git a/src/stock-symbol-scraper/main.py b/src/stock-symbol-scraper/main.py index 45e1d87..96a0327 100644 --- a/src/stock-symbol-scraper/main.py +++ b/src/stock-symbol-scraper/main.py @@ -7,6 +7,7 @@ from argparse import ArgumentParser from pathlib import Path from datetime import datetime from string import ascii_uppercase +from math import inf from aiohttp import ClientSession from bs4 import BeautifulSoup @@ -34,20 +35,17 @@ class UnexpectedMarkupError(Exception): pass -def data_from_rows(trs: ResultSet) -> list[row_type]: - data: list[row_type] = [] - for row in trs: - data.append(get_single_row_data(row)) - return data +def extract_row_data(*table_rows: Tag) -> list[row_type]: + return [get_single_tr_data(tr) for tr in table_rows] -def get_single_row_data(table_row: Tag) -> row_type: +def get_single_tr_data(table_row: Tag) -> row_type: tds = table_row.find_all('td') company_name = str(tds[0].a.contents[0]).strip() stock_symbol = str(tds[0].a.contents[1].contents[0]).strip() m = re.search(STOCK_SYMBOL_PATTERN, stock_symbol) if m is None: - log.error(f"{stock_symbol} did not match the stock symbol pattern") + log.warning(f"{stock_symbol} did not match the stock symbol pattern; saving as is") else: stock_symbol = m.group(1) country = get_str_from_td(tds[1]) @@ -64,14 +62,14 @@ def get_str_from_td(td: Tag) -> str: return str(content).strip() -async def all_trs_from_page(url: str, session: ClientSession = None) -> ResultSet: +async def trs_from_page(url: str, session: ClientSession = None, limit: int = None) -> ResultSet: if session is None: session = ClientSession() async with session.get(url) as response: html = await response.text() soup = BeautifulSoup(html, 'html.parser') try: - return soup.find('div', {'id': 'marketsindex'}).table.tbody.find_all('tr') + return soup.find('div', {'id': 'marketsindex'}).table.tbody.find_all('tr', limit=limit) except AttributeError: log.error("Unexpected HTML markup!") file_name = f'unexpected_response_at_{datetime.now().strftime("%Y-%m-%d_%H-%M")}.html' @@ -80,27 +78,28 @@ async def all_trs_from_page(url: str, session: ClientSession = None) -> ResultSe raise UnexpectedMarkupError -async def all_data_from_category(category: str, session: ClientSession = None) -> list[row_type]: +async def get_data_from_category(category: str, session: ClientSession = None, + first_page: int = 1, last_page: int = inf) -> list[row_type]: log.info(f"Getting companies starting with '{category}'") if session is None: session = ClientSession() data: list[row_type] = [] - page = 1 - trs = await all_trs_from_page(f'{BASE_URL}{category}', session) - while len(trs) > 0: - log.info(f"Scraping '{category}' page {page}") - data.extend(data_from_rows(trs)) + page = first_page + trs = await trs_from_page(f'{BASE_URL}{category}', session) + while page <= last_page and len(trs) > 0: + data.extend(extract_row_data(*trs)) + log.info(f"Scraped '{category}' page {page}") page += 1 - trs = await all_trs_from_page(f'{BASE_URL}{category}/{page}', session) + trs = await trs_from_page(f'{BASE_URL}{category}/{page}', session) return data -async def get_all_data(asynchronous: bool = False) -> list[row_type]: +async def get_all_data(sequential: bool = False) -> list[row_type]: async with ClientSession() as session: - if asynchronous: - results = await asyncio.gather(*(all_data_from_category(category, session) for category in CATEGORIES)) + if sequential: + results = [await get_data_from_category(category, session) for category in CATEGORIES] else: - results = [await all_data_from_category(category, session) for category in CATEGORIES] + results = await asyncio.gather(*(get_data_from_category(category, session) for category in CATEGORIES)) data = [] for result in results: data.extend(result) @@ -128,7 +127,7 @@ def main() -> None: if args.verbose: log.setLevel(logging.DEBUG) - data = asyncio.run(get_all_data(not args.sequential)) + data = asyncio.run(get_all_data(args.sequential)) if args.to_file is None: csv.writer(sys.stdout).writerows(data)