stocksymbolscraper/src/stock-symbol-scraper/main.py

142 lines
4.2 KiB
Python

import logging
import re
import csv
import sys
import asyncio
from argparse import ArgumentParser
from pathlib import Path
from datetime import datetime
from string import ascii_uppercase
from aiohttp import ClientSession
from bs4 import BeautifulSoup
from bs4.element import Tag, ResultSet
log = logging.getLogger(__name__)
log.setLevel(logging.ERROR)
ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)
log.addHandler(ch)
row_type = tuple[str, str, str, str, str]
DOMAIN = 'www.marketwatch.com'
BASE_URL = f'https://{DOMAIN}/tools/markets/stocks/a-z/'
DIGIT_CATEGORY = '0-9'
OTHER_CATEGORY = 'Other'
CATEGORIES = [DIGIT_CATEGORY] + list(ascii_uppercase) + [OTHER_CATEGORY]
STOCK_SYMBOL_PATTERN = re.compile(r'\(([\w.&]+)\)')
class UnexpectedMarkupError(Exception):
pass
def data_from_rows(trs: ResultSet) -> list[row_type]:
data: list[row_type] = []
for row in trs:
data.append(get_single_row_data(row))
return data
def get_single_row_data(table_row: Tag) -> row_type:
tds = table_row.find_all('td')
company_name = str(tds[0].a.contents[0]).strip()
stock_symbol = str(tds[0].a.contents[1].contents[0]).strip()
m = re.search(STOCK_SYMBOL_PATTERN, stock_symbol)
if m is None:
log.error(f"{stock_symbol} did not match the stock symbol pattern")
else:
stock_symbol = m.group(1)
country = get_str_from_td(tds[1])
exchange = get_str_from_td(tds[2])
sector = get_str_from_td(tds[3])
return company_name, stock_symbol, country, exchange, sector
def get_str_from_td(td: Tag) -> str:
try:
content = td.contents[0]
except IndexError:
return ''
return str(content).strip()
async def all_trs_from_page(url: str, session: ClientSession = None) -> ResultSet:
if session is None:
session = ClientSession()
async with session.get(url) as response:
html = await response.text()
soup = BeautifulSoup(html, 'html.parser')
try:
return soup.find('div', {'id': 'marketsindex'}).table.tbody.find_all('tr')
except AttributeError:
log.error("Unexpected HTML markup!")
file_name = f'unexpected_response_at_{datetime.now().strftime("%Y-%m-%d_%H-%M")}.html'
with open(file_name, 'w') as f:
f.write(html)
raise UnexpectedMarkupError
async def all_data_from_category(category: str, session: ClientSession = None) -> list[row_type]:
log.info(f"Getting companies starting with '{category}'")
if session is None:
session = ClientSession()
data: list[row_type] = []
page = 1
trs = await all_trs_from_page(f'{BASE_URL}{category}', session)
while len(trs) > 0:
log.info(f"Scraping '{category}' page {page}")
data.extend(data_from_rows(trs))
page += 1
trs = await all_trs_from_page(f'{BASE_URL}{category}/{page}', session)
return data
async def get_all_data(asynchronous: bool = False) -> list[row_type]:
async with ClientSession() as session:
if asynchronous:
results = await asyncio.gather(*(all_data_from_category(category, session) for category in CATEGORIES))
else:
results = [await all_data_from_category(category, session) for category in CATEGORIES]
data = []
for result in results:
data.extend(result)
return data
def main() -> None:
parser = ArgumentParser(description="Scrape all stock symbols")
parser.add_argument(
'-v', '--verbose',
action='store_true',
help="If set, prints all sorts of stuff."
)
parser.add_argument(
'-S', '--sequential',
action='store_true',
help="If set, all requests are performed sequentially; otherwise async capabilities are used for concurrency."
)
parser.add_argument(
'-f', '--to-file',
type=Path,
help="Writes results to the specified destination file. If omitted results are printed to stdout."
)
args = parser.parse_args()
if args.verbose:
log.setLevel(logging.DEBUG)
data = asyncio.run(get_all_data(not args.sequential))
if args.to_file is None:
csv.writer(sys.stdout).writerows(data)
else:
with open(args.to_file, 'w') as f:
csv.writer(f).writerows(data)
if __name__ == '__main__':
main()