diff --git a/apps/candlestick-patterns/app.py b/apps/candlestick-patterns/app.py index 3c019678..ffd1a6e9 100644 --- a/apps/candlestick-patterns/app.py +++ b/apps/candlestick-patterns/app.py @@ -479,6 +479,19 @@ def section_summary(text: str): ), ] ), + dbc.Col( + children=[ + html.Label("Source:"), + dcc.Dropdown( + id="datasource", + options=[{"value": 'yahoo', "label": 'yahoo'}, + {"value": 'csv', "label": 'csv'}, + {"value": 'csv_all', "label": 'csv_all'}, + {"value": 'sql', "label": 'sql'}], + value='csv', + ), + ] + ), ], ), html.Label("Filter period:"), @@ -906,8 +919,26 @@ def section_summary(text: str): ) +data_mode = 'sql' +df_all_symbol = None +sql_dal = None + + @cache.memoize() def fetch_data(symbol, period, interval, auto_adjust, back_adjust): + """Fetch OHLCV data from backend.""" + global data_mode + if data_mode == 'csv_all': + return fetch_data_csv_all(symbol, period, interval, auto_adjust, back_adjust) + elif data_mode == 'csv': + return fetch_data_csv(symbol, period, interval, auto_adjust, back_adjust) + elif data_mode == 'sql': + return fetch_data_sql(symbol, period, interval, auto_adjust, back_adjust) + return fetch_data_yf(symbol, period, interval, auto_adjust, back_adjust) + + +@cache.memoize() +def fetch_data_yf(symbol, period, interval, auto_adjust, back_adjust): """Fetch OHLCV data from Yahoo! Finance.""" df = yf.Ticker(symbol).history( period=period, @@ -921,6 +952,73 @@ def fetch_data(symbol, period, interval, auto_adjust, back_adjust): return df +@cache.memoize() +def fetch_data_csv(symbol, period, interval, auto_adjust, back_adjust, csvdir='~/trade/data/kaggle_allUS_daily_with_volume_yahoo/stocks'): + """Fetch OHLCV data from csv containing one symbols.""" + csvfile = csvdir + '/' + symbol + '.csv' + _start = time.perf_counter() + df = pd.read_csv(csvfile, parse_dates=['Date'], index_col='Date') + _duration = time.perf_counter() - _start + return df + + +@cache.memoize() +def fetch_data_csv_all(symbol, period, interval, auto_adjust, back_adjust, csvfile='~/trade/data/test.csv'): + """Fetch OHLCV data from csv containing multiple symbols.""" + global df_all_symbol + csvfile = '~/trade/data/test.csv' + if df_all_symbol is None: + _start = time.perf_counter() + df_all_symbol = pd.read_csv(csvfile, parse_dates=['date']) + df_all_symbol = df_all_symbol.rename( + columns={"date": "Date", "open": "Open", "high": "High", "low": "Low", "close": "Close", + "volume": "Volume"}) + df_all_symbol['Date'] = df_all_symbol['Date'].map( + lambda t: pd.to_datetime(t.replace(tzinfo=None)).to_pydatetime()) + df_all_symbol = df_all_symbol.set_index(pd.DatetimeIndex(df_all_symbol['Date'])) + _duration = time.perf_counter() - _start + res = df_all_symbol[df_all_symbol['symbol'] == symbol] + res = res.drop(['symbol'], axis=1) + return res + + +@cache.memoize() +def fetch_data_sql(symbol, period, interval, auto_adjust, back_adjust): + """Fetch OHLCV data from SQL.""" + global sql_dal + if interval == '60m': + dbtableRead = 'stock_market.ohlcv_1_hour' + elif interval == '15m': + dbtableRead = 'stock_market.ohlcv_15_minute' + elif interval == '1d': + dbtableRead = 'stock_market.ohlcv_1_day' + else: + raise(ValueError(f"error unknown interval {interval}")) + if sql_dal is None: + user = getpass.getuser() + password = os.getenv('pg_password') + sql_dal = MarketDataRepository() + sql_dal.init(user=user, password=password) + res = sql_dal.get_one_symbol(dbtableRead, symbol) + res = res.rename( + columns={"date": "Date", "open": "Open", "high": "High", "low": "Low", "close": "Close", + "volume": "Volume"}) + return res + + +@app.callback( + [Output('clean_button', 'children')], + [Input('clean_button', 'n_clicks')], + prevent_initial_call=True +) +def clean_cache(_): + """Clean all data in cache.""" + files = glob.glob('data/*') + for f in files: + os.remove(f) + return + + @app.callback( [ Output("data_signal", "children"), diff --git a/apps/candlestick-patterns/dal_stock_sql/dal.py b/apps/candlestick-patterns/dal_stock_sql/dal.py new file mode 100644 index 00000000..fe675b03 --- /dev/null +++ b/apps/candlestick-patterns/dal_stock_sql/dal.py @@ -0,0 +1,126 @@ +import datetime as dt +import psycopg2 +from psycopg2 import pool +import pandas as pd +import logging + + +logger = logging.getLogger(__name__) + + +class MarketDataRepository: + def __init__(self): + """Initialize basic structure but don't connect yet""" + self.connection_pool = None + self.ver_twsapi = None + + def init(self, database='postgres', user='admin', password='admin', host='127.0.0.1', port='5432', + ver_twsapi=10, maxconn=2): + self.ver_twsapi = ver_twsapi + # Establishing the connection + logger.info('connecting to host %s and DB %s' % (host, database)) + self.connection_pool = psycopg2.pool.ThreadedConnectionPool( + 1, maxconn, # minconn, maxconn + user=user, + password=password, + host=host, + port=port, + database=database + ) + + def get_connection(self): + try: + conn = self.connection_pool.getconn() + except Exception as e: + raise e + return conn + + def release_connection(self, connection): + self.connection_pool.putconn(connection) + + def end(self): + self.connection_pool.closeall() + + def get_all_symbol(self, dbtable: str, date_start: dt = None, date_end: dt = None): + if None not in [date_start, date_end]: + sql = '''SELECT * FROM %s WHERE date >= to_timestamp(%s) AND date <= to_timestamp(%s);''' % \ + (dbtable, date_start.strftime('%s'), date_end.strftime('%s')) + elif date_start == date_end: + sql = '''SELECT * FROM %s ;''' % (dbtable) + print(sql) + elif date_end is None: + sql = '''SELECT * FROM %s WHERE date >= to_timestamp(%s);''' % \ + (dbtable, date_start.strftime('%s')) + else: + sql = '''SELECT * FROM %s WHERE date <= to_timestamp(%s);''' % \ + (dbtable, date_end.strftime('%s')) + # pandas only supports SQLAlchemy connectable (engine/connection) or database string URI or sqlite3 DBAPI2 connection. Other DBAPI2 objects are not tested. Please consider using SQLAlchemy. + conn = self.get_connection() + try: + with conn.cursor() as cursor: + cursor.execute(sql) + result = cursor.fetchall() + columns = [desc[0] for desc in cursor.description] + cursor.close() + df = pd.DataFrame(result, columns=columns) + # Specify the types directly at DataFrame creation + df = df.astype({ + 'date': 'datetime64[ns, UTC]', + 'open': 'float', + 'high': 'float', + 'low': 'float', + 'close': 'float', + 'volume': 'int' + }) + except Exception as e: + raise e + finally: + self.release_connection(conn) + # remove TZ otherwise pandas DateTimeIndex lookup in pandas do not work + #df['date'] = df['date'].map(lambda t: pd.to_datetime(t.replace(tzinfo=None)).to_pydatetime()) + df = df.set_index(pd.DatetimeIndex(df['date'])) + return df + + def get_one_symbol(self, dbtable: str, symbol: str, date_start: dt = None, date_end: dt = None): + if ':' not in symbol: + logger.error('Incorrect symbol %s' % symbol) + return + # raise ValueError + if None not in [date_start, date_end]: + sql = '''SELECT * FROM %s WHERE symbol = '%s' AND date >= to_timestamp(%s) AND date <= to_timestamp(%s);''' % \ + (dbtable, symbol, date_start.strftime('%s'), date_end.strftime('%s')) + elif date_start == date_end: + sql = '''SELECT * FROM %s WHERE symbol = '%s';''' % (dbtable, symbol) + print(sql) + elif date_end is None: + sql = '''SELECT * FROM %s WHERE symbol = '%s' AND date >= to_timestamp(%s);''' % \ + (dbtable, symbol, date_start.strftime('%s')) + else: + sql = '''SELECT * FROM %s WHERE symbol = '%s' AND date <= to_timestamp(%s);''' % \ + (dbtable, symbol, date_end.strftime('%s')) + # pandas only supports SQLAlchemy connectable (engine/connection) or database string URI or sqlite3 DBAPI2 connection. Other DBAPI2 objects are not tested. Please consider using SQLAlchemy. + conn = self.get_connection() + try: + with conn.cursor() as cursor: + cursor.execute(sql) + result = cursor.fetchall() + columns = [desc[0] for desc in cursor.description] + cursor.close() + df = pd.DataFrame(result, columns=columns) + # Specify the types directly at DataFrame creation + df = df.astype({ + 'date': 'datetime64[ns, UTC]', + 'open': 'float', + 'high': 'float', + 'low': 'float', + 'close': 'float', + 'volume': 'int' + }) + except Exception as e: + raise e + finally: + self.release_connection(conn) + # remove TZ otherwise pandas DateTimeIndex lookup in pandas do not work + #df['date'] = df['date'].map(lambda t: pd.to_datetime(t.replace(tzinfo=None)).to_pydatetime()) + df = df.set_index(pd.DatetimeIndex(df['date'])) + return df