diff --git a/app.py b/app.py index 6681b8218ac29a0b56b9bd2a21ad6aba59acc9e8..74d80308b4ea4fb1010fe2d9faa6643fa40d2f72 100644 --- a/app.py +++ b/app.py @@ -32,7 +32,6 @@ from sqlalchemy.exc import InvalidRequestError from groq import Groq from mistralai import Mistral -from openai import OpenAI app = Flask(__name__) @@ -48,6 +47,118 @@ socketio = SocketIO(app, async_mode="gevent") # Global dictionary to keep track of cancellation requests cancellation_requests = {} +from openai import OpenAI + +# Global openai inference compatible client endpoints. +ENDPOINTS = [ + # vLLM clusters + { + "name": "vllm1", + "base_url": os.environ.get("VLLM_ENDPOINT"), + "api_key": os.environ.get("VLLM_API_KEY", "not-needed"), + }, + { + "name": "vllm2", + "base_url": os.environ.get("VLLM_ENDPOINT2"), + "api_key": os.environ.get("VLLM_API_KEY2", "not-needed"), + }, + { + "name": "vllm3", + "base_url": os.environ.get("VLLM_ENDPOINT3"), + "api_key": os.environ.get("VLLM_API_KEY3", "not-needed"), + }, + # Ollama + { + "name": "ollama", + "base_url": os.environ.get("OLLAMA_ENDPOINT"), + "api_key": os.environ.get("OLLAMA_API_KEY", "not-needed"), + }, + # x.ai (for grok, etc.) + { + "name": "xai", + "base_url": "https://api.x.ai/v1", + "api_key": os.environ.get("XAI_API_KEY", "not-needed"), + }, + # Google generative language (Gemini) + { + "name": "google", + "base_url": "https://generativelanguage.googleapis.com/v1beta/openai/", + "api_key": os.environ.get("GOOGLE_API_KEY"), + }, + # Fallback: official OpenAI + { + "name": "public-openai", + "base_url": None, # Means library defaults to https://api.openai.com/v1 + "api_key": os.environ.get("OPENAI_API_KEY"), + }, +] + + +# 2) Initialization function: Build the map by listing models on each endpoint +def initialize_model_map(): + MODEL_CLIENT_MAP.clear() + + for ep in ENDPOINTS: + base_url = ep["base_url"] + api_key = ep["api_key"] + endpoint_name = ep["name"] + + # Create a dedicated client for this endpoint + client = OpenAI(base_url=base_url, api_key=api_key) + + # Attempt to list the models from this endpoint + try: + response = client.models.list() + # vLLM returns a SyncPage[Model], so 'response.data' is a list of Model() objects + model_list = response.data + print(f"[DEBUG] {endpoint_name} => {model_list}") + except Exception as e: + print(f"[WARN] Could not list models for endpoint '{endpoint_name}': {e}") + continue + + # For each discovered model, store it in the global map + for m in model_list: + model_id = m.id + if model_id and model_id not in MODEL_CLIENT_MAP: + MODEL_CLIENT_MAP[model_id] = client + + print("loaded models:") + print(MODEL_CLIENT_MAP) + + +# 3) A global dictionary: model_name -> dedicated OpenAI client +MODEL_CLIENT_MAP = {} + +if MODEL_CLIENT_MAP: + pass +else: + initialize_model_map() + + +# 4) Lookup function: get an OpenAI client for a given model name +def get_client_for_model(model_name: str): + """ + If the model name is known, return its dedicated client. + Otherwise, fallback to public openai usage. + """ + # Return the matching client if we have it + if model_name in MODEL_CLIENT_MAP: + return MODEL_CLIENT_MAP[model_name] + # Otherwise, fallback to official + # print(f"[INFO] Unknown model '{model_name}'. Using fallback (public-openai).") + fallback_client = OpenAI( + base_url=None, # official openai base + api_key=os.environ.get("OPENAI_API_KEY"), + ) + return fallback_client + + +def get_openai_client_and_model( + model_name="adamo1139/Hermes-3-Llama-3.1-8B-FP8-Dynamic", +): + return get_client_for_model(model_name), model_name + + system_users = [ "anthropic.claude-3-haiku-20240307-v1:0", "anthropic.claude-3-sonnet-20240229-v1:0", @@ -74,7 +185,7 @@ system_users = [ "mistral-small-latest", "mistral-medium", "mistral-large-latest", - "codestral-latest" + "codestral-latest", "mistralai/Mixtral-8x7B-v0.1", "mistralai/Mistral-7B-Instruct-v0.1", "mixtral-8x7b-32768", @@ -1225,54 +1336,6 @@ def chat_claude( socketio.emit("delete_processing_message", msg_id, room=room.name) -def get_openai_client_and_model(model_name="adamo1139/Hermes-3-Llama-3.1-8B-FP8-Dynamic"): - vllm_endpoint = os.environ.get("VLLM_ENDPOINT") - vllm_api_key = os.environ.get("VLLM_API_KEY", "not-needed") - vllm_endpoint2 = os.environ.get("VLLM_ENDPOINT2") - vllm_api_key2 = os.environ.get("VLLM_ENDPOINTAPI_KEY2", "not-needed") - vllm_endpoint3 = os.environ.get("VLLM_ENDPOINT3") - vllm_api_key3 = os.environ.get("VLLM_ENDPOINTAPI_KEY3", "not-needed") - ollama_endpoint = os.environ.get("OLLAMA_ENDPOINT") - ollama_api_key = os.environ.get("OLLAMA_API_KEY", "not-needed") - xai_api_key = os.environ.get("XAI_API_KEY") - google_api_key = os.environ.get("GOOGLE_API_KEY") - is_openai_model = ( - model_name.lower().startswith(('gpt', 'o1', 'o3')) - ) - is_xai_model = "grok-" in model_name.lower() - is_google_model = "gemini-" in model_name.lower() - is_ollama_model = "hf.co" in model_name.lower() - is_qwq_model = "qwq" in model_name.lower() - is_r1_model = "deepseek" in model_name.lower() - is_vllm_model = not ( - is_openai_model - or is_xai_model - or is_google_model - or is_ollama_model - or is_qwq_model - or is_r1_model - ) - # clearly this isn't the ideal way to grow our open source endpoints... - if is_vllm_model: - openai_client = OpenAI(base_url=vllm_endpoint, api_key=vllm_api_key) - elif is_qwq_model: - openai_client = OpenAI(base_url=vllm_endpoint2, api_key=vllm_api_key2) - elif is_r1_model: - openai_client = OpenAI(base_url=vllm_endpoint3, api_key=vllm_api_key3) - elif is_ollama_model: - openai_client = OpenAI(base_url=ollama_endpoint, api_key=ollama_api_key) - elif is_xai_model: - openai_client = OpenAI(base_url="https://api.x.ai/v1", api_key=xai_api_key) - elif is_google_model: - openai_client = OpenAI( - base_url="https://generativelanguage.googleapis.com/v1beta/openai/", - api_key=google_api_key, - ) - else: - openai_client = OpenAI() - return openai_client, model_name - - def chat_gpt(username, room_name, model_name="gpt-4o-mini"): openai_client, model_name = get_openai_client_and_model(model_name) @@ -3248,9 +3311,13 @@ def translate_text(text, target_language): except Exception as e: return f"Error: {e}" + if __name__ == "__main__": import argparse - parser = argparse.ArgumentParser(description="Run the SocketIO application with optional configurations.") + + parser = argparse.ArgumentParser( + description="Run the SocketIO application with optional configurations." + ) parser.add_argument("--profile", help="AWS profile name", default=None) parser.add_argument( "--local-activities",