Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 98 additions & 0 deletions apps/candlestick-patterns/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,19 @@ def section_summary(text: str):
),
]
),
dbc.Col(
children=[
html.Label("Source:"),
dcc.Dropdown(
id="datasource",
options=[{"value": 'yahoo', "label": 'yahoo'},
{"value": 'csv', "label": 'csv'},
{"value": 'csv_all', "label": 'csv_all'},
{"value": 'sql', "label": 'sql'}],
value='csv',
),
]
),
],
),
html.Label("Filter period:"),
Expand Down Expand Up @@ -906,8 +919,26 @@ def section_summary(text: str):
)


data_mode = 'sql'
df_all_symbol = None
sql_dal = None


@cache.memoize()
def fetch_data(symbol, period, interval, auto_adjust, back_adjust):
"""Fetch OHLCV data from backend."""
global data_mode
if data_mode == 'csv_all':
return fetch_data_csv_all(symbol, period, interval, auto_adjust, back_adjust)
elif data_mode == 'csv':
return fetch_data_csv(symbol, period, interval, auto_adjust, back_adjust)
elif data_mode == 'sql':
return fetch_data_sql(symbol, period, interval, auto_adjust, back_adjust)
return fetch_data_yf(symbol, period, interval, auto_adjust, back_adjust)


