From 323144200bab062bc22e763c78a54a6adbb1da0d Mon Sep 17 00:00:00 2001 From: Daniil Fajnberg Date: Thu, 11 Nov 2021 16:21:34 +0100 Subject: [PATCH] only use one http session in the main function --- src/stock-symbol-scraper/main.py | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/src/stock-symbol-scraper/main.py b/src/stock-symbol-scraper/main.py index 0f90c43..5cf1714 100644 --- a/src/stock-symbol-scraper/main.py +++ b/src/stock-symbol-scraper/main.py @@ -61,10 +61,11 @@ def get_str_from_td(td: Tag) -> str: return str(content).strip() -async def all_trs_from_page(url: str) -> ResultSet: - async with ClientSession() as session: - async with session.get(url) as response: - html = await response.text() +async def all_trs_from_page(url: str, session: ClientSession = None) -> ResultSet: + if session is None: + session = ClientSession() + async with session.get(url) as response: + html = await response.text() soup = BeautifulSoup(html, 'html.parser') try: return soup.find('div', {'id': 'marketsindex'}).table.tbody.find_all('tr') @@ -76,24 +77,27 @@ async def all_trs_from_page(url: str) -> ResultSet: raise UnexpectedMarkupError -async def all_data_from_category(category: str) -> list[row_type]: +async def all_data_from_category(category: str, session: ClientSession = None) -> list[row_type]: log.info(f"Getting companies starting with '{category}'") + if session is None: + session = ClientSession() data: list[row_type] = [] page = 1 - trs = await all_trs_from_page(f'{BASE_URL}{category}') + trs = await all_trs_from_page(f'{BASE_URL}{category}', session) while len(trs) > 0: - log.info(f"Scraping page {page}") + log.info(f"Scraping '{category}' page {page}") data.extend(data_from_rows(trs)) page += 1 - trs = await all_trs_from_page(f'{BASE_URL}{category}/{page}') + trs = await all_trs_from_page(f'{BASE_URL}{category}/{page}', session) return data async def get_all_data(asynchronous: bool = False) -> list[row_type]: - if asynchronous: - results = await asyncio.gather(*(all_data_from_category(category) for category in CATEGORIES)) - else: - results = [await all_data_from_category(category) for category in CATEGORIES] + async with ClientSession() as session: + if asynchronous: + results = await asyncio.gather(*(all_data_from_category(category, session) for category in CATEGORIES)) + else: + results = [await all_data_from_category(category, session) for category in CATEGORIES] data = [] for result in results: data.extend(result)