Initial
This commit is contained in:
156
bert/lib/bert.ex
Normal file
156
bert/lib/bert.ex
Normal file
@@ -0,0 +1,156 @@
|
||||
defmodule Bert do
|
||||
use GenServer
|
||||
require Logger
|
||||
|
||||
Nx.global_default_backend(EXLA.Backend)
|
||||
|
||||
@moduledoc """
|
||||
Documentation for `Bert`.
|
||||
"""
|
||||
|
||||
# Client API
|
||||
@doc """
|
||||
Starts the GenServer, which loads the BERT model and tokenizer.
|
||||
"""
|
||||
def start_link(opts) do
|
||||
Logger.info("Starting GenServer")
|
||||
GenServer.start_link(__MODULE__, :ok, opts)
|
||||
end
|
||||
|
||||
@doc """
|
||||
Requests the text embedding for a given string.
|
||||
"""
|
||||
def embed(server_pid, text) do
|
||||
Logger.info("Calling embed")
|
||||
GenServer.call(server_pid, {:embed, text})
|
||||
end
|
||||
|
||||
# Server Callbacks
|
||||
@impl true
|
||||
def init(:ok) do
|
||||
Logger.info("Calling init")
|
||||
|
||||
gpu = has_gpu_access?()
|
||||
|
||||
Logger.info("Has GPU? #{gpu}")
|
||||
|
||||
# TODO: Get these working as batched runs
|
||||
|
||||
# Load the model and tokenizer from the Hugging Face repository
|
||||
model_name = "google-bert/bert-base-uncased"
|
||||
{:ok, model_info} = Bumblebee.load_model({:hf, model_name}, architecture: :base)
|
||||
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, model_name})
|
||||
|
||||
# Create the serving function for computing embeddings
|
||||
serving =
|
||||
Bumblebee.Text.TextEmbedding.text_embedding(
|
||||
model_info,
|
||||
tokenizer,
|
||||
defn_options: [compile: EXLA]
|
||||
)
|
||||
|
||||
# Create the serving function for computing embeddings
|
||||
serving2 =
|
||||
Bumblebee.Text.TextEmbedding.text_embedding(
|
||||
model_info,
|
||||
tokenizer,
|
||||
defn_options: [compile: EXLA],
|
||||
output_attribute: :hidden_state
|
||||
)
|
||||
|
||||
# TODO: This is a bit much currently
|
||||
if gpu do
|
||||
repo = {:hf, "microsoft/Phi-3.5-mini-instruct"}
|
||||
|
||||
{:ok, model_info} = Bumblebee.load_model(repo, backend: EXLA.Backend)
|
||||
{:ok, tokenizer} = Bumblebee.load_tokenizer(repo)
|
||||
{:ok, generation_config} = Bumblebee.load_generation_config(repo)
|
||||
|
||||
generation_config =
|
||||
Bumblebee.configure(generation_config,
|
||||
max_new_tokens: 256,
|
||||
strategy: %{type: :multinomial_sampling, top_p: 0.6}
|
||||
)
|
||||
|
||||
text_gen =
|
||||
Bumblebee.Text.generation(model_info, tokenizer, generation_config,
|
||||
#compile: [batch_size: 1, sequence_length: 128_000],
|
||||
defn_options: [compiler: EXLA]
|
||||
)
|
||||
|
||||
# Store the serving function in the state
|
||||
{:ok, %{flat_serving: serving, all_serving: serving2, text_serving: text_gen}}
|
||||
else
|
||||
{:ok, %{flat_serving: serving, all_serving: serving2}}
|
||||
end
|
||||
end
|
||||
|
||||
@impl true
|
||||
def handle_call({:embed, text}, _from, state) do
|
||||
Logger.info("Handling single output embedding call")
|
||||
# Run the serving function with the text input
|
||||
result = Nx.Serving.run(state.flat_serving, text)
|
||||
embedding = result.embedding
|
||||
{:reply, embedding, state}
|
||||
end
|
||||
|
||||
@impl true
|
||||
def handle_call({:embed_all, text}, _from, state) do
|
||||
Logger.info("Handling all token output embedding call")
|
||||
# Run the serving function with the text input
|
||||
result = Nx.Serving.run(state.all_serving, text)
|
||||
embedding = result.embedding
|
||||
{:reply, embedding, state}
|
||||
end
|
||||
|
||||
@impl true
|
||||
def handle_call({:make_text, text}, _from, state) do
|
||||
Logger.info("Handling text gen call")
|
||||
|
||||
if has_gpu_access?() == false do
|
||||
{:reply, "No gpu", state}
|
||||
else
|
||||
# Run the serving function with the text input
|
||||
result = Nx.Serving.run(state.text_serving, text)
|
||||
{:reply, result, state}
|
||||
end
|
||||
end
|
||||
|
||||
@doc """
|
||||
Example using Axon to get all the output embeddings
|
||||
"""
|
||||
def get_embedding(text) do
|
||||
{:ok, model_info} =
|
||||
Bumblebee.load_model({:hf, "google-bert/bert-base-uncased"}, architecture: :base)
|
||||
|
||||
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "google-bert/bert-base-uncased"})
|
||||
inputs = Bumblebee.apply_tokenizer(tokenizer, text)
|
||||
outputs = Axon.predict(model_info.model, model_info.params, inputs).hidden_state[0]
|
||||
outputs
|
||||
end
|
||||
|
||||
@doc """
|
||||
Return if Elixir has access to the GPU or not.
|
||||
"""
|
||||
@spec has_gpu_access? :: boolean()
|
||||
def has_gpu_access?() do
|
||||
try do
|
||||
case Nx.tensor(0) do
|
||||
# :host == CPU
|
||||
%Nx.Tensor{data: %EXLA.Backend{buffer: %EXLA.DeviceBuffer{client_name: :host}}} ->
|
||||
false
|
||||
|
||||
# :cuda == GPU
|
||||
%Nx.Tensor{data: %EXLA.Backend{buffer: %EXLA.DeviceBuffer{client_name: :cuda}}} ->
|
||||
true
|
||||
|
||||
_other ->
|
||||
false
|
||||
end
|
||||
rescue
|
||||
_exception ->
|
||||
Logger.error("Error trying to determine GPU access!")
|
||||
false
|
||||
end
|
||||
end
|
||||
end
|
||||
Reference in New Issue
Block a user