diff --git a/app.py b/app.py index 969dcd2..d4da490 100644 --- a/app.py +++ b/app.py @@ -1,139 +1,19 @@ -import os -from threading import Thread -import json -from rethinkdb import RethinkDB -from pybit.unified_trading import WebSocket from time import sleep -from messages.TickerData import Message,TickerData -from messages.Condition import Condition -from typing import List, Set +from messages.DbConnector import DbConnector +from messages.Condition import Condition +from messages.Bybit import Bybit +from messages.TickerData import TickerData -r = RethinkDB() - -# RethinkDB settings -RDB_HOST = os.environ.get('RETHINKDB_URL') -RDB_PORT = os.environ.get('RETHINKDB_PORT') -DB_NAME = "finfree" -TABLE_NAME = 'conditions' - -# Initialize RethinkDB connection -rdb_conn = r.connect(host=RDB_HOST, port=RDB_PORT, db=DB_NAME) - -# WebSocket setup -WS_URL = "wss://stream.bybit.com/v5/public/linear" -ws = WebSocket( - testnet=True, - channel_type="linear", -) - -_subscribed = set() -_conditions = list() -_last = dict() - -def getSubscribed() -> Set[str]: - return _subscribed - -def addToSubscribed(symbol): - global _subscribed - _subscribed.add(symbol) - -def getConditions() -> List[Condition]: - return _conditions - -def setConditions(condition): - global _conditions - _conditions = condition - -def getLast(): - return _last - -def setLast(key, value): - global _last - _last[key] = value - -def replace(new_object, array): - index = next((i for i, obj in enumerate(array) if obj.id == new_object.id), None) - if index is not None: - array[index] = new_object - -def fetch_conditions(): - try: - symbols = r.table(TABLE_NAME).run(rdb_conn) - return list(symbols) - except Exception as e: - print(f"Error fetching symbols: {e}") - return [] - -def get_symbols(conditions: List[Condition]): - return list(set([symbol['symbol'] for symbol in conditions])) - - -def get_symbol(symbol: str, conditions: List[Condition]) -> List[Condition]: - return list(filter(lambda c: c["symbol"] == symbol, conditions)) - -def handle_message(message: dict): - msg = Message.from_json(message) - if (isinstance(msg.data, TickerData)): - ticker = msg.data - symbol = ticker.symbol - price = ticker.lastPrice - last = getLast() - - if symbol not in last or last[symbol] != price: - - conditions = getConditions() - - filtered = get_symbol(ticker.symbol, conditions) - - for condition in filtered: - if (condition["condition"] == "lt"): - lower_than(condition, price) - elif (condition["condition"] == "gt"): - greater_than(condition, price) - - setLast(symbol, price) - -def watch_symbols_table(): - feed = r.table(TABLE_NAME).changes().run(rdb_conn) - for change in feed: - print(f"Change detected: {change}") - - replace(change['new_val'], getConditions()) - - if change['new_val'] and not change['old_val']: # New symbol added - symbol = change['new_val']['symbol'] - subscribe_to_symbol(symbol) - -def subscribe_to_symbol(symbol): - subscribed = getSubscribed() - - if symbol not in subscribed: - addToSubscribed(symbol) - ws.ticker_stream( - symbol=symbol, - callback=handle_message - ) - -def lower_than(condition: Condition, price: float): - print(f"is {condition['symbol']} price {condition['value']} lower than {price}? {condition['value'] < price}") - -def greater_than(condition: Condition, price): - print(f"is {condition['symbol']} price {condition['value']} greater than {price}? {condition['value'] > price}") +def handle_tickerdata(data: TickerData, condition: Condition): + print(data.symbol + ": " + data.lastPrice) def main(): - - setConditions(fetch_conditions()) - - for symbol in get_symbols(getConditions()): - subscribe_to_symbol(symbol) - - watch_symbols_table() + DbConnector.watch_conditions(lambda c: + Bybit.subscribe_symbol(c.symbol, lambda d: + handle_tickerdata(d, c))) while True: sleep(1) if __name__ == "__main__": main() - - - diff --git a/messages/Bybit.py b/messages/Bybit.py new file mode 100644 index 0000000..008688c --- /dev/null +++ b/messages/Bybit.py @@ -0,0 +1,43 @@ +from pybit.unified_trading import WebSocket +from typing import Callable +from functools import partial +from .TickerData import Message, TickerData + +WS_URL = "wss://stream.bybit.com/v5/public/linear" +ws = WebSocket( + testnet=True, + channel_type="linear", +) + +subscribed: set[str] = set() +lastPrice: dict[str, float] = dict() + +class Bybit: + @staticmethod + def subscribe_symbol(symbol: str, callback: Callable[[TickerData], None]): + global subscribed + + if symbol not in subscribed: + subscribed.add(symbol) + ws.ticker_stream( + symbol=symbol, + callback=lambda m: Bybit.handle_message(m, callback) + ) + + @staticmethod + def handle_message(message, callback: Callable[[TickerData], None]): + msg = Message.from_json(message) + if msg.data is not None and isinstance(msg.data, TickerData): + Bybit.handle_tickerdata(msg.data, callback) + + @staticmethod + def handle_tickerdata(data: TickerData, callback: Callable[[TickerData], None]): + global lastPrice + if data.symbol not in lastPrice or lastPrice[data.symbol] != data.lastPrice: + lastPrice[data.symbol] = data.lastPrice + Bybit.handle_new_tickerdata(data, callback) + + @staticmethod + def handle_new_tickerdata(data: TickerData, callback: Callable[[TickerData], None]): + callback(data) + diff --git a/messages/DbConnector.py b/messages/DbConnector.py new file mode 100644 index 0000000..60b8da3 --- /dev/null +++ b/messages/DbConnector.py @@ -0,0 +1,48 @@ +import os +from typing import Callable +from rethinkdb import RethinkDB +from .Condition import Condition + +RDB_HOST = os.environ.get('RETHINKDB_URL') +RDB_PORT = os.environ.get('RETHINKDB_PORT') +DB_NAME = "finfree" +TABLE_NAME = 'conditions' + +r = RethinkDB() +connection = None + +def get_connection(): + global connection + + if connection is None or not connection.is_open(): + connection = r.connect(RDB_HOST, RDB_PORT, db=DB_NAME) + + return connection + +def getRethinkDB(): + return r + +def fetch_conditions() -> list[Condition]: + try: + cursor = r.table(TABLE_NAME).run(get_connection()) + return [Condition(**doc) for doc in cursor] + except Exception as e: + print(f"Error fetching symbols: {e}") + return [] + +conditions: list[Condition] = list() + +class DbConnector: + def watch_conditions(callback: Callable[[Condition], None]): + global conditions + + conditions = fetch_conditions() + + for cond in conditions: + callback(cond) + + feed = r.table(TABLE_NAME).changes().run(get_connection()) + for change in feed: + if change['new_val'] and not change['old_val']: # New symbol added + cond: Condition = change['new_val'] + callback(cond)