157 lines
4.3 KiB
Elixir
157 lines
4.3 KiB
Elixir
defmodule Bert do
|
|
use GenServer
|
|
require Logger
|
|
|
|
Nx.global_default_backend(EXLA.Backend)
|
|
|
|
@moduledoc """
|
|
Documentation for `Bert`.
|
|
"""
|
|
|
|
# Client API
|
|
@doc """
|
|
Starts the GenServer, which loads the BERT model and tokenizer.
|
|
"""
|
|
def start_link(opts) do
|
|
Logger.info("Starting GenServer")
|
|
GenServer.start_link(__MODULE__, :ok, opts)
|
|
end
|
|
|
|
@doc """
|
|
Requests the text embedding for a given string.
|
|
"""
|
|
def embed(server_pid, text) do
|
|
Logger.info("Calling embed")
|
|
GenServer.call(server_pid, {:embed, text})
|
|
end
|
|
|
|
# Server Callbacks
|
|
@impl true
|
|
def init(:ok) do
|
|
Logger.info("Calling init")
|
|
|
|
gpu = has_gpu_access?()
|
|
|
|
Logger.info("Has GPU? #{gpu}")
|
|
|
|
# TODO: Get these working as batched runs
|
|
|
|
# Load the model and tokenizer from the Hugging Face repository
|
|
model_name = "google-bert/bert-base-uncased"
|
|
{:ok, model_info} = Bumblebee.load_model({:hf, model_name}, architecture: :base)
|
|
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, model_name})
|
|
|
|
# Create the serving function for computing embeddings
|
|
serving =
|
|
Bumblebee.Text.TextEmbedding.text_embedding(
|
|
model_info,
|
|
tokenizer,
|
|
defn_options: [compile: EXLA]
|
|
)
|
|
|
|
# Create the serving function for computing embeddings
|
|
serving2 =
|
|
Bumblebee.Text.TextEmbedding.text_embedding(
|
|
model_info,
|
|
tokenizer,
|
|
defn_options: [compile: EXLA],
|
|
output_attribute: :hidden_state
|
|
)
|
|
|
|
# TODO: This is a bit much currently
|
|
if gpu do
|
|
repo = {:hf, "microsoft/Phi-3.5-mini-instruct"}
|
|
|
|
{:ok, model_info} = Bumblebee.load_model(repo, backend: EXLA.Backend)
|
|
{:ok, tokenizer} = Bumblebee.load_tokenizer(repo)
|
|
{:ok, generation_config} = Bumblebee.load_generation_config(repo)
|
|
|
|
generation_config =
|
|
Bumblebee.configure(generation_config,
|
|
max_new_tokens: 256,
|
|
strategy: %{type: :multinomial_sampling, top_p: 0.6}
|
|
)
|
|
|
|
text_gen =
|
|
Bumblebee.Text.generation(model_info, tokenizer, generation_config,
|
|
#compile: [batch_size: 1, sequence_length: 128_000],
|
|
defn_options: [compiler: EXLA]
|
|
)
|
|
|
|
# Store the serving function in the state
|
|
{:ok, %{flat_serving: serving, all_serving: serving2, text_serving: text_gen}}
|
|
else
|
|
{:ok, %{flat_serving: serving, all_serving: serving2}}
|
|
end
|
|
end
|
|
|
|
@impl true
|
|
def handle_call({:embed, text}, _from, state) do
|
|
Logger.info("Handling single output embedding call")
|
|
# Run the serving function with the text input
|
|
result = Nx.Serving.run(state.flat_serving, text)
|
|
embedding = result.embedding
|
|
{:reply, embedding, state}
|
|
end
|
|
|
|
@impl true
|
|
def handle_call({:embed_all, text}, _from, state) do
|
|
Logger.info("Handling all token output embedding call")
|
|
# Run the serving function with the text input
|
|
result = Nx.Serving.run(state.all_serving, text)
|
|
embedding = result.embedding
|
|
{:reply, embedding, state}
|
|
end
|
|
|
|
@impl true
|
|
def handle_call({:make_text, text}, _from, state) do
|
|
Logger.info("Handling text gen call")
|
|
|
|
if has_gpu_access?() == false do
|
|
{:reply, "No gpu", state}
|
|
else
|
|
# Run the serving function with the text input
|
|
result = Nx.Serving.run(state.text_serving, text)
|
|
{:reply, result, state}
|
|
end
|
|
end
|
|
|
|
@doc """
|
|
Example using Axon to get all the output embeddings
|
|
"""
|
|
def get_embedding(text) do
|
|
{:ok, model_info} =
|
|
Bumblebee.load_model({:hf, "google-bert/bert-base-uncased"}, architecture: :base)
|
|
|
|
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "google-bert/bert-base-uncased"})
|
|
inputs = Bumblebee.apply_tokenizer(tokenizer, text)
|
|
outputs = Axon.predict(model_info.model, model_info.params, inputs).hidden_state[0]
|
|
outputs
|
|
end
|
|
|
|
@doc """
|
|
Return if Elixir has access to the GPU or not.
|
|
"""
|
|
@spec has_gpu_access? :: boolean()
|
|
def has_gpu_access?() do
|
|
try do
|
|
case Nx.tensor(0) do
|
|
# :host == CPU
|
|
%Nx.Tensor{data: %EXLA.Backend{buffer: %EXLA.DeviceBuffer{client_name: :host}}} ->
|
|
false
|
|
|
|
# :cuda == GPU
|
|
%Nx.Tensor{data: %EXLA.Backend{buffer: %EXLA.DeviceBuffer{client_name: :cuda}}} ->
|
|
true
|
|
|
|
_other ->
|
|
false
|
|
end
|
|
rescue
|
|
_exception ->
|
|
Logger.error("Error trying to determine GPU access!")
|
|
false
|
|
end
|
|
end
|
|
end
|