renamed some functions and variables; added optional parameters
This commit is contained in:
parent
d98e00a8ae
commit
2715f78e55
@ -7,6 +7,7 @@ from argparse import ArgumentParser
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from string import ascii_uppercase
|
from string import ascii_uppercase
|
||||||
|
from math import inf
|
||||||
|
|
||||||
from aiohttp import ClientSession
|
from aiohttp import ClientSession
|
||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
@ -34,20 +35,17 @@ class UnexpectedMarkupError(Exception):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def data_from_rows(trs: ResultSet) -> list[row_type]:
|
def extract_row_data(*table_rows: Tag) -> list[row_type]:
|
||||||
data: list[row_type] = []
|
return [get_single_tr_data(tr) for tr in table_rows]
|
||||||
for row in trs:
|
|
||||||
data.append(get_single_row_data(row))
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
def get_single_row_data(table_row: Tag) -> row_type:
|
def get_single_tr_data(table_row: Tag) -> row_type:
|
||||||
tds = table_row.find_all('td')
|
tds = table_row.find_all('td')
|
||||||
company_name = str(tds[0].a.contents[0]).strip()
|
company_name = str(tds[0].a.contents[0]).strip()
|
||||||
stock_symbol = str(tds[0].a.contents[1].contents[0]).strip()
|
stock_symbol = str(tds[0].a.contents[1].contents[0]).strip()
|
||||||
m = re.search(STOCK_SYMBOL_PATTERN, stock_symbol)
|
m = re.search(STOCK_SYMBOL_PATTERN, stock_symbol)
|
||||||
if m is None:
|
if m is None:
|
||||||
log.error(f"{stock_symbol} did not match the stock symbol pattern")
|
log.warning(f"{stock_symbol} did not match the stock symbol pattern; saving as is")
|
||||||
else:
|
else:
|
||||||
stock_symbol = m.group(1)
|
stock_symbol = m.group(1)
|
||||||
country = get_str_from_td(tds[1])
|
country = get_str_from_td(tds[1])
|
||||||
@ -64,14 +62,14 @@ def get_str_from_td(td: Tag) -> str:
|
|||||||
return str(content).strip()
|
return str(content).strip()
|
||||||
|
|
||||||
|
|
||||||
async def all_trs_from_page(url: str, session: ClientSession = None) -> ResultSet:
|
async def trs_from_page(url: str, session: ClientSession = None, limit: int = None) -> ResultSet:
|
||||||
if session is None:
|
if session is None:
|
||||||
session = ClientSession()
|
session = ClientSession()
|
||||||
async with session.get(url) as response:
|
async with session.get(url) as response:
|
||||||
html = await response.text()
|
html = await response.text()
|
||||||
soup = BeautifulSoup(html, 'html.parser')
|
soup = BeautifulSoup(html, 'html.parser')
|
||||||
try:
|
try:
|
||||||
return soup.find('div', {'id': 'marketsindex'}).table.tbody.find_all('tr')
|
return soup.find('div', {'id': 'marketsindex'}).table.tbody.find_all('tr', limit=limit)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
log.error("Unexpected HTML markup!")
|
log.error("Unexpected HTML markup!")
|
||||||
file_name = f'unexpected_response_at_{datetime.now().strftime("%Y-%m-%d_%H-%M")}.html'
|
file_name = f'unexpected_response_at_{datetime.now().strftime("%Y-%m-%d_%H-%M")}.html'
|
||||||
@ -80,27 +78,28 @@ async def all_trs_from_page(url: str, session: ClientSession = None) -> ResultSe
|
|||||||
raise UnexpectedMarkupError
|
raise UnexpectedMarkupError
|
||||||
|
|
||||||
|
|
||||||
async def all_data_from_category(category: str, session: ClientSession = None) -> list[row_type]:
|
async def get_data_from_category(category: str, session: ClientSession = None,
|
||||||
|
first_page: int = 1, last_page: int = inf) -> list[row_type]:
|
||||||
log.info(f"Getting companies starting with '{category}'")
|
log.info(f"Getting companies starting with '{category}'")
|
||||||
if session is None:
|
if session is None:
|
||||||
session = ClientSession()
|
session = ClientSession()
|
||||||
data: list[row_type] = []
|
data: list[row_type] = []
|
||||||
page = 1
|
page = first_page
|
||||||
trs = await all_trs_from_page(f'{BASE_URL}{category}', session)
|
trs = await trs_from_page(f'{BASE_URL}{category}', session)
|
||||||
while len(trs) > 0:
|
while page <= last_page and len(trs) > 0:
|
||||||
log.info(f"Scraping '{category}' page {page}")
|
data.extend(extract_row_data(*trs))
|
||||||
data.extend(data_from_rows(trs))
|
log.info(f"Scraped '{category}' page {page}")
|
||||||
page += 1
|
page += 1
|
||||||
trs = await all_trs_from_page(f'{BASE_URL}{category}/{page}', session)
|
trs = await trs_from_page(f'{BASE_URL}{category}/{page}', session)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
async def get_all_data(asynchronous: bool = False) -> list[row_type]:
|
async def get_all_data(sequential: bool = False) -> list[row_type]:
|
||||||
async with ClientSession() as session:
|
async with ClientSession() as session:
|
||||||
if asynchronous:
|
if sequential:
|
||||||
results = await asyncio.gather(*(all_data_from_category(category, session) for category in CATEGORIES))
|
results = [await get_data_from_category(category, session) for category in CATEGORIES]
|
||||||
else:
|
else:
|
||||||
results = [await all_data_from_category(category, session) for category in CATEGORIES]
|
results = await asyncio.gather(*(get_data_from_category(category, session) for category in CATEGORIES))
|
||||||
data = []
|
data = []
|
||||||
for result in results:
|
for result in results:
|
||||||
data.extend(result)
|
data.extend(result)
|
||||||
@ -128,7 +127,7 @@ def main() -> None:
|
|||||||
if args.verbose:
|
if args.verbose:
|
||||||
log.setLevel(logging.DEBUG)
|
log.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
data = asyncio.run(get_all_data(not args.sequential))
|
data = asyncio.run(get_all_data(args.sequential))
|
||||||
|
|
||||||
if args.to_file is None:
|
if args.to_file is None:
|
||||||
csv.writer(sys.stdout).writerows(data)
|
csv.writer(sys.stdout).writerows(data)
|
||||||
|
Loading…
Reference in New Issue
Block a user