commit 1fc90a03193e95c218c4b6753f85e145a77e27ea Author: Jeffrey Ward Date: Fri Dec 5 16:59:24 2025 -0500 Initial diff --git a/README.md b/README.md new file mode 100644 index 0000000..3a6c670 --- /dev/null +++ b/README.md @@ -0,0 +1,2 @@ +# bertex + diff --git a/bert/.formatter.exs b/bert/.formatter.exs new file mode 100644 index 0000000..d2cda26 --- /dev/null +++ b/bert/.formatter.exs @@ -0,0 +1,4 @@ +# Used by "mix format" +[ + inputs: ["{mix,.formatter}.exs", "{config,lib,test}/**/*.{ex,exs}"] +] diff --git a/bert/.gitignore b/bert/.gitignore new file mode 100644 index 0000000..bd6a041 --- /dev/null +++ b/bert/.gitignore @@ -0,0 +1,23 @@ +# The directory Mix will write compiled artifacts to. +/_build/ + +# If you run "mix test --cover", coverage assets end up here. +/cover/ + +# The directory Mix downloads your dependencies sources to. +/deps/ + +# Where third-party dependencies like ExDoc output generated docs. +/doc/ + +# If the VM crashes, it generates a dump, let's ignore it too. +erl_crash.dump + +# Also ignore archive artifacts (built via "mix archive.build"). +*.ez + +# Ignore package tarball (built via "mix hex.build"). +bert-*.tar + +# Temporary files, for example, from tests. +/tmp/ diff --git a/bert/README.md b/bert/README.md new file mode 100644 index 0000000..58a1ccd --- /dev/null +++ b/bert/README.md @@ -0,0 +1,21 @@ +# Bert + +**TODO: Add description** + +## Installation + +If [available in Hex](https://hex.pm/docs/publish), the package can be installed +by adding `bert` to your list of dependencies in `mix.exs`: + +```elixir +def deps do + [ + {:bert, "~> 0.1.0"} + ] +end +``` + +Documentation can be generated with [ExDoc](https://github.com/elixir-lang/ex_doc) +and published on [HexDocs](https://hexdocs.pm). Once published, the docs can +be found at . + diff --git a/bert/call.exs b/bert/call.exs new file mode 100644 index 0000000..eb2f4f7 --- /dev/null +++ b/bert/call.exs @@ -0,0 +1,4 @@ +#elixir --sname node1 -S mix run --no-halt +#iex --sname node2 +Node.connect(:"node1@kittykat") +GenServer.call({BertEmbedding, :"node1@kittykat"}, {:embed, "I'm going to embed this text"}) diff --git a/bert/config/config.exs b/bert/config/config.exs new file mode 100644 index 0000000..8e3a070 --- /dev/null +++ b/bert/config/config.exs @@ -0,0 +1,3 @@ +import Config + +config :nx, default_backend: EXLA.Backend diff --git a/bert/lib/bert.ex b/bert/lib/bert.ex new file mode 100644 index 0000000..2a63f90 --- /dev/null +++ b/bert/lib/bert.ex @@ -0,0 +1,156 @@ +defmodule Bert do + use GenServer + require Logger + + Nx.global_default_backend(EXLA.Backend) + + @moduledoc """ + Documentation for `Bert`. + """ + + # Client API + @doc """ + Starts the GenServer, which loads the BERT model and tokenizer. + """ + def start_link(opts) do + Logger.info("Starting GenServer") + GenServer.start_link(__MODULE__, :ok, opts) + end + + @doc """ + Requests the text embedding for a given string. + """ + def embed(server_pid, text) do + Logger.info("Calling embed") + GenServer.call(server_pid, {:embed, text}) + end + + # Server Callbacks + @impl true + def init(:ok) do + Logger.info("Calling init") + + gpu = has_gpu_access?() + + Logger.info("Has GPU? #{gpu}") + + # TODO: Get these working as batched runs + + # Load the model and tokenizer from the Hugging Face repository + model_name = "google-bert/bert-base-uncased" + {:ok, model_info} = Bumblebee.load_model({:hf, model_name}, architecture: :base) + {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, model_name}) + + # Create the serving function for computing embeddings + serving = + Bumblebee.Text.TextEmbedding.text_embedding( + model_info, + tokenizer, + defn_options: [compile: EXLA] + ) + + # Create the serving function for computing embeddings + serving2 = + Bumblebee.Text.TextEmbedding.text_embedding( + model_info, + tokenizer, + defn_options: [compile: EXLA], + output_attribute: :hidden_state + ) + + # TODO: This is a bit much currently + if gpu do + repo = {:hf, "microsoft/Phi-3.5-mini-instruct"} + + {:ok, model_info} = Bumblebee.load_model(repo, backend: EXLA.Backend) + {:ok, tokenizer} = Bumblebee.load_tokenizer(repo) + {:ok, generation_config} = Bumblebee.load_generation_config(repo) + + generation_config = + Bumblebee.configure(generation_config, + max_new_tokens: 256, + strategy: %{type: :multinomial_sampling, top_p: 0.6} + ) + + text_gen = + Bumblebee.Text.generation(model_info, tokenizer, generation_config, + #compile: [batch_size: 1, sequence_length: 128_000], + defn_options: [compiler: EXLA] + ) + + # Store the serving function in the state + {:ok, %{flat_serving: serving, all_serving: serving2, text_serving: text_gen}} + else + {:ok, %{flat_serving: serving, all_serving: serving2}} + end + end + + @impl true + def handle_call({:embed, text}, _from, state) do + Logger.info("Handling single output embedding call") + # Run the serving function with the text input + result = Nx.Serving.run(state.flat_serving, text) + embedding = result.embedding + {:reply, embedding, state} + end + + @impl true + def handle_call({:embed_all, text}, _from, state) do + Logger.info("Handling all token output embedding call") + # Run the serving function with the text input + result = Nx.Serving.run(state.all_serving, text) + embedding = result.embedding + {:reply, embedding, state} + end + + @impl true + def handle_call({:make_text, text}, _from, state) do + Logger.info("Handling text gen call") + + if has_gpu_access?() == false do + {:reply, "No gpu", state} + else + # Run the serving function with the text input + result = Nx.Serving.run(state.text_serving, text) + {:reply, result, state} + end + end + + @doc """ + Example using Axon to get all the output embeddings + """ + def get_embedding(text) do + {:ok, model_info} = + Bumblebee.load_model({:hf, "google-bert/bert-base-uncased"}, architecture: :base) + + {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "google-bert/bert-base-uncased"}) + inputs = Bumblebee.apply_tokenizer(tokenizer, text) + outputs = Axon.predict(model_info.model, model_info.params, inputs).hidden_state[0] + outputs + end + + @doc """ + Return if Elixir has access to the GPU or not. + """ + @spec has_gpu_access? :: boolean() + def has_gpu_access?() do + try do + case Nx.tensor(0) do + # :host == CPU + %Nx.Tensor{data: %EXLA.Backend{buffer: %EXLA.DeviceBuffer{client_name: :host}}} -> + false + + # :cuda == GPU + %Nx.Tensor{data: %EXLA.Backend{buffer: %EXLA.DeviceBuffer{client_name: :cuda}}} -> + true + + _other -> + false + end + rescue + _exception -> + Logger.error("Error trying to determine GPU access!") + false + end + end +end diff --git a/bert/lib/bert/application.ex b/bert/lib/bert/application.ex new file mode 100644 index 0000000..cbaa716 --- /dev/null +++ b/bert/lib/bert/application.ex @@ -0,0 +1,20 @@ +defmodule Bert.Application do + # See https://hexdocs.pm/elixir/Application.html + # for more information on OTP Applications + @moduledoc false + + use Application + + @impl true + def start(_type, _args) do + children = [ + # Starts a worker by calling: Bert.Worker.start_link(arg) + {Bert, name: BertEmbedding} + ] + + # See https://hexdocs.pm/elixir/Supervisor.html + # for other strategies and supported options + opts = [strategy: :one_for_one, name: Bert.Supervisor] + Supervisor.start_link(children, opts) + end +end diff --git a/bert/mix.exs b/bert/mix.exs new file mode 100644 index 0000000..4a2e637 --- /dev/null +++ b/bert/mix.exs @@ -0,0 +1,31 @@ +defmodule Bert.MixProject do + use Mix.Project + + def project do + [ + app: :bert, + version: "0.1.0", + elixir: "~> 1.18", + start_permanent: Mix.env() == :prod, + deps: deps() + ] + end + + # Run "mix help compile.app" to learn about applications. + def application do + [ + extra_applications: [:logger], + mod: {Bert.Application, []} + ] + end + + # Run "mix help deps" to learn about dependencies. + defp deps do + [ + {:bumblebee, "~> 0.6.0"}, + {:exla, ">= 0.0.0"} + # {:dep_from_hexpm, "~> 0.3.0"}, + # {:dep_from_git, git: "https://github.com/elixir-lang/my_dep.git", tag: "0.1.0"} + ] + end +end diff --git a/bert/mix.lock b/bert/mix.lock new file mode 100644 index 0000000..e841041 --- /dev/null +++ b/bert/mix.lock @@ -0,0 +1,24 @@ +%{ + "axon": {:hex, :axon, "0.7.0", "2e2c6d93b4afcfa812566b8922204fa022b60081e86ebd411df4db7ea30f5457", [:mix], [{:kino, "~> 0.7", [hex: :kino, repo: "hexpm", optional: true]}, {:kino_vega_lite, "~> 0.1.7", [hex: :kino_vega_lite, repo: "hexpm", optional: true]}, {:nx, "~> 0.9", [hex: :nx, repo: "hexpm", optional: false]}, {:polaris, "~> 0.1", [hex: :polaris, repo: "hexpm", optional: false]}, {:table_rex, "~> 3.1.1", [hex: :table_rex, repo: "hexpm", optional: true]}], "hexpm", "ee9857a143c9486597ceff434e6ca833dc1241be6158b01025b8217757ed1036"}, + "bumblebee": {:hex, :bumblebee, "0.6.3", "c0028643c92de93258a9804da1d4d48797eaf7911b702464b3b3dd2cc7f938f1", [:mix], [{:axon, "~> 0.7.0", [hex: :axon, repo: "hexpm", optional: false]}, {:jason, "~> 1.4.0", [hex: :jason, repo: "hexpm", optional: false]}, {:nx, "~> 0.9.0 or ~> 0.10.0", [hex: :nx, repo: "hexpm", optional: false]}, {:nx_image, "~> 0.1.0", [hex: :nx_image, repo: "hexpm", optional: false]}, {:nx_signal, "~> 0.2.0", [hex: :nx_signal, repo: "hexpm", optional: false]}, {:progress_bar, "~> 3.0", [hex: :progress_bar, repo: "hexpm", optional: false]}, {:safetensors, "~> 0.1.3", [hex: :safetensors, repo: "hexpm", optional: false]}, {:tokenizers, "~> 0.4", [hex: :tokenizers, repo: "hexpm", optional: false]}, {:unpickler, "~> 0.1.0", [hex: :unpickler, repo: "hexpm", optional: false]}, {:unzip, "~> 0.12.0", [hex: :unzip, repo: "hexpm", optional: false]}], "hexpm", "c619197787561f8e5fb2ffba269c341654accaec9d591999b7fddd55761dd079"}, + "castore": {:hex, :castore, "1.0.15", "8aa930c890fe18b6fe0a0cff27b27d0d4d231867897bd23ea772dee561f032a3", [:mix], [], "hexpm", "96ce4c69d7d5d7a0761420ef743e2f4096253931a3ba69e5ff8ef1844fe446d3"}, + "complex": {:hex, :complex, "0.6.0", "b0130086a7a8c33574d293b2e0e250f4685580418eac52a5658a4bd148f3ccf1", [:mix], [], "hexpm", "0a5fa95580dcaf30fcd60fe1aaf24327c0fe401e98c24d892e172e79498269f9"}, + "decimal": {:hex, :decimal, "2.3.0", "3ad6255aa77b4a3c4f818171b12d237500e63525c2fd056699967a3e7ea20f62", [:mix], [], "hexpm", "a4d66355cb29cb47c3cf30e71329e58361cfcb37c34235ef3bf1d7bf3773aeac"}, + "elixir_make": {:hex, :elixir_make, "0.9.0", "6484b3cd8c0cee58f09f05ecaf1a140a8c97670671a6a0e7ab4dc326c3109726", [:mix], [], "hexpm", "db23d4fd8b757462ad02f8aa73431a426fe6671c80b200d9710caf3d1dd0ffdb"}, + "exla": {:hex, :exla, "0.10.0", "93e7d75a774fbc06ce05b96de20c4b01bda413b315238cb3c727c09a05d2bc3a", [:make, :mix], [{:elixir_make, "~> 0.6", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:fine, "~> 0.1.0", [hex: :fine, repo: "hexpm", optional: false]}, {:nimble_pool, "~> 1.0", [hex: :nimble_pool, repo: "hexpm", optional: false]}, {:nx, "~> 0.10.0", [hex: :nx, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}, {:xla, "~> 0.9.0", [hex: :xla, repo: "hexpm", optional: false]}], "hexpm", "16fffdb64667d7f0a3bc683fdcd2792b143a9b345e4b1f1d5cd50330c63d8119"}, + "fine": {:hex, :fine, "0.1.4", "b19a89c1476c7c57afb5f9314aed5960b5bc95d5277de4cb5ee8e1d1616ce379", [:mix], [], "hexpm", "be3324cc454a42d80951cf6023b9954e9ff27c6daa255483b3e8d608670303f5"}, + "jason": {:hex, :jason, "1.4.4", "b9226785a9aa77b6857ca22832cffa5d5011a667207eb2a0ad56adb5db443b8a", [:mix], [{:decimal, "~> 1.0 or ~> 2.0", [hex: :decimal, repo: "hexpm", optional: true]}], "hexpm", "c5eb0cab91f094599f94d55bc63409236a8ec69a21a67814529e8d5f6cc90b3b"}, + "nimble_pool": {:hex, :nimble_pool, "1.1.0", "bf9c29fbdcba3564a8b800d1eeb5a3c58f36e1e11d7b7fb2e084a643f645f06b", [:mix], [], "hexpm", "af2e4e6b34197db81f7aad230c1118eac993acc0dae6bc83bac0126d4ae0813a"}, + "nx": {:hex, :nx, "0.10.0", "128e4a094cb790f663e20e1334b127c1f2a4df54edfb8b13c22757ec33133b4f", [:mix], [{:complex, "~> 0.6", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "3db8892c124aeee091df0e6fbf8e5bf1b81f502eb0d4f5ba63e6378ebcae7da4"}, + "nx_image": {:hex, :nx_image, "0.1.2", "0c6e3453c1dc30fc80c723a54861204304cebc8a89ed3b806b972c73ee5d119d", [:mix], [{:nx, "~> 0.4", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "9161863c42405ddccb6dbbbeae078ad23e30201509cc804b3b3a7c9e98764b81"}, + "nx_signal": {:hex, :nx_signal, "0.2.0", "e1ca0318877b17c81ce8906329f5125f1e2361e4c4235a5baac8a95ee88ea98e", [:mix], [{:nx, "~> 0.6", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "7247e5e18a177a59c4cb5355952900c62fdeadeb2bad02a9a34237b68744e2bb"}, + "polaris": {:hex, :polaris, "0.1.0", "dca61b18e3e801ecdae6ac9f0eca5f19792b44a5cb4b8d63db50fc40fc038d22", [:mix], [{:nx, "~> 0.5", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "13ef2b166650e533cb24b10e2f3b8ab4f2f449ba4d63156e8c569527f206e2c2"}, + "progress_bar": {:hex, :progress_bar, "3.0.0", "f54ff038c2ac540cfbb4c2bfe97c75e7116ead044f3c2b10c9f212452194b5cd", [:mix], [{:decimal, "~> 2.0", [hex: :decimal, repo: "hexpm", optional: false]}], "hexpm", "6981c2b25ab24aecc91a2dc46623658e1399c21a2ae24db986b90d678530f2b7"}, + "rustler_precompiled": {:hex, :rustler_precompiled, "0.8.3", "4e741024b0b097fe783add06e53ae9a6f23ddc78df1010f215df0c02915ef5a8", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: false]}, {:rustler, "~> 0.23", [hex: :rustler, repo: "hexpm", optional: true]}], "hexpm", "c23f5f33cb6608542de4d04faf0f0291458c352a4648e4d28d17ee1098cddcc4"}, + "safetensors": {:hex, :safetensors, "0.1.3", "7ff3c22391e213289c713898481d492c9c28a49ab1d0705b72630fb8360426b2", [:mix], [{:jason, "~> 1.4", [hex: :jason, repo: "hexpm", optional: false]}, {:nx, "~> 0.5", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "fe50b53ea59fde4e723dd1a2e31cfdc6013e69343afac84c6be86d6d7c562c14"}, + "telemetry": {:hex, :telemetry, "1.3.0", "fedebbae410d715cf8e7062c96a1ef32ec22e764197f70cda73d82778d61e7a2", [:rebar3], [], "hexpm", "7015fc8919dbe63764f4b4b87a95b7c0996bd539e0d499be6ec9d7f3875b79e6"}, + "tokenizers": {:hex, :tokenizers, "0.5.1", "b0975d92b4ee5b18e8f47b5d65b9d5f1e583d9130189b1a2620401af4e7d4b35", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: false]}, {:rustler, ">= 0.0.0", [hex: :rustler, repo: "hexpm", optional: true]}, {:rustler_precompiled, "~> 0.6", [hex: :rustler_precompiled, repo: "hexpm", optional: false]}], "hexpm", "5f08d97cc7f2ed3d71d370d68120da6d3de010948ccf676c9c0eb591ba4bacc9"}, + "unpickler": {:hex, :unpickler, "0.1.0", "c2262c0819e6985b761e7107546cef96a485f401816be5304a65fdd200d5bd6a", [:mix], [], "hexpm", "e2b3f61e62406187ac52afead8a63bfb4e49394028993f3c4c42712743cab79e"}, + "unzip": {:hex, :unzip, "0.12.0", "beed92238724732418b41eba77dcb7f51e235b707406c05b1732a3052d1c0f36", [:mix], [], "hexpm", "95655b72db368e5a84951f0bed586ac053b55ee3815fd96062fce10ce4fc998d"}, + "xla": {:hex, :xla, "0.9.1", "cca0040ff94902764007a118871bfc667f1a0085d4a5074533a47d6b58bec61e", [:make, :mix], [{:elixir_make, "~> 0.4", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "eb5e443ae5391b1953f253e051f2307bea183b59acee138053a9300779930daf"}, +} diff --git a/bert/start.txt b/bert/start.txt new file mode 100644 index 0000000..b34ed6b --- /dev/null +++ b/bert/start.txt @@ -0,0 +1 @@ +elixir --sname gpu -S mix run --no-halt diff --git a/bert/test/bert_test.exs b/bert/test/bert_test.exs new file mode 100644 index 0000000..ebf1d95 --- /dev/null +++ b/bert/test/bert_test.exs @@ -0,0 +1,8 @@ +defmodule BertTest do + use ExUnit.Case + doctest Bert + + test "greets the world" do + assert Bert.hello() == :world + end +end diff --git a/bert/test/test_helper.exs b/bert/test/test_helper.exs new file mode 100644 index 0000000..869559e --- /dev/null +++ b/bert/test/test_helper.exs @@ -0,0 +1 @@ +ExUnit.start()