Files
bertex/bert/lib/bert.ex
2025-12-05 16:59:24 -05:00

157 lines
4.3 KiB
Elixir

defmodule Bert do
use GenServer
require Logger
Nx.global_default_backend(EXLA.Backend)
@moduledoc """
Documentation for `Bert`.
"""
# Client API
@doc """
Starts the GenServer, which loads the BERT model and tokenizer.
"""
def start_link(opts) do
Logger.info("Starting GenServer")
GenServer.start_link(__MODULE__, :ok, opts)
end
@doc """
Requests the text embedding for a given string.
"""
def embed(server_pid, text) do
Logger.info("Calling embed")
GenServer.call(server_pid, {:embed, text})
end
# Server Callbacks
@impl true
def init(:ok) do
Logger.info("Calling init")
gpu = has_gpu_access?()
Logger.info("Has GPU? #{gpu}")
# TODO: Get these working as batched runs
# Load the model and tokenizer from the Hugging Face repository
model_name = "google-bert/bert-base-uncased"
{:ok, model_info} = Bumblebee.load_model({:hf, model_name}, architecture: :base)
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, model_name})
# Create the serving function for computing embeddings
serving =
Bumblebee.Text.TextEmbedding.text_embedding(
model_info,
tokenizer,
defn_options: [compile: EXLA]
)
# Create the serving function for computing embeddings
serving2 =
Bumblebee.Text.TextEmbedding.text_embedding(
model_info,
tokenizer,
defn_options: [compile: EXLA],
output_attribute: :hidden_state
)
# TODO: This is a bit much currently
if gpu do
repo = {:hf, "microsoft/Phi-3.5-mini-instruct"}
{:ok, model_info} = Bumblebee.load_model(repo, backend: EXLA.Backend)
{:ok, tokenizer} = Bumblebee.load_tokenizer(repo)
{:ok, generation_config} = Bumblebee.load_generation_config(repo)
generation_config =
Bumblebee.configure(generation_config,
max_new_tokens: 256,
strategy: %{type: :multinomial_sampling, top_p: 0.6}
)
text_gen =
Bumblebee.Text.generation(model_info, tokenizer, generation_config,
#compile: [batch_size: 1, sequence_length: 128_000],
defn_options: [compiler: EXLA]
)
# Store the serving function in the state
{:ok, %{flat_serving: serving, all_serving: serving2, text_serving: text_gen}}
else
{:ok, %{flat_serving: serving, all_serving: serving2}}
end
end
@impl true
def handle_call({:embed, text}, _from, state) do
Logger.info("Handling single output embedding call")
# Run the serving function with the text input
result = Nx.Serving.run(state.flat_serving, text)
embedding = result.embedding
{:reply, embedding, state}
end
@impl true
def handle_call({:embed_all, text}, _from, state) do
Logger.info("Handling all token output embedding call")
# Run the serving function with the text input
result = Nx.Serving.run(state.all_serving, text)
embedding = result.embedding
{:reply, embedding, state}
end
@impl true
def handle_call({:make_text, text}, _from, state) do
Logger.info("Handling text gen call")
if has_gpu_access?() == false do
{:reply, "No gpu", state}
else
# Run the serving function with the text input
result = Nx.Serving.run(state.text_serving, text)
{:reply, result, state}
end
end
@doc """
Example using Axon to get all the output embeddings
"""
def get_embedding(text) do
{:ok, model_info} =
Bumblebee.load_model({:hf, "google-bert/bert-base-uncased"}, architecture: :base)
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "google-bert/bert-base-uncased"})
inputs = Bumblebee.apply_tokenizer(tokenizer, text)
outputs = Axon.predict(model_info.model, model_info.params, inputs).hidden_state[0]
outputs
end
@doc """
Return if Elixir has access to the GPU or not.
"""
@spec has_gpu_access? :: boolean()
def has_gpu_access?() do
try do
case Nx.tensor(0) do
# :host == CPU
%Nx.Tensor{data: %EXLA.Backend{buffer: %EXLA.DeviceBuffer{client_name: :host}}} ->
false
# :cuda == GPU
%Nx.Tensor{data: %EXLA.Backend{buffer: %EXLA.DeviceBuffer{client_name: :cuda}}} ->
true
_other ->
false
end
rescue
_exception ->
Logger.error("Error trying to determine GPU access!")
false
end
end
end