This commit is contained in:
2025-12-05 16:59:24 -05:00
commit 1fc90a0319
13 changed files with 298 additions and 0 deletions

2
README.md Normal file
View File

@@ -0,0 +1,2 @@
# bertex

4
bert/.formatter.exs Normal file
View File

@@ -0,0 +1,4 @@
# Used by "mix format"
[
inputs: ["{mix,.formatter}.exs", "{config,lib,test}/**/*.{ex,exs}"]
]

23
bert/.gitignore vendored Normal file
View File

@@ -0,0 +1,23 @@
# The directory Mix will write compiled artifacts to.
/_build/
# If you run "mix test --cover", coverage assets end up here.
/cover/
# The directory Mix downloads your dependencies sources to.
/deps/
# Where third-party dependencies like ExDoc output generated docs.
/doc/
# If the VM crashes, it generates a dump, let's ignore it too.
erl_crash.dump
# Also ignore archive artifacts (built via "mix archive.build").
*.ez
# Ignore package tarball (built via "mix hex.build").
bert-*.tar
# Temporary files, for example, from tests.
/tmp/

21
bert/README.md Normal file
View File

@@ -0,0 +1,21 @@
# Bert
**TODO: Add description**
## Installation
If [available in Hex](https://hex.pm/docs/publish), the package can be installed
by adding `bert` to your list of dependencies in `mix.exs`:
```elixir
def deps do
[
{:bert, "~> 0.1.0"}
]
end
```
Documentation can be generated with [ExDoc](https://github.com/elixir-lang/ex_doc)
and published on [HexDocs](https://hexdocs.pm). Once published, the docs can
be found at <https://hexdocs.pm/bert>.

4
bert/call.exs Normal file
View File

@@ -0,0 +1,4 @@
#elixir --sname node1 -S mix run --no-halt
#iex --sname node2
Node.connect(:"node1@kittykat")
GenServer.call({BertEmbedding, :"node1@kittykat"}, {:embed, "I'm going to embed this text"})

3
bert/config/config.exs Normal file
View File

@@ -0,0 +1,3 @@
import Config
config :nx, default_backend: EXLA.Backend

156
bert/lib/bert.ex Normal file
View File

@@ -0,0 +1,156 @@
defmodule Bert do
use GenServer
require Logger
Nx.global_default_backend(EXLA.Backend)
@moduledoc """
Documentation for `Bert`.
"""
# Client API
@doc """
Starts the GenServer, which loads the BERT model and tokenizer.
"""
def start_link(opts) do
Logger.info("Starting GenServer")
GenServer.start_link(__MODULE__, :ok, opts)
end
@doc """
Requests the text embedding for a given string.
"""
def embed(server_pid, text) do
Logger.info("Calling embed")
GenServer.call(server_pid, {:embed, text})
end
# Server Callbacks
@impl true
def init(:ok) do
Logger.info("Calling init")
gpu = has_gpu_access?()
Logger.info("Has GPU? #{gpu}")
# TODO: Get these working as batched runs
# Load the model and tokenizer from the Hugging Face repository
model_name = "google-bert/bert-base-uncased"
{:ok, model_info} = Bumblebee.load_model({:hf, model_name}, architecture: :base)
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, model_name})
# Create the serving function for computing embeddings
serving =
Bumblebee.Text.TextEmbedding.text_embedding(
model_info,
tokenizer,
defn_options: [compile: EXLA]
)
# Create the serving function for computing embeddings
serving2 =
Bumblebee.Text.TextEmbedding.text_embedding(
model_info,
tokenizer,
defn_options: [compile: EXLA],
output_attribute: :hidden_state
)
# TODO: This is a bit much currently
if gpu do
repo = {:hf, "microsoft/Phi-3.5-mini-instruct"}
{:ok, model_info} = Bumblebee.load_model(repo, backend: EXLA.Backend)
{:ok, tokenizer} = Bumblebee.load_tokenizer(repo)
{:ok, generation_config} = Bumblebee.load_generation_config(repo)
generation_config =
Bumblebee.configure(generation_config,
max_new_tokens: 256,
strategy: %{type: :multinomial_sampling, top_p: 0.6}
)
text_gen =
Bumblebee.Text.generation(model_info, tokenizer, generation_config,
#compile: [batch_size: 1, sequence_length: 128_000],
defn_options: [compiler: EXLA]
)
# Store the serving function in the state
{:ok, %{flat_serving: serving, all_serving: serving2, text_serving: text_gen}}
else
{:ok, %{flat_serving: serving, all_serving: serving2}}
end
end
@impl true
def handle_call({:embed, text}, _from, state) do
Logger.info("Handling single output embedding call")
# Run the serving function with the text input
result = Nx.Serving.run(state.flat_serving, text)
embedding = result.embedding
{:reply, embedding, state}
end
@impl true
def handle_call({:embed_all, text}, _from, state) do
Logger.info("Handling all token output embedding call")
# Run the serving function with the text input
result = Nx.Serving.run(state.all_serving, text)
embedding = result.embedding
{:reply, embedding, state}
end
@impl true
def handle_call({:make_text, text}, _from, state) do
Logger.info("Handling text gen call")
if has_gpu_access?() == false do
{:reply, "No gpu", state}
else
# Run the serving function with the text input
result = Nx.Serving.run(state.text_serving, text)
{:reply, result, state}
end
end
@doc """
Example using Axon to get all the output embeddings
"""
def get_embedding(text) do
{:ok, model_info} =
Bumblebee.load_model({:hf, "google-bert/bert-base-uncased"}, architecture: :base)
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "google-bert/bert-base-uncased"})
inputs = Bumblebee.apply_tokenizer(tokenizer, text)
outputs = Axon.predict(model_info.model, model_info.params, inputs).hidden_state[0]
outputs
end
@doc """
Return if Elixir has access to the GPU or not.
"""
@spec has_gpu_access? :: boolean()
def has_gpu_access?() do
try do
case Nx.tensor(0) do
# :host == CPU
%Nx.Tensor{data: %EXLA.Backend{buffer: %EXLA.DeviceBuffer{client_name: :host}}} ->
false
# :cuda == GPU
%Nx.Tensor{data: %EXLA.Backend{buffer: %EXLA.DeviceBuffer{client_name: :cuda}}} ->
true
_other ->
false
end
rescue
_exception ->
Logger.error("Error trying to determine GPU access!")
false
end
end
end

