defmodule Bert do use GenServer require Logger Nx.global_default_backend(EXLA.Backend) @moduledoc """ Documentation for `Bert`. """ # Client API @doc """ Starts the GenServer, which loads the BERT model and tokenizer. """ def start_link(opts) do Logger.info("Starting GenServer") GenServer.start_link(__MODULE__, :ok, opts) end @doc """ Requests the text embedding for a given string. """ def embed(server_pid, text) do Logger.info("Calling embed") GenServer.call(server_pid, {:embed, text}) end # Server Callbacks @impl true def init(:ok) do Logger.info("Calling init") gpu = has_gpu_access?() Logger.info("Has GPU? #{gpu}") # TODO: Get these working as batched runs # Load the model and tokenizer from the Hugging Face repository model_name = "google-bert/bert-base-uncased" {:ok, model_info} = Bumblebee.load_model({:hf, model_name}, architecture: :base) {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, model_name}) # Create the serving function for computing embeddings serving = Bumblebee.Text.TextEmbedding.text_embedding( model_info, tokenizer, defn_options: [compile: EXLA] ) # Create the serving function for computing embeddings serving2 = Bumblebee.Text.TextEmbedding.text_embedding( model_info, tokenizer, defn_options: [compile: EXLA], output_attribute: :hidden_state ) # TODO: This is a bit much currently if gpu do repo = {:hf, "microsoft/Phi-3.5-mini-instruct"} {:ok, model_info} = Bumblebee.load_model(repo, backend: EXLA.Backend) {:ok, tokenizer} = Bumblebee.load_tokenizer(repo) {:ok, generation_config} = Bumblebee.load_generation_config(repo) generation_config = Bumblebee.configure(generation_config, max_new_tokens: 256, strategy: %{type: :multinomial_sampling, top_p: 0.6} ) text_gen = Bumblebee.Text.generation(model_info, tokenizer, generation_config, #compile: [batch_size: 1, sequence_length: 128_000], defn_options: [compiler: EXLA] ) # Store the serving function in the state {:ok, %{flat_serving: serving, all_serving: serving2, text_serving: text_gen}} else {:ok, %{flat_serving: serving, all_serving: serving2}} end end @impl true def handle_call({:embed, text}, _from, state) do Logger.info("Handling single output embedding call") # Run the serving function with the text input result = Nx.Serving.run(state.flat_serving, text) embedding = result.embedding {:reply, embedding, state} end @impl true def handle_call({:embed_all, text}, _from, state) do Logger.info("Handling all token output embedding call") # Run the serving function with the text input result = Nx.Serving.run(state.all_serving, text) embedding = result.embedding {:reply, embedding, state} end @impl true def handle_call({:make_text, text}, _from, state) do Logger.info("Handling text gen call") if has_gpu_access?() == false do {:reply, "No gpu", state} else # Run the serving function with the text input result = Nx.Serving.run(state.text_serving, text) {:reply, result, state} end end @doc """ Example using Axon to get all the output embeddings """ def get_embedding(text) do {:ok, model_info} = Bumblebee.load_model({:hf, "google-bert/bert-base-uncased"}, architecture: :base) {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "google-bert/bert-base-uncased"}) inputs = Bumblebee.apply_tokenizer(tokenizer, text) outputs = Axon.predict(model_info.model, model_info.params, inputs).hidden_state[0] outputs end @doc """ Return if Elixir has access to the GPU or not. """ @spec has_gpu_access? :: boolean() def has_gpu_access?() do try do case Nx.tensor(0) do # :host == CPU %Nx.Tensor{data: %EXLA.Backend{buffer: %EXLA.DeviceBuffer{client_name: :host}}} -> false # :cuda == GPU %Nx.Tensor{data: %EXLA.Backend{buffer: %EXLA.DeviceBuffer{client_name: :cuda}}} -> true _other -> false end rescue _exception -> Logger.error("Error trying to determine GPU access!") false end end end