@cache.memoize()
def fetch_data_yf(symbol, period, interval, auto_adjust, back_adjust):
"""Fetch OHLCV data from Yahoo! Finance."""
df = yf.Ticker(symbol).history(
period=period,
Expand All @@ -921,6 +952,73 @@ def fetch_data(symbol, period, interval, auto_adjust, back_adjust):
return df


@cache.memoize()
def fetch_data_csv(symbol, period, interval, auto_adjust, back_adjust, csvdir='~/trade/data/kaggle_allUS_daily_with_volume_yahoo/stocks'):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

csvdir is using ~ this shouldn't be hard coded

"""Fetch OHLCV data from csv containing one symbols."""
csvfile = csvdir + '/' + symbol + '.csv'
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change to f string

_start = time.perf_counter()
df = pd.read_csv(csvfile, parse_dates=['Date'], index_col='Date')
_duration = time.perf_counter() - _start
return df


@cache.memoize()
def fetch_data_csv_all(symbol, period, interval, auto_adjust, back_adjust, csvfile='~/trade/data/test.csv'):
"""Fetch OHLCV data from csv containing multiple symbols."""
global df_all_symbol
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a global doesn't feel appropriate there.

csvfile = '~/trade/data/test.csv'
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we shouldn't use ~ anywhere

if df_all_symbol is None:
_start = time.perf_counter()
df_all_symbol = pd.read_csv(csvfile, parse_dates=['date'])
df_all_symbol = df_all_symbol.rename(
columns={"date": "Date", "open": "Open", "high": "High", "low": "Low", "close": "Close",
"volume": "Volume"})
df_all_symbol['Date'] = df_all_symbol['Date'].map(
lambda t: pd.to_datetime(t.replace(tzinfo=None)).to_pydatetime())
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

total nitpick but don't say 't' be more descriptive about this variable

df_all_symbol = df_all_symbol.set_index(pd.DatetimeIndex(df_all_symbol['Date']))
_duration = time.perf_counter() - _start
res = df_all_symbol[df_all_symbol['symbol'] == symbol]
res = res.drop(['symbol'], axis=1)
return res


@cache.memoize()
def fetch_data_sql(symbol, period, interval, auto_adjust, back_adjust):
"""Fetch OHLCV data from SQL."""
global sql_dal
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove global

if interval == '60m':
dbtableRead = 'stock_market.ohlcv_1_hour'
elif interval == '15m':
dbtableRead = 'stock_market.ohlcv_15_minute'
elif interval == '1d':
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems counter intuitive. You have args "period" and "interval" but you're checking interval=='xyx' shouldn't it be some f-string combination of interval and period?

dbtableRead = 'stock_market.ohlcv_1_day'
else:
raise(ValueError(f"error unknown interval {interval}"))
if sql_dal is None:
user = getpass.getuser()
password = os.getenv('pg_password')
sql_dal = MarketDataRepository()
sql_dal.init(user=user, password=password)
res = sql_dal.get_one_symbol(dbtableRead, symbol)
res = res.rename(
columns={"date": "Date", "open": "Open", "high": "High", "low": "Low", "close": "Close",
"volume": "Volume"})
return res


@app.callback(
[Output('clean_button', 'children')],
[Input('clean_button', 'n_clicks')],
prevent_initial_call=True
)
def clean_cache(_):
"""Clean all data in cache."""
files = glob.glob('data/*')
for f in files:
os.remove(f)
return


@app.callback(
[
Output("data_signal", "children"),
Expand Down
126 changes: 126 additions & 0 deletions apps/candlestick-patterns/dal_stock_sql/dal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import datetime as dt
import psycopg2
from psycopg2 import pool
import pandas as pd
import logging


logger = logging.getLogger(__name__)


class MarketDataRepository:
def __init__(self):
"""Initialize basic structure but don't connect yet"""
self.connection_pool = None
self.ver_twsapi = None

def init(self, database='postgres', user='admin', password='admin', host='127.0.0.1', port='5432',
ver_twsapi=10, maxconn=2):
self.ver_twsapi = ver_twsapi
# Establishing the connection
logger.info('connecting to host %s and DB %s' % (host, database))
self.connection_pool = psycopg2.pool.ThreadedConnectionPool(
1, maxconn, # minconn, maxconn
user=user,
password=password,
host=host,
port=port,
database=database
)

def get_connection(self):
try:
conn = self.connection_pool.getconn()
except Exception as e:
raise e
return conn

def release_connection(self, connection):
self.connection_pool.putconn(connection)

def end(self):
self.connection_pool.closeall()

def get_all_symbol(self, dbtable: str, date_start: dt = None, date_end: dt = None):
if None not in [date_start, date_end]:
sql = '''SELECT * FROM %s WHERE date >= to_timestamp(%s) AND date <= to_timestamp(%s);''' % \
(dbtable, date_start.strftime('%s'), date_end.strftime('%s'))
elif date_start == date_end:
sql = '''SELECT * FROM %s ;''' % (dbtable)
print(sql)
elif date_end is None:
sql = '''SELECT * FROM %s WHERE date >= to_timestamp(%s);''' % \
(dbtable, date_start.strftime('%s'))
else:
sql = '''SELECT * FROM %s WHERE date <= to_timestamp(%s);''' % \
(dbtable, date_end.strftime('%s'))
# pandas only supports SQLAlchemy connectable (engine/connection) or database string URI or sqlite3 DBAPI2 connection. Other DBAPI2 objects are not tested. Please consider using SQLAlchemy.
conn = self.get_connection()
try:
with conn.cursor() as cursor:
cursor.execute(sql)
result = cursor.fetchall()
columns = [desc[0] for desc in cursor.description]
cursor.close()
df = pd.DataFrame(result, columns=columns)
# Specify the types directly at DataFrame creation
df = df.astype({
'date': 'datetime64[ns, UTC]',
'open': 'float',
'high': 'float',
'low': 'float',
'close': 'float',
'volume': 'int'
})
except Exception as e:
raise e
finally:
self.release_connection(conn)
# remove TZ otherwise pandas DateTimeIndex lookup in pandas do not work
#df['date'] = df['date'].map(lambda t: pd.to_datetime(t.replace(tzinfo=None)).to_pydatetime())
df = df.set_index(pd.DatetimeIndex(df['date']))
return df

def get_one_symbol(self, dbtable: str, symbol: str, date_start: dt = None, date_end: dt = None):
if ':' not in symbol:
logger.error('Incorrect symbol %s' % symbol)
return
# raise ValueError
if None not in [date_start, date_end]:
sql = '''SELECT * FROM %s WHERE symbol = '%s' AND date >= to_timestamp(%s) AND date <= to_timestamp(%s);''' % \
(dbtable, symbol, date_start.strftime('%s'), date_end.strftime('%s'))
elif date_start == date_end:
sql = '''SELECT * FROM %s WHERE symbol = '%s';''' % (dbtable, symbol)
print(sql)
elif date_end is None:
sql = '''SELECT * FROM %s WHERE symbol = '%s' AND date >= to_timestamp(%s);''' % \
(dbtable, symbol, date_start.strftime('%s'))
else:
sql = '''SELECT * FROM %s WHERE symbol = '%s' AND date <= to_timestamp(%s);''' % \
(dbtable, symbol, date_end.strftime('%s'))
# pandas only supports SQLAlchemy connectable (engine/connection) or database string URI or sqlite3 DBAPI2 connection. Other DBAPI2 objects are not tested. Please consider using SQLAlchemy.
conn = self.get_connection()
try:
with conn.cursor() as cursor:
cursor.execute(sql)
result = cursor.fetchall()
columns = [desc[0] for desc in cursor.description]
cursor.close()
df = pd.DataFrame(result, columns=columns)
# Specify the types directly at DataFrame creation
df = df.astype({
'date': 'datetime64[ns, UTC]',
'open': 'float',
'high': 'float',
'low': 'float',
'close': 'float',
'volume': 'int'
})
except Exception as e:
raise e
finally:
self.release_connection(conn)
# remove TZ otherwise pandas DateTimeIndex lookup in pandas do not work
#df['date'] = df['date'].map(lambda t: pd.to_datetime(t.replace(tzinfo=None)).to_pydatetime())
df = df.set_index(pd.DatetimeIndex(df['date']))
return df