View File

@@ -0,0 +1,20 @@
defmodule Bert.Application do
# See https://hexdocs.pm/elixir/Application.html
# for more information on OTP Applications
@moduledoc false
use Application
@impl true
def start(_type, _args) do
children = [
# Starts a worker by calling: Bert.Worker.start_link(arg)
{Bert, name: BertEmbedding}
]
# See https://hexdocs.pm/elixir/Supervisor.html
# for other strategies and supported options
opts = [strategy: :one_for_one, name: Bert.Supervisor]
Supervisor.start_link(children, opts)
end
end

31
bert/mix.exs Normal file
View File

@@ -0,0 +1,31 @@
defmodule Bert.MixProject do
use Mix.Project
def project do
[
app: :bert,
version: "0.1.0",
elixir: "~> 1.18",
start_permanent: Mix.env() == :prod,
deps: deps()
]
end
# Run "mix help compile.app" to learn about applications.
def application do
[
extra_applications: [:logger],
mod: {Bert.Application, []}
]
end
# Run "mix help deps" to learn about dependencies.
defp deps do
[
{:bumblebee, "~> 0.6.0"},
{:exla, ">= 0.0.0"}
# {:dep_from_hexpm, "~> 0.3.0"},
# {:dep_from_git, git: "https://github.com/elixir-lang/my_dep.git", tag: "0.1.0"}
]
end
end

24
bert/mix.lock Normal file
View File

@@ -0,0 +1,24 @@
%{
"axon": {:hex, :axon, "0.7.0", "2e2c6d93b4afcfa812566b8922204fa022b60081e86ebd411df4db7ea30f5457", [:mix], [{:kino, "~> 0.7", [hex: :kino, repo: "hexpm", optional: true]}, {:kino_vega_lite, "~> 0.1.7", [hex: :kino_vega_lite, repo: "hexpm", optional: true]}, {:nx, "~> 0.9", [hex: :nx, repo: "hexpm", optional: false]}, {:polaris, "~> 0.1", [hex: :polaris, repo: "hexpm", optional: false]}, {:table_rex, "~> 3.1.1", [hex: :table_rex, repo: "hexpm", optional: true]}], "hexpm", "ee9857a143c9486597ceff434e6ca833dc1241be6158b01025b8217757ed1036"},
"bumblebee": {:hex, :bumblebee, "0.6.3", "c0028643c92de93258a9804da1d4d48797eaf7911b702464b3b3dd2cc7f938f1", [:mix], [{:axon, "~> 0.7.0", [hex: :axon, repo: "hexpm", optional: false]}, {:jason, "~> 1.4.0", [hex: :jason, repo: "hexpm", optional: false]}, {:nx, "~> 0.9.0 or ~> 0.10.0", [hex: :nx, repo: "hexpm", optional: false]}, {:nx_image, "~> 0.1.0", [hex: :nx_image, repo: "hexpm", optional: false]}, {:nx_signal, "~> 0.2.0", [hex: :nx_signal, repo: "hexpm", optional: false]}, {:progress_bar, "~> 3.0", [hex: :progress_bar, repo: "hexpm", optional: false]}, {:safetensors, "~> 0.1.3", [hex: :safetensors, repo: "hexpm", optional: false]}, {:tokenizers, "~> 0.4", [hex: :tokenizers, repo: "hexpm", optional: false]}, {:unpickler, "~> 0.1.0", [hex: :unpickler, repo: "hexpm", optional: false]}, {:unzip, "~> 0.12.0", [hex: :unzip, repo: "hexpm", optional: false]}], "hexpm", "c619197787561f8e5fb2ffba269c341654accaec9d591999b7fddd55761dd079"},
"castore": {:hex, :castore, "1.0.15", "8aa930c890fe18b6fe0a0cff27b27d0d4d231867897bd23ea772dee561f032a3", [:mix], [], "hexpm", "96ce4c69d7d5d7a0761420ef743e2f4096253931a3ba69e5ff8ef1844fe446d3"},
"complex": {:hex, :complex, "0.6.0", "b0130086a7a8c33574d293b2e0e250f4685580418eac52a5658a4bd148f3ccf1", [:mix], [], "hexpm", "0a5fa95580dcaf30fcd60fe1aaf24327c0fe401e98c24d892e172e79498269f9"},
"decimal": {:hex, :decimal, "2.3.0", "3ad6255aa77b4a3c4f818171b12d237500e63525c2fd056699967a3e7ea20f62", [:mix], [], "hexpm", "a4d66355cb29cb47c3cf30e71329e58361cfcb37c34235ef3bf1d7bf3773aeac"},
"elixir_make": {:hex, :elixir_make, "0.9.0", "6484b3cd8c0cee58f09f05ecaf1a140a8c97670671a6a0e7ab4dc326c3109726", [:mix], [], "hexpm", "db23d4fd8b757462ad02f8aa73431a426fe6671c80b200d9710caf3d1dd0ffdb"},
"exla": {:hex, :exla, "0.10.0", "93e7d75a774fbc06ce05b96de20c4b01bda413b315238cb3c727c09a05d2bc3a", [:make, :mix], [{:elixir_make, "~> 0.6", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:fine, "~> 0.1.0", [hex: :fine, repo: "hexpm", optional: false]}, {:nimble_pool, "~> 1.0", [hex: :nimble_pool, repo: "hexpm", optional: false]}, {:nx, "~> 0.10.0", [hex: :nx, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}, {:xla, "~> 0.9.0", [hex: :xla, repo: "hexpm", optional: false]}], "hexpm", "16fffdb64667d7f0a3bc683fdcd2792b143a9b345e4b1f1d5cd50330c63d8119"},
"fine": {:hex, :fine, "0.1.4", "b19a89c1476c7c57afb5f9314aed5960b5bc95d5277de4cb5ee8e1d1616ce379", [:mix], [], "hexpm", "be3324cc454a42d80951cf6023b9954e9ff27c6daa255483b3e8d608670303f5"},
"jason": {:hex, :jason, "1.4.4", "b9226785a9aa77b6857ca22832cffa5d5011a667207eb2a0ad56adb5db443b8a", [:mix], [{:decimal, "~> 1.0 or ~> 2.0", [hex: :decimal, repo: "hexpm", optional: true]}], "hexpm", "c5eb0cab91f094599f94d55bc63409236a8ec69a21a67814529e8d5f6cc90b3b"},
"nimble_pool": {:hex, :nimble_pool, "1.1.0", "bf9c29fbdcba3564a8b800d1eeb5a3c58f36e1e11d7b7fb2e084a643f645f06b", [:mix], [], "hexpm", "af2e4e6b34197db81f7aad230c1118eac993acc0dae6bc83bac0126d4ae0813a"},
"nx": {:hex, :nx, "0.10.0", "128e4a094cb790f663e20e1334b127c1f2a4df54edfb8b13c22757ec33133b4f", [:mix], [{:complex, "~> 0.6", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "3db8892c124aeee091df0e6fbf8e5bf1b81f502eb0d4f5ba63e6378ebcae7da4"},
"nx_image": {:hex, :nx_image, "0.1.2", "0c6e3453c1dc30fc80c723a54861204304cebc8a89ed3b806b972c73ee5d119d", [:mix], [{:nx, "~> 0.4", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "9161863c42405ddccb6dbbbeae078ad23e30201509cc804b3b3a7c9e98764b81"},
"nx_signal": {:hex, :nx_signal, "0.2.0", "e1ca0318877b17c81ce8906329f5125f1e2361e4c4235a5baac8a95ee88ea98e", [:mix], [{:nx, "~> 0.6", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "7247e5e18a177a59c4cb5355952900c62fdeadeb2bad02a9a34237b68744e2bb"},
"polaris": {:hex, :polaris, "0.1.0", "dca61b18e3e801ecdae6ac9f0eca5f19792b44a5cb4b8d63db50fc40fc038d22", [:mix], [{:nx, "~> 0.5", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "13ef2b166650e533cb24b10e2f3b8ab4f2f449ba4d63156e8c569527f206e2c2"},
"progress_bar": {:hex, :progress_bar, "3.0.0", "f54ff038c2ac540cfbb4c2bfe97c75e7116ead044f3c2b10c9f212452194b5cd", [:mix], [{:decimal, "~> 2.0", [hex: :decimal, repo: "hexpm", optional: false]}], "hexpm", "6981c2b25ab24aecc91a2dc46623658e1399c21a2ae24db986b90d678530f2b7"},
"rustler_precompiled": {:hex, :rustler_precompiled, "0.8.3", "4e741024b0b097fe783add06e53ae9a6f23ddc78df1010f215df0c02915ef5a8", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: false]}, {:rustler, "~> 0.23", [hex: :rustler, repo: "hexpm", optional: true]}], "hexpm", "c23f5f33cb6608542de4d04faf0f0291458c352a4648e4d28d17ee1098cddcc4"},
"safetensors": {:hex, :safetensors, "0.1.3", "7ff3c22391e213289c713898481d492c9c28a49ab1d0705b72630fb8360426b2", [:mix], [{:jason, "~> 1.4", [hex: :jason, repo: "hexpm", optional: false]}, {:nx, "~> 0.5", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "fe50b53ea59fde4e723dd1a2e31cfdc6013e69343afac84c6be86d6d7c562c14"},
"telemetry": {:hex, :telemetry, "1.3.0", "fedebbae410d715cf8e7062c96a1ef32ec22e764197f70cda73d82778d61e7a2", [:rebar3], [], "hexpm", "7015fc8919dbe63764f4b4b87a95b7c0996bd539e0d499be6ec9d7f3875b79e6"},
"tokenizers": {:hex, :tokenizers, "0.5.1", "b0975d92b4ee5b18e8f47b5d65b9d5f1e583d9130189b1a2620401af4e7d4b35", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: false]}, {:rustler, ">= 0.0.0", [hex: :rustler, repo: "hexpm", optional: true]}, {:rustler_precompiled, "~> 0.6", [hex: :rustler_precompiled, repo: "hexpm", optional: false]}], "hexpm", "5f08d97cc7f2ed3d71d370d68120da6d3de010948ccf676c9c0eb591ba4bacc9"},
"unpickler": {:hex, :unpickler, "0.1.0", "c2262c0819e6985b761e7107546cef96a485f401816be5304a65fdd200d5bd6a", [:mix], [], "hexpm", "e2b3f61e62406187ac52afead8a63bfb4e49394028993f3c4c42712743cab79e"},
"unzip": {:hex, :unzip, "0.12.0", "beed92238724732418b41eba77dcb7f51e235b707406c05b1732a3052d1c0f36", [:mix], [], "hexpm", "95655b72db368e5a84951f0bed586ac053b55ee3815fd96062fce10ce4fc998d"},
"xla": {:hex, :xla, "0.9.1", "cca0040ff94902764007a118871bfc667f1a0085d4a5074533a47d6b58bec61e", [:make, :mix], [{:elixir_make, "~> 0.4", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "eb5e443ae5391b1953f253e051f2307bea183b59acee138053a9300779930daf"},
}

1
bert/start.txt Normal file
View File

@@ -0,0 +1 @@
elixir --sname gpu -S mix run --no-halt

8
bert/test/bert_test.exs Normal file
View File

@@ -0,0 +1,8 @@
defmodule BertTest do
use ExUnit.Case
doctest Bert
test "greets the world" do
assert Bert.hello() == :world
end
end

View File

@@ -0,0 +1 @@
ExUnit.start()