-
-
Notifications
You must be signed in to change notification settings - Fork 1k
feat : add datasource sql and csv for webapp #772
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
b304468
89b2523
4710b1f
7a88324
73dc7a9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -479,6 +479,19 @@ def section_summary(text: str): | |
| ), | ||
| ] | ||
| ), | ||
| dbc.Col( | ||
| children=[ | ||
| html.Label("Source:"), | ||
| dcc.Dropdown( | ||
| id="datasource", | ||
| options=[{"value": 'yahoo', "label": 'yahoo'}, | ||
| {"value": 'csv', "label": 'csv'}, | ||
| {"value": 'csv_all', "label": 'csv_all'}, | ||
| {"value": 'sql', "label": 'sql'}], | ||
| value='csv', | ||
| ), | ||
| ] | ||
| ), | ||
| ], | ||
| ), | ||
| html.Label("Filter period:"), | ||
|
|
@@ -906,8 +919,26 @@ def section_summary(text: str): | |
| ) | ||
|
|
||
|
|
||
| data_mode = 'sql' | ||
| df_all_symbol = None | ||
| sql_dal = None | ||
|
|
||
|
|
||
| @cache.memoize() | ||
| def fetch_data(symbol, period, interval, auto_adjust, back_adjust): | ||
| """Fetch OHLCV data from backend.""" | ||
| global data_mode | ||
| if data_mode == 'csv_all': | ||
| return fetch_data_csv_all(symbol, period, interval, auto_adjust, back_adjust) | ||
| elif data_mode == 'csv': | ||
| return fetch_data_csv(symbol, period, interval, auto_adjust, back_adjust) | ||
| elif data_mode == 'sql': | ||
| return fetch_data_sql(symbol, period, interval, auto_adjust, back_adjust) | ||
| return fetch_data_yf(symbol, period, interval, auto_adjust, back_adjust) | ||
|
|
||
|
|
||
| @cache.memoize() | ||
| def fetch_data_yf(symbol, period, interval, auto_adjust, back_adjust): | ||
| """Fetch OHLCV data from Yahoo! Finance.""" | ||
| df = yf.Ticker(symbol).history( | ||
| period=period, | ||
|
|
@@ -921,6 +952,73 @@ def fetch_data(symbol, period, interval, auto_adjust, back_adjust): | |
| return df | ||
|
|
||
|
|
||
| @cache.memoize() | ||
| def fetch_data_csv(symbol, period, interval, auto_adjust, back_adjust, csvdir='~/trade/data/kaggle_allUS_daily_with_volume_yahoo/stocks'): | ||
| """Fetch OHLCV data from csv containing one symbols.""" | ||
| csvfile = csvdir + '/' + symbol + '.csv' | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. change to f string |
||
| _start = time.perf_counter() | ||
| df = pd.read_csv(csvfile, parse_dates=['Date'], index_col='Date') | ||
| _duration = time.perf_counter() - _start | ||
| return df | ||
|
|
||
|
|
||
| @cache.memoize() | ||
| def fetch_data_csv_all(symbol, period, interval, auto_adjust, back_adjust, csvfile='~/trade/data/test.csv'): | ||
| """Fetch OHLCV data from csv containing multiple symbols.""" | ||
| global df_all_symbol | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. a global doesn't feel appropriate there. |
||
| csvfile = '~/trade/data/test.csv' | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we shouldn't use |
||
| if df_all_symbol is None: | ||
| _start = time.perf_counter() | ||
| df_all_symbol = pd.read_csv(csvfile, parse_dates=['date']) | ||
| df_all_symbol = df_all_symbol.rename( | ||
| columns={"date": "Date", "open": "Open", "high": "High", "low": "Low", "close": "Close", | ||
| "volume": "Volume"}) | ||
| df_all_symbol['Date'] = df_all_symbol['Date'].map( | ||
| lambda t: pd.to_datetime(t.replace(tzinfo=None)).to_pydatetime()) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. total nitpick but don't say 't' be more descriptive about this variable |
||
| df_all_symbol = df_all_symbol.set_index(pd.DatetimeIndex(df_all_symbol['Date'])) | ||
| _duration = time.perf_counter() - _start | ||
| res = df_all_symbol[df_all_symbol['symbol'] == symbol] | ||
| res = res.drop(['symbol'], axis=1) | ||
| return res | ||
|
|
||
|
|
||
| @cache.memoize() | ||
| def fetch_data_sql(symbol, period, interval, auto_adjust, back_adjust): | ||
| """Fetch OHLCV data from SQL.""" | ||
| global sql_dal | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove global |
||
| if interval == '60m': | ||
| dbtableRead = 'stock_market.ohlcv_1_hour' | ||
| elif interval == '15m': | ||
| dbtableRead = 'stock_market.ohlcv_15_minute' | ||
| elif interval == '1d': | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this seems counter intuitive. You have args "period" and "interval" but you're checking |
||
| dbtableRead = 'stock_market.ohlcv_1_day' | ||
| else: | ||
| raise(ValueError(f"error unknown interval {interval}")) | ||
| if sql_dal is None: | ||
| user = getpass.getuser() | ||
| password = os.getenv('pg_password') | ||
| sql_dal = MarketDataRepository() | ||
| sql_dal.init(user=user, password=password) | ||
| res = sql_dal.get_one_symbol(dbtableRead, symbol) | ||
| res = res.rename( | ||
| columns={"date": "Date", "open": "Open", "high": "High", "low": "Low", "close": "Close", | ||
| "volume": "Volume"}) | ||
| return res | ||
|
|
||
|
|
||
| @app.callback( | ||
| [Output('clean_button', 'children')], | ||
| [Input('clean_button', 'n_clicks')], | ||
| prevent_initial_call=True | ||
| ) | ||
| def clean_cache(_): | ||
| """Clean all data in cache.""" | ||
| files = glob.glob('data/*') | ||
| for f in files: | ||
| os.remove(f) | ||
| return | ||
|
|
||
|
|
||
| @app.callback( | ||
| [ | ||
| Output("data_signal", "children"), | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,126 @@ | ||
| import datetime as dt | ||
| import psycopg2 | ||
| from psycopg2 import pool | ||
| import pandas as pd | ||
| import logging | ||
|
|
||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class MarketDataRepository: | ||
| def __init__(self): | ||
| """Initialize basic structure but don't connect yet""" | ||
| self.connection_pool = None | ||
| self.ver_twsapi = None | ||
|
|
||
| def init(self, database='postgres', user='admin', password='admin', host='127.0.0.1', port='5432', | ||
| ver_twsapi=10, maxconn=2): | ||
| self.ver_twsapi = ver_twsapi | ||
| # Establishing the connection | ||
| logger.info('connecting to host %s and DB %s' % (host, database)) | ||
| self.connection_pool = psycopg2.pool.ThreadedConnectionPool( | ||
| 1, maxconn, # minconn, maxconn | ||
| user=user, | ||
| password=password, | ||
| host=host, | ||
| port=port, | ||
| database=database | ||
| ) | ||
|
|
||
| def get_connection(self): | ||
| try: | ||
| conn = self.connection_pool.getconn() | ||
| except Exception as e: | ||
| raise e | ||
| return conn | ||
|
|
||
| def release_connection(self, connection): | ||
| self.connection_pool.putconn(connection) | ||
|
|
||
| def end(self): | ||
| self.connection_pool.closeall() | ||
|
|
||
| def get_all_symbol(self, dbtable: str, date_start: dt = None, date_end: dt = None): | ||
| if None not in [date_start, date_end]: | ||
| sql = '''SELECT * FROM %s WHERE date >= to_timestamp(%s) AND date <= to_timestamp(%s);''' % \ | ||
| (dbtable, date_start.strftime('%s'), date_end.strftime('%s')) | ||
| elif date_start == date_end: | ||
| sql = '''SELECT * FROM %s ;''' % (dbtable) | ||
| print(sql) | ||
| elif date_end is None: | ||
| sql = '''SELECT * FROM %s WHERE date >= to_timestamp(%s);''' % \ | ||
| (dbtable, date_start.strftime('%s')) | ||
| else: | ||
| sql = '''SELECT * FROM %s WHERE date <= to_timestamp(%s);''' % \ | ||
| (dbtable, date_end.strftime('%s')) | ||
| # pandas only supports SQLAlchemy connectable (engine/connection) or database string URI or sqlite3 DBAPI2 connection. Other DBAPI2 objects are not tested. Please consider using SQLAlchemy. | ||
| conn = self.get_connection() | ||
| try: | ||
| with conn.cursor() as cursor: | ||
| cursor.execute(sql) | ||
| result = cursor.fetchall() | ||
| columns = [desc[0] for desc in cursor.description] | ||
| cursor.close() | ||
| df = pd.DataFrame(result, columns=columns) | ||
| # Specify the types directly at DataFrame creation | ||
| df = df.astype({ | ||
| 'date': 'datetime64[ns, UTC]', | ||
| 'open': 'float', | ||
| 'high': 'float', | ||
| 'low': 'float', | ||
| 'close': 'float', | ||
| 'volume': 'int' | ||
| }) | ||
| except Exception as e: | ||
| raise e | ||
| finally: | ||
| self.release_connection(conn) | ||
| # remove TZ otherwise pandas DateTimeIndex lookup in pandas do not work | ||
| #df['date'] = df['date'].map(lambda t: pd.to_datetime(t.replace(tzinfo=None)).to_pydatetime()) | ||
| df = df.set_index(pd.DatetimeIndex(df['date'])) | ||
| return df | ||
|
|
||
| def get_one_symbol(self, dbtable: str, symbol: str, date_start: dt = None, date_end: dt = None): | ||
| if ':' not in symbol: | ||
| logger.error('Incorrect symbol %s' % symbol) | ||
| return | ||
| # raise ValueError | ||
| if None not in [date_start, date_end]: | ||
| sql = '''SELECT * FROM %s WHERE symbol = '%s' AND date >= to_timestamp(%s) AND date <= to_timestamp(%s);''' % \ | ||
| (dbtable, symbol, date_start.strftime('%s'), date_end.strftime('%s')) | ||
| elif date_start == date_end: | ||
| sql = '''SELECT * FROM %s WHERE symbol = '%s';''' % (dbtable, symbol) | ||
| print(sql) | ||
| elif date_end is None: | ||
| sql = '''SELECT * FROM %s WHERE symbol = '%s' AND date >= to_timestamp(%s);''' % \ | ||
| (dbtable, symbol, date_start.strftime('%s')) | ||
| else: | ||
| sql = '''SELECT * FROM %s WHERE symbol = '%s' AND date <= to_timestamp(%s);''' % \ | ||
| (dbtable, symbol, date_end.strftime('%s')) | ||
| # pandas only supports SQLAlchemy connectable (engine/connection) or database string URI or sqlite3 DBAPI2 connection. Other DBAPI2 objects are not tested. Please consider using SQLAlchemy. | ||
| conn = self.get_connection() | ||
| try: | ||
| with conn.cursor() as cursor: | ||
| cursor.execute(sql) | ||
| result = cursor.fetchall() | ||
| columns = [desc[0] for desc in cursor.description] | ||
| cursor.close() | ||
| df = pd.DataFrame(result, columns=columns) | ||
| # Specify the types directly at DataFrame creation | ||
| df = df.astype({ | ||
| 'date': 'datetime64[ns, UTC]', | ||
| 'open': 'float', | ||
| 'high': 'float', | ||
| 'low': 'float', | ||
| 'close': 'float', | ||
| 'volume': 'int' | ||
| }) | ||
| except Exception as e: | ||
| raise e | ||
| finally: | ||
| self.release_connection(conn) | ||
| # remove TZ otherwise pandas DateTimeIndex lookup in pandas do not work | ||
| #df['date'] = df['date'].map(lambda t: pd.to_datetime(t.replace(tzinfo=None)).to_pydatetime()) | ||
| df = df.set_index(pd.DatetimeIndex(df['date'])) | ||
| return df |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
csvdir is using
~this shouldn't be hard